From 54f2da51e05b66b02d0ad44299b8166150db1a98 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:01:14 -0700 Subject: [PATCH 01/30] misc: arc2face attempt 1 --- examples/text_to_image/train_text_to_image.py | 385 +++++++++++++++--- 1 file changed, 332 insertions(+), 53 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4138d1b46329..d452d7346138 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -50,7 +50,6 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -131,7 +130,13 @@ def save_model_card( inference=True, ) - tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"] + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "text-to-image", + "diffusers", + "diffusers-training", + ] model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) @@ -165,7 +170,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight images = [] for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], num_inference_steps=20, generator=generator + ).images[0] images.append(image) @@ -194,7 +201,10 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + "--input_perturbation", + type=float, + default=0, + help="The scale of input perturbation. Recommended 0.1.", ) parser.add_argument( "--pretrained_model_name_or_path", @@ -243,7 +253,10 @@ def parse_args(): ), ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", ) parser.add_argument( "--caption_column", @@ -265,7 +278,9 @@ def parse_args(): type=str, default=None, nargs="+", - help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + help=( + "A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`." + ), ) parser.add_argument( "--output_dir", @@ -304,7 +319,10 @@ def parse_args(): help="whether to randomly flip images horizontally", ) parser.add_argument( - "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( @@ -346,7 +364,10 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--snr_gamma", @@ -356,7 +377,9 @@ def parse_args(): "More details here: https://arxiv.org/abs/2303.09556.", ) parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", ) parser.add_argument( "--allow_tf32", @@ -385,13 +408,39 @@ def parse_args(): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--prediction_type", type=str, @@ -433,7 +482,12 @@ def parse_args(): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) parser.add_argument( "--checkpointing_steps", type=int, @@ -459,7 +513,9 @@ def parse_args(): ), ) parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", ) parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( @@ -494,6 +550,118 @@ def parse_args(): return args +from typing import Optional, Tuple, Union + +from transformers.models.clip.modeling_clip import ( + BaseModelOutputWithPooling, + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) + +SKS_TOKEN = 48136 + + +def text_encoder_fwd( + text_transformer, + input_ids: Optional[torch.Tensor] = None, + face_embeddings: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else text_transformer.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else text_transformer.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else text_transformer.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + try: + sks_token_index = int(torch.where(input_ids == SKS_TOKEN)[1][0]) + except IndexError: + raise ValueError("your prompt doesn't have sks in it!!!!") + + hidden_states = text_transformer.embeddings(input_ids=input_ids, position_ids=position_ids) + + pad_left = 0 + pad_right = 768 - 512 + hidden_states[:, sks_token_index, :] = face_embeddings + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = text_transformer.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = text_transformer.final_layer_norm(last_hidden_state) + + if text_transformer.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == text_transformer.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + def main(): args = parse_args() @@ -514,7 +682,9 @@ def main(): ) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -550,20 +720,28 @@ def main(): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, ).repo_id # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, ) def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ - deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + deepspeed_plugin = ( + AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + ) if deepspeed_plugin is None: return [] @@ -580,27 +758,52 @@ def deepspeed_zero_init_disabled_context_manager(): # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. with ContextManagers(deepspeed_zero_init_disabled_context_manager()): text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision, + ) + + from insightface.app import FaceAnalysis + + app = FaceAnalysis( + name="buffalo_l", + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) + app.prepare(ctx_id=0, det_size=(640, 640)) - # Freeze vae and text_encoder and set unet to trainable + # Freeze vae and set unet and TE to trainable vae.requires_grad_(False) - text_encoder.requires_grad_(False) unet.train() + text_encoder.train() + text_encoder.requires_grad_(False) + text_encoder.text_model.encoder.requires_grad_(True) # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + ema_unet = EMAModel( + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -631,7 +834,9 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, "unet_ema"), UNet2DConditionModel + ) ema_unet.load_state_dict(load_model.state_dict()) ema_unet.to(accelerator.device) del load_model @@ -660,7 +865,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Initialize the optimizer @@ -677,7 +885,7 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - unet.parameters(), + unet.parameters() + text_encoder.text_model.encoder.parameters(), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -747,16 +955,30 @@ def tokenize_captions(examples, is_train=True): f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( - captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", ) return inputs.input_ids # Preprocessing the datasets. train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(args.resolution) + if args.center_crop + else transforms.RandomCrop(args.resolution) + ), + ( + transforms.RandomHorizontalFlip() + if args.random_flip + else transforms.Lambda(lambda x: x) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -770,7 +992,9 @@ def preprocess_train(examples): with accelerator.main_process_first(): if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + dataset["train"] = ( + dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + ) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) @@ -791,7 +1015,9 @@ def collate_fn(examples): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -826,7 +1052,9 @@ def collate_fn(examples): vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -846,13 +1074,17 @@ def unwrap_model(model): return model # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -898,6 +1130,28 @@ def unwrap_model(model): train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + with torch.no_grad(): + batch_size = batch["pixel_values"].shape[0] + face_embeddings = [] + for batch_index in range(batch_size): + faces = app.get(batch["pixel_values"][batch_index]) + try: + found_face = faces[0] + except IndexError: + break + + face_embedding = torch.from_numpy(found_face.normed_embedding).unsqueeze(0) + face_embedding = F.pad(face_embedding, (pad_left, pad_right)) + + face_embeddings.append(face_embedding) + + if len(face_embeddings) != batch_size: + print("Skipping batch due to no face found in one of the images") + continue + + # (bs, 1, 768) + face_embeddings = torch.cat(face_embeddings, dim=0) + # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor @@ -907,13 +1161,19 @@ def unwrap_model(model): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + (latents.shape[0], latents.shape[1], 1, 1), + device=latents.device, ) if args.input_perturbation: new_noise = noise + args.input_perturbation * torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep @@ -924,7 +1184,11 @@ def unwrap_model(model): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] + encoder_hidden_states = text_encoder_fwd( + text_encoder.text_model, + batch["input_ids"], + face_embeddings=face_embeddings, + )[0] # Get the target for loss depending on the prediction type if args.prediction_type is not None: @@ -936,10 +1200,14 @@ def unwrap_model(model): elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] if args.snr_gamma is None: loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -948,9 +1216,9 @@ def unwrap_model(model): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] + mse_loss_weights = torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": @@ -997,17 +1265,24 @@ def unwrap_model(model): logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: @@ -1068,7 +1343,11 @@ def unwrap_model(model): for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], + num_inference_steps=20, + generator=generator, + ).images[0] images.append(image) if args.push_to_hub: From a89e23decd36926c633d3eb8a1d5f79b753f9c94 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:08:44 -0700 Subject: [PATCH 02/30] fix param concat --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d452d7346138..8f4951c060a1 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -885,7 +885,7 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - unet.parameters() + text_encoder.text_model.encoder.parameters(), + list(unet.parameters()) + list(text_encoder.text_model.encoder.parameters()), lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, From 2ad3e678cc6a1a18d8eb15baa4ce841f0cce9fb3 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:12:06 -0700 Subject: [PATCH 03/30] fixate prompt --- examples/text_to_image/train_text_to_image.py | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8f4951c060a1..8c2cc9a04806 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -931,15 +931,6 @@ def load_model_hook(models, input_dir): raise ValueError( f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) - if args.caption_column is None: - caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] - else: - caption_column = args.caption_column - if caption_column not in column_names: - raise ValueError( - f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" - ) - # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): @@ -987,7 +978,6 @@ def tokenize_captions(examples, is_train=True): def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] - examples["input_ids"] = tokenize_captions(examples) return examples with accelerator.main_process_first(): @@ -1001,8 +991,7 @@ def preprocess_train(examples): def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.stack([example["input_ids"] for example in examples]) - return {"pixel_values": pixel_values, "input_ids": input_ids} + return {"pixel_values": pixel_values} # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -1125,6 +1114,14 @@ def unwrap_model(model): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) + with torch.no_grad(): + input_ids = tokenizer( + ["a photo of sks person"] * args.train_batch_size, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + )["input_ids"] for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 @@ -1141,7 +1138,7 @@ def unwrap_model(model): break face_embedding = torch.from_numpy(found_face.normed_embedding).unsqueeze(0) - face_embedding = F.pad(face_embedding, (pad_left, pad_right)) + face_embedding = F.pad(face_embedding, (0, 768 - 512)) face_embeddings.append(face_embedding) @@ -1186,7 +1183,7 @@ def unwrap_model(model): # Get the text embedding for conditioning encoder_hidden_states = text_encoder_fwd( text_encoder.text_model, - batch["input_ids"], + input_ids, face_embeddings=face_embeddings, )[0] From 646c1e413f479810e7da621f8f6d7c5b31cccbb0 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:14:40 -0700 Subject: [PATCH 04/30] face embedding to numpy --- examples/text_to_image/train_text_to_image.py | 143 +++++++++++++----- 1 file changed, 109 insertions(+), 34 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8c2cc9a04806..db348b0e11cf 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -42,10 +42,20 @@ from transformers.utils import ContextManagers import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr -from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils import ( + check_min_version, + deprecate, + is_wandb_available, + make_image_grid, +) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module @@ -142,7 +152,9 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) -def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): +def log_validation( + vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch +): logger.info("Running validation... ") pipeline = StableDiffusionPipeline.from_pretrained( @@ -179,7 +191,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + tracker.writer.add_images( + "validation", np_images, epoch, dataformats="NHWC" + ) elif tracker.name == "wandb": tracker.log( { @@ -294,7 +308,9 @@ def parse_args(): default=None, help="The directory where the downloaded models and datasets will be stored.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -389,7 +405,9 @@ def parse_args(): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--use_ema", action="store_true", help="Whether to use EMA model." + ) parser.add_argument( "--non_ema_revision", type=str, @@ -429,7 +447,9 @@ def parse_args(): default=1e-08, help="Epsilon value for the Adam optimizer", ) - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) parser.add_argument( "--push_to_hub", action="store_true", @@ -517,7 +537,9 @@ def parse_args(): action="store_true", help="Whether or not to use xformers.", ) - parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--noise_offset", type=float, default=0, help="The scale of noise offset." + ) parser.add_argument( "--validation_epochs", type=int, @@ -586,7 +608,9 @@ def text_encoder_fwd( else text_transformer.config.output_hidden_states ) return_dict = ( - return_dict if return_dict is not None else text_transformer.config.use_return_dict + return_dict + if return_dict is not None + else text_transformer.config.use_return_dict ) if input_ids is None: @@ -599,7 +623,9 @@ def text_encoder_fwd( except IndexError: raise ValueError("your prompt doesn't have sks in it!!!!") - hidden_states = text_transformer.embeddings(input_ids=input_ids, position_ids=position_ids) + hidden_states = text_transformer.embeddings( + input_ids=input_ids, position_ids=position_ids + ) pad_left = 0 pad_right = 768 - 512 @@ -636,7 +662,9 @@ def text_encoder_fwd( # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), ] else: # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) @@ -740,7 +768,9 @@ def deepspeed_zero_init_disabled_context_manager(): returns either a context list that includes one that will disable zero.Init or an empty context list """ deepspeed_plugin = ( - AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + AcceleratorState().deepspeed_plugin + if accelerate.state.is_initialized() + else None ) if deepspeed_plugin is None: return [] @@ -816,7 +846,9 @@ def deepspeed_zero_init_disabled_context_manager(): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): @@ -846,7 +878,9 @@ def load_model_hook(models, input_dir): model = models.pop() # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + load_model = UNet2DConditionModel.from_pretrained( + input_dir, subfolder="unet" + ) model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) @@ -924,13 +958,16 @@ def load_model_hook(models, input_dir): # 6. Get the column names for input/target. dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) + # Preprocessing the datasets. # We need to tokenize input captions and transform the images. def tokenize_captions(examples, is_train=True): @@ -983,7 +1020,9 @@ def preprocess_train(examples): with accelerator.main_process_first(): if args.max_train_samples is not None: dataset["train"] = ( - dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) ) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) @@ -1064,7 +1103,9 @@ def unwrap_model(model): # Train! total_batch_size = ( - args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps ) logger.info("***** Running training *****") @@ -1131,26 +1172,34 @@ def unwrap_model(model): batch_size = batch["pixel_values"].shape[0] face_embeddings = [] for batch_index in range(batch_size): - faces = app.get(batch["pixel_values"][batch_index]) + tensor_img = batch["pixel_values"][batch_index] + tensor_img.permute(1, 2, 0).detach().numpy() + + faces = app.get(tensor_img) try: found_face = faces[0] except IndexError: break - face_embedding = torch.from_numpy(found_face.normed_embedding).unsqueeze(0) + face_embedding = torch.from_numpy( + found_face.normed_embedding + ).unsqueeze(0) face_embedding = F.pad(face_embedding, (0, 768 - 512)) - face_embeddings.append(face_embedding) if len(face_embeddings) != batch_size: - print("Skipping batch due to no face found in one of the images") + print( + "Skipping batch due to no face found in one of the images" + ) continue # (bs, 1, 768) face_embeddings = torch.cat(face_embeddings, dim=0) # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = vae.encode( + batch["pixel_values"].to(weight_dtype) + ).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents @@ -1162,7 +1211,9 @@ def unwrap_model(model): device=latents.device, ) if args.input_perturbation: - new_noise = noise + args.input_perturbation * torch.randn_like(noise) + new_noise = noise + args.input_perturbation * torch.randn_like( + noise + ) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -1176,7 +1227,9 @@ def unwrap_model(model): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.input_perturbation: - noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + noisy_latents = noise_scheduler.add_noise( + latents, new_noise, timesteps + ) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) @@ -1190,7 +1243,9 @@ def unwrap_model(model): # Get the target for loss depending on the prediction type if args.prediction_type is not None: # set prediction_type of scheduler if defined - noise_scheduler.register_to_config(prediction_type=args.prediction_type) + noise_scheduler.register_to_config( + prediction_type=args.prediction_type + ) if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -1207,7 +1262,9 @@ def unwrap_model(model): )[0] if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. @@ -1221,8 +1278,13 @@ def unwrap_model(model): elif noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). @@ -1251,12 +1313,18 @@ def unwrap_model(model): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( @@ -1272,7 +1340,9 @@ def unwrap_model(model): ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -1286,7 +1356,10 @@ def unwrap_model(model): break if accelerator.is_main_process: - if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if ( + args.validation_prompts is not None + and epoch % args.validation_epochs == 0 + ): if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. ema_unet.store(unet.parameters()) @@ -1336,7 +1409,9 @@ def unwrap_model(model): if args.seed is None: generator = None else: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed + ) for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): From f279478c69d762434379e7fc6390351bbc78d65d Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:15:26 -0700 Subject: [PATCH 05/30] move the tensor to host cpu --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index db348b0e11cf..7ee49e9e2d8e 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1173,7 +1173,7 @@ def unwrap_model(model): face_embeddings = [] for batch_index in range(batch_size): tensor_img = batch["pixel_values"][batch_index] - tensor_img.permute(1, 2, 0).detach().numpy() + tensor_img.cpu().permute(1, 2, 0).detach().numpy() faces = app.get(tensor_img) try: From 6e03b0e2365ee7e6451acb663e22d19d408a228c Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:16:27 -0700 Subject: [PATCH 06/30] more --- examples/text_to_image/train_text_to_image.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 7ee49e9e2d8e..93b6b65fe7cb 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1173,9 +1173,10 @@ def unwrap_model(model): face_embeddings = [] for batch_index in range(batch_size): tensor_img = batch["pixel_values"][batch_index] - tensor_img.cpu().permute(1, 2, 0).detach().numpy() + numpy_img = tensor_img.cpu().permute(1, 2, 0).detach().numpy() - faces = app.get(tensor_img) + faces = app.get(numpy_img) + print(f"detected {len(faces)} faces!") try: found_face = faces[0] except IndexError: From d6d5a519db3b9d09f3de1cbc6a4474926488260c Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:17:18 -0700 Subject: [PATCH 07/30] normalize the numbers --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 93b6b65fe7cb..88ee2f5c8935 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1173,7 +1173,7 @@ def unwrap_model(model): face_embeddings = [] for batch_index in range(batch_size): tensor_img = batch["pixel_values"][batch_index] - numpy_img = tensor_img.cpu().permute(1, 2, 0).detach().numpy() + numpy_img = (tensor_img * 255).cpu().permute(1, 2, 0).detach().numpy() faces = app.get(numpy_img) print(f"detected {len(faces)} faces!") From ed375a80847e18b54fc59c4e3c2ab6b4a79ceae3 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:18:17 -0700 Subject: [PATCH 08/30] move the face embeddings back to gpu --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 88ee2f5c8935..382da0b3c51b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1195,7 +1195,7 @@ def unwrap_model(model): continue # (bs, 1, 768) - face_embeddings = torch.cat(face_embeddings, dim=0) + face_embeddings = torch.cat(face_embeddings, dim=0).to("cuda") # Convert images to latent space latents = vae.encode( From 5a93f4bea8c8714b030f02df1a9db5b1b76376a6 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:20:35 -0700 Subject: [PATCH 09/30] move the inpuit ids to to cuda --- examples/text_to_image/train_text_to_image.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 382da0b3c51b..edb5b8707922 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1162,7 +1162,7 @@ def unwrap_model(model): padding="max_length", truncation=True, return_tensors="pt", - )["input_ids"] + )["input_ids"].to("cuda") for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 @@ -1173,7 +1173,9 @@ def unwrap_model(model): face_embeddings = [] for batch_index in range(batch_size): tensor_img = batch["pixel_values"][batch_index] - numpy_img = (tensor_img * 255).cpu().permute(1, 2, 0).detach().numpy() + numpy_img = ( + (tensor_img * 255).cpu().permute(1, 2, 0).detach().numpy() + ) faces = app.get(numpy_img) print(f"detected {len(faces)} faces!") @@ -1194,8 +1196,8 @@ def unwrap_model(model): ) continue - # (bs, 1, 768) - face_embeddings = torch.cat(face_embeddings, dim=0).to("cuda") + # (bs, 1, 768) + face_embeddings = torch.cat(face_embeddings, dim=0).to("cuda") # Convert images to latent space latents = vae.encode( From c660ab14cd7eebd62f4e0bd49a42955910b77c28 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:24:54 -0700 Subject: [PATCH 10/30] consolidate parameters --- examples/text_to_image/train_text_to_image.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index edb5b8707922..ef2ca5fe6e69 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -820,6 +820,7 @@ def deepspeed_zero_init_disabled_context_manager(): text_encoder.train() text_encoder.requires_grad_(False) text_encoder.text_model.encoder.requires_grad_(True) + training_parameters = list(unet.parameters()) + list(text_encoder.text_model.encoder.parameters()) # Create EMA for the unet. if args.use_ema: @@ -919,7 +920,7 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - list(unet.parameters()) + list(text_encoder.text_model.encoder.parameters()), + training_parameters, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -1297,7 +1298,7 @@ def unwrap_model(model): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(training_parameters, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -1305,7 +1306,7 @@ def unwrap_model(model): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: - ema_unet.step(unet.parameters()) + ema_unet.step(training_parameters) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) @@ -1365,8 +1366,8 @@ def unwrap_model(model): ): if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) + ema_unet.store(training_parameters) + ema_unet.copy_to(training_parameters) log_validation( vae, text_encoder, @@ -1379,14 +1380,14 @@ def unwrap_model(model): ) if args.use_ema: # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) + ema_unet.restore(training_parameters) # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet) if args.use_ema: - ema_unet.copy_to(unet.parameters()) + ema_unet.copy_to(training_parameters) pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, From c3c4a20bdad1a0703219e5c044b3a2bb6bd01a52 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:30:09 -0700 Subject: [PATCH 11/30] add validation steps --- examples/text_to_image/train_text_to_image.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index ef2ca5fe6e69..d29c3f8141cd 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -546,6 +546,12 @@ def parse_args(): default=5, help="Run validation every X epochs.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=50, + help="Run validation every X steps.", + ) parser.add_argument( "--tracker_project_name", type=str, @@ -1362,7 +1368,7 @@ def unwrap_model(model): if accelerator.is_main_process: if ( args.validation_prompts is not None - and epoch % args.validation_epochs == 0 + and global_step % args.validation_steps == 0 ): if args.use_ema: # Store the UNet parameters temporarily and load the EMA parameters to perform inference. From 2c2b5d558a9c58722e174ea2c776867581df50c0 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:40:28 -0700 Subject: [PATCH 12/30] change the implementation --- examples/text_to_image/README.md | 16 ++ examples/text_to_image/train_text_to_image.py | 234 +++++++++--------- 2 files changed, 129 insertions(+), 121 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index f2931d3f347e..10180347435d 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -76,6 +76,22 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \ +accelerate launch --mixed_precision="bf16" train_text_to_image.py \ + --pretrained_model_name_or_path=SG161222/Realistic_Vision_V6.0_B1_noVAE \ + --dataset_name=nielsr/CelebA-faces \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="arc2face-attempt-1" \ + --validation_steps=50 \ + --validation_prompt="a person with a hat" \ + --report_to="wandb" + + To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d29c3f8141cd..a702123808a5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -578,123 +578,116 @@ def parse_args(): return args -from typing import Optional, Tuple, Union +import torch +from transformers import CLIPTextModel +from typing import Any, Callable, Dict, Optional, Tuple, Union, List +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask + + +class CLIPTextModelWrapper(CLIPTextModel): + # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812 + # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them. + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + input_token_embs: Optional[torch.Tensor] = None, + return_token_embs: Optional[bool] = False, + ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]: + + if return_token_embs: + return self.text_model.embeddings.token_embedding(input_ids) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.text_model.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) -from transformers.models.clip.modeling_clip import ( - BaseModelOutputWithPooling, - _create_4d_causal_attention_mask, - _prepare_4d_attention_mask, -) + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.text_model.final_layer_norm(last_hidden_state) + + if self.text_model.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id) + .int() + .argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) -SKS_TOKEN = 48136 - - -def text_encoder_fwd( - text_transformer, - input_ids: Optional[torch.Tensor] = None, - face_embeddings: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - Returns: - - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else text_transformer.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else text_transformer.config.output_hidden_states - ) - return_dict = ( - return_dict - if return_dict is not None - else text_transformer.config.use_return_dict - ) - - if input_ids is None: - raise ValueError("You have to specify input_ids") - - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - try: - sks_token_index = int(torch.where(input_ids == SKS_TOKEN)[1][0]) - except IndexError: - raise ValueError("your prompt doesn't have sks in it!!!!") - - hidden_states = text_transformer.embeddings( - input_ids=input_ids, position_ids=position_ids - ) - - pad_left = 0 - pad_right = 768 - 512 - hidden_states[:, sks_token_index, :] = face_embeddings - - # CLIP's text model uses causal mask, prepare it here. - # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _create_4d_causal_attention_mask( - input_shape, hidden_states.dtype, device=hidden_states.device - ) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) - - encoder_outputs = text_transformer.encoder( - inputs_embeds=hidden_states, - attention_mask=attention_mask, - causal_attention_mask=causal_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - last_hidden_state = encoder_outputs[0] - last_hidden_state = text_transformer.final_layer_norm(last_hidden_state) - - if text_transformer.eos_token_id == 2: - # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. - # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added - # ------------------------------------------------------------ - # text_embeds.shape = [batch_size, sequence_length, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( - dim=-1 - ), - ] - else: - # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) - pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) - ( - input_ids.to(dtype=torch.int, device=last_hidden_state.device) - == text_transformer.eos_token_id - ) - .int() - .argmax(dim=-1), - ] +import torch +import torch.nn.functional as F - if not return_dict: - return (last_hidden_state, pooled_output) + encoder_outputs[1:] +def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): - return BaseModelOutputWithPooling( - last_hidden_state=last_hidden_state, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + ''' + face_embs: (N, 512) normalized ArcFace embeddings + ''' + token_embs = text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) + token_embs[input_ids==arcface_token_id] = face_embs + + prompt_embeds = text_encoder( + input_ids=input_ids, + input_token_embs=token_embs + )[0] + + return prompt_embeds def main(): args = parse_args() @@ -793,7 +786,7 @@ def deepspeed_zero_init_disabled_context_manager(): # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - text_encoder = CLIPTextModel.from_pretrained( + text_encoder = CLIPTextModelWrapper.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, @@ -814,10 +807,7 @@ def deepspeed_zero_init_disabled_context_manager(): from insightface.app import FaceAnalysis - app = FaceAnalysis( - name="buffalo_l", - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], - ) + app = FaceAnalysis(name='antelopev2', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) app.prepare(ctx_id=0, det_size=(640, 640)) # Freeze vae and set unet and TE to trainable @@ -1164,12 +1154,13 @@ def unwrap_model(model): ) with torch.no_grad(): input_ids = tokenizer( - ["a photo of sks person"] * args.train_batch_size, + ["photo of a id person"] * args.train_batch_size, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", )["input_ids"].to("cuda") + arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0] for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 @@ -1244,10 +1235,11 @@ def unwrap_model(model): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder_fwd( - text_encoder.text_model, + encoder_hidden_states = project_face_embs( input_ids, - face_embeddings=face_embeddings, + text_encoder, + face_embeddings, + arcface_token_id, )[0] # Get the target for loss depending on the prediction type From 3eb8f4281d51fa054893e9acaefaaf54642046c8 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:46:54 -0700 Subject: [PATCH 13/30] inference attempt --- examples/text_to_image/train_text_to_image.py | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index a702123808a5..d6c86147e62c 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -153,8 +153,9 @@ def save_model_card( def log_validation( - vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch + vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch, app, ): + from PIL import Image logger.info("Running validation... ") pipeline = StableDiffusionPipeline.from_pretrained( @@ -181,9 +182,24 @@ def log_validation( images = [] for i in range(len(args.validation_prompts)): + validation_image = args.validation_prompts[i] + img = np.array(Image.open(validation_image))[:,:,::-1] + faces = app.get(img) + if len(faces) == 0: + images.append(Image.new("RGB", (512, 512), (255, 255, 255))) + continue + + faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) + id_emb = torch.tensor(faces['embedding'], dtype=torch.float)[None].to("cuda") + id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding + id_emb = project_face_embs_inf(pipeline, id_emb) # pass throught the encoder + with torch.autocast("cuda"): image = pipeline( - args.validation_prompts[i], num_inference_steps=20, generator=generator + prompt_embeds=id_emb, + num_inference_steps=50, + guidance_scale=3.0, + generator=generator ).images[0] images.append(image) @@ -673,6 +689,34 @@ def forward( import torch import torch.nn.functional as F +@torch.no_grad() +def project_face_embs_inf(pipeline, face_embs): + + ''' + face_embs: (N, 512) normalized ArcFace embeddings + ''' + + arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] + + input_ids = pipeline.tokenizer( + "photo of a id person", + truncation=True, + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.to(pipeline.device) + + face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0) + token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) + token_embs[input_ids==arcface_token_id] = face_embs_padded + + prompt_embeds = pipeline.text_encoder( + input_ids=input_ids, + input_token_embs=token_embs + )[0] + + return prompt_embeds + def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): ''' @@ -1177,11 +1221,11 @@ def unwrap_model(model): faces = app.get(numpy_img) print(f"detected {len(faces)} faces!") - try: - found_face = faces[0] - except IndexError: + if len(faces) == 0: + print("Skipping batch due to no face found in one of the images") break + found_face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) face_embedding = torch.from_numpy( found_face.normed_embedding ).unsqueeze(0) @@ -1375,6 +1419,7 @@ def unwrap_model(model): accelerator, weight_dtype, global_step, + app, ) if args.use_ema: # Switch back to the original UNet parameters. From 1db0097465bcef0324a084942b7b827221605ec8 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:50:42 -0700 Subject: [PATCH 14/30] transformers stuff --- examples/text_to_image/README.md | 2 +- examples/text_to_image/train_text_to_image.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 10180347435d..3854f854acc5 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -88,7 +88,7 @@ accelerate launch --mixed_precision="bf16" train_text_to_image.py \ --lr_scheduler="constant" --lr_warmup_steps=0 \ --output_dir="arc2face-attempt-1" \ --validation_steps=50 \ - --validation_prompt="a person with a hat" \ + --validation_prompt="taylor.jpg" \ --report_to="wandb" diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d6c86147e62c..ea69bb9924bd 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -598,7 +598,7 @@ def parse_args(): from transformers import CLIPTextModel from typing import Any, Callable, Dict, Optional, Tuple, Union, List from transformers.modeling_outputs import BaseModelOutputWithPooling -from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask +from transformers.models.clip.modeling_clip import _create_4d_causal_attention_mask, _prepare_4d_attention_mask class CLIPTextModelWrapper(CLIPTextModel): @@ -637,11 +637,13 @@ def forward( # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) encoder_outputs = self.text_model.encoder( inputs_embeds=hidden_states, From 2891672d0f66b8f81a00bdba4536f47531578349 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:51:48 -0700 Subject: [PATCH 15/30] go back to buffalo --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index ea69bb9924bd..60a94d7cdba7 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -853,7 +853,7 @@ def deepspeed_zero_init_disabled_context_manager(): from insightface.app import FaceAnalysis - app = FaceAnalysis(name='antelopev2', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) app.prepare(ctx_id=0, det_size=(640, 640)) # Freeze vae and set unet and TE to trainable From a128bbd457f3ba3e5a64e8090bfbf77db17d26ae Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:54:03 -0700 Subject: [PATCH 16/30] bfloat --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 60a94d7cdba7..d5cb13244375 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -190,7 +190,7 @@ def log_validation( continue faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) - id_emb = torch.tensor(faces['embedding'], dtype=torch.float)[None].to("cuda") + id_emb = torch.tensor(faces['embedding'], dtype=torch.bfloat16)[None].to("cuda") id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding id_emb = project_face_embs_inf(pipeline, id_emb) # pass throught the encoder @@ -1241,7 +1241,7 @@ def unwrap_model(model): continue # (bs, 1, 768) - face_embeddings = torch.cat(face_embeddings, dim=0).to("cuda") + face_embeddings = torch.cat(face_embeddings, dtype=weight_dtype, dim=0).to("cuda") # Convert images to latent space latents = vae.encode( From c1132f3e84475cb7abd4de25d0b08f94728c4470 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:54:33 -0700 Subject: [PATCH 17/30] stuff --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index d5cb13244375..9af798d8cfd2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1241,7 +1241,7 @@ def unwrap_model(model): continue # (bs, 1, 768) - face_embeddings = torch.cat(face_embeddings, dtype=weight_dtype, dim=0).to("cuda") + face_embeddings = torch.cat(face_embeddings, dim=0).to(device="cuda", dtype=weight_dtype) # Convert images to latent space latents = vae.encode( From 17716ba1350ff6cac1513e44583696199760778e Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 14:55:42 -0700 Subject: [PATCH 18/30] shapes --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9af798d8cfd2..8dfa16a64d79 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1286,7 +1286,7 @@ def unwrap_model(model): text_encoder, face_embeddings, arcface_token_id, - )[0] + ) # Get the target for loss depending on the prediction type if args.prediction_type is not None: From 1d8688c39aae720470d8e5448217b7b8f31d717e Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:04:04 -0700 Subject: [PATCH 19/30] x --- examples/text_to_image/train_text_to_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8dfa16a64d79..0fadaf6178d4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1404,6 +1404,7 @@ def unwrap_model(model): break if accelerator.is_main_process: + print(global_step, args.validation_prompts) if ( args.validation_prompts is not None and global_step % args.validation_steps == 0 From 1f5ccffa8a107e5e3c7a5542eeb90f4629f0b1ca Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:06:37 -0700 Subject: [PATCH 20/30] x --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 0fadaf6178d4..3a647e2f3966 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1403,8 +1403,8 @@ def unwrap_model(model): if global_step >= args.max_train_steps: break + print(accelerator.is_main_process) if accelerator.is_main_process: - print(global_step, args.validation_prompts) if ( args.validation_prompts is not None and global_step % args.validation_steps == 0 From ff16519412cb5bc28c0ef9dda68d39db71a1b5c1 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:07:50 -0700 Subject: [PATCH 21/30] lol --- examples/text_to_image/train_text_to_image.py | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 3a647e2f3966..10280e4ddecf 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1403,30 +1403,30 @@ def unwrap_model(model): if global_step >= args.max_train_steps: break - print(accelerator.is_main_process) - if accelerator.is_main_process: - if ( - args.validation_prompts is not None - and global_step % args.validation_steps == 0 - ): - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(training_parameters) - ema_unet.copy_to(training_parameters) - log_validation( - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, - app, - ) - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(training_parameters) + if accelerator.is_main_process: + if ( + args.validation_prompts is not None + and global_step % args.validation_steps == 0 + ): + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(training_parameters) + ema_unet.copy_to(training_parameters) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + app, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(training_parameters) + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() From 348305f5d0880a18ee9d45f6d946f28f12bc2b29 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:34:07 -0700 Subject: [PATCH 22/30] use taylor's picture if no face is found --- examples/text_to_image/train_text_to_image.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 10280e4ddecf..c6afcfaeee51 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1208,6 +1208,13 @@ def unwrap_model(model): )["input_ids"].to("cuda") arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0] + with torch.no_grad(): + from PIL import Image + + validation_image = args.validation_prompts[0] + img = np.array(Image.open(validation_image))[:,:,::-1] + taylor_faces = app.get(img) + for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 for step, batch in enumerate(train_dataloader): @@ -1224,8 +1231,8 @@ def unwrap_model(model): faces = app.get(numpy_img) print(f"detected {len(faces)} faces!") if len(faces) == 0: - print("Skipping batch due to no face found in one of the images") - break + print("replacing with taylor's face") + faces = taylor_faces found_face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) face_embedding = torch.from_numpy( From 8c33fe1619a1b9e41f2656bb1d673b6c0c4a8598 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:36:21 -0700 Subject: [PATCH 23/30] try something --- examples/text_to_image/train_text_to_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index c6afcfaeee51..21f1de4dfccb 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1200,7 +1200,7 @@ def unwrap_model(model): ) with torch.no_grad(): input_ids = tokenizer( - ["photo of a id person"] * args.train_batch_size, + ["photo of a id person"] * 1, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, From 5368f8ea4ee4d3515bc1431535b0a0853d807f79 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:38:26 -0700 Subject: [PATCH 24/30] lol --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 21f1de4dfccb..2f15337e61e4 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -725,7 +725,7 @@ def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): face_embs: (N, 512) normalized ArcFace embeddings ''' - token_embs = text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) + token_embs = text_encoder(input_ids=input_ids, return_token_embs=True) token_embs[input_ids==arcface_token_id] = face_embs prompt_embeds = text_encoder( @@ -1200,7 +1200,7 @@ def unwrap_model(model): ) with torch.no_grad(): input_ids = tokenizer( - ["photo of a id person"] * 1, + ["photo of a id person"] * args.train_batch_size, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, From c0fa2fa4b29dcb611c28cc73b36057df104b3adc Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:45:23 -0700 Subject: [PATCH 25/30] lol --- examples/text_to_image/train_text_to_image.py | 170 +++++++++++++----- 1 file changed, 123 insertions(+), 47 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 2f15337e61e4..214bb77228e9 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -153,9 +153,18 @@ def save_model_card( def log_validation( - vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch, app, + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + epoch, + app, ): from PIL import Image + logger.info("Running validation... ") pipeline = StableDiffusionPipeline.from_pretrained( @@ -181,29 +190,52 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] + image_similarity = [] for i in range(len(args.validation_prompts)): validation_image = args.validation_prompts[i] - img = np.array(Image.open(validation_image))[:,:,::-1] + img = np.array(Image.open(validation_image))[:, :, ::-1] faces = app.get(img) if len(faces) == 0: images.append(Image.new("RGB", (512, 512), (255, 255, 255))) continue - faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) - id_emb = torch.tensor(faces['embedding'], dtype=torch.bfloat16)[None].to("cuda") - id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding - id_emb = project_face_embs_inf(pipeline, id_emb) # pass throught the encoder + faces = sorted( + faces, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) + id_emb = torch.tensor(faces["embedding"], dtype=torch.bfloat16)[None].to("cuda") + id_emb = id_emb / torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding + id_emb = project_face_embs_inf(pipeline, id_emb) # pass throught the encoder with torch.autocast("cuda"): image = pipeline( prompt_embeds=id_emb, num_inference_steps=50, guidance_scale=3.0, - generator=generator + generator=generator, ).images[0] images.append(image) + face_2 = np.array(image)[:, :, ::-1] + faces_2 = app.get(face_2) + if len(faces_2) == 0: + continue + + faces_2 = sorted( + faces_2, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) + # cosine similarity between embeddings + sim = np.dot(faces["embedding"], faces_2["embedding"]) / ( + np.linalg.norm(faces["embedding"]) * np.linalg.norm(faces_2["embedding"]) + ) + image_similarity.append(sim) + for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) @@ -216,7 +248,8 @@ def log_validation( "validation": [ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") for i, image in enumerate(images) - ] + ], + "image_similarity": image_similarity, } ) else: @@ -598,7 +631,10 @@ def parse_args(): from transformers import CLIPTextModel from typing import Any, Callable, Dict, Optional, Tuple, Union, List from transformers.modeling_outputs import BaseModelOutputWithPooling -from transformers.models.clip.modeling_clip import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from transformers.models.clip.modeling_clip import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) class CLIPTextModelWrapper(CLIPTextModel): @@ -619,13 +655,25 @@ def forward( if return_token_embs: return self.text_model.embeddings.token_embedding(input_ids) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) - output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.text_model.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.text_model.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.text_model.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict if input_ids is None: raise ValueError("You have to specify input_ids") @@ -633,7 +681,11 @@ def forward( input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) - hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs) + hidden_states = self.text_model.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=input_token_embs, + ) # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 @@ -643,7 +695,9 @@ def forward( # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) encoder_outputs = self.text_model.encoder( inputs_embeds=hidden_states, @@ -665,15 +719,24 @@ def forward( # take features from the eot embedding (eot_token is the highest number in each sequence) # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), - input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), ] else: # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) pooled_output = last_hidden_state[ - torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) - (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.text_model.eos_token_id + ) .int() .argmax(dim=-1), ] @@ -688,53 +751,55 @@ def forward( attentions=encoder_outputs.attentions, ) + import torch import torch.nn.functional as F + @torch.no_grad() def project_face_embs_inf(pipeline, face_embs): - - ''' + """ face_embs: (N, 512) normalized ArcFace embeddings - ''' + """ arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] input_ids = pipeline.tokenizer( - "photo of a id person", - truncation=True, - padding="max_length", - max_length=pipeline.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids.to(pipeline.device) + "photo of a id person", + truncation=True, + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.to(pipeline.device) - face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0) - token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) - token_embs[input_ids==arcface_token_id] = face_embs_padded + face_embs_padded = F.pad( + face_embs, (0, pipeline.text_encoder.config.hidden_size - 512), "constant", 0 + ) + token_embs = pipeline.text_encoder( + input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True + ) + token_embs[input_ids == arcface_token_id] = face_embs_padded prompt_embeds = pipeline.text_encoder( - input_ids=input_ids, - input_token_embs=token_embs + input_ids=input_ids, input_token_embs=token_embs )[0] return prompt_embeds -def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): - ''' +def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): + """ face_embs: (N, 512) normalized ArcFace embeddings - ''' + """ token_embs = text_encoder(input_ids=input_ids, return_token_embs=True) - token_embs[input_ids==arcface_token_id] = face_embs + token_embs[input_ids == arcface_token_id] = face_embs - prompt_embeds = text_encoder( - input_ids=input_ids, - input_token_embs=token_embs - )[0] + prompt_embeds = text_encoder(input_ids=input_ids, input_token_embs=token_embs)[0] return prompt_embeds + def main(): args = parse_args() @@ -853,7 +918,9 @@ def deepspeed_zero_init_disabled_context_manager(): from insightface.app import FaceAnalysis - app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + app = FaceAnalysis( + name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) app.prepare(ctx_id=0, det_size=(640, 640)) # Freeze vae and set unet and TE to trainable @@ -862,7 +929,9 @@ def deepspeed_zero_init_disabled_context_manager(): text_encoder.train() text_encoder.requires_grad_(False) text_encoder.text_model.encoder.requires_grad_(True) - training_parameters = list(unet.parameters()) + list(text_encoder.text_model.encoder.parameters()) + training_parameters = list(unet.parameters()) + list( + text_encoder.text_model.encoder.parameters() + ) # Create EMA for the unet. if args.use_ema: @@ -1212,7 +1281,7 @@ def unwrap_model(model): from PIL import Image validation_image = args.validation_prompts[0] - img = np.array(Image.open(validation_image))[:,:,::-1] + img = np.array(Image.open(validation_image))[:, :, ::-1] taylor_faces = app.get(img) for epoch in range(first_epoch, args.num_train_epochs): @@ -1234,7 +1303,13 @@ def unwrap_model(model): print("replacing with taylor's face") faces = taylor_faces - found_face = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) + found_face = sorted( + faces, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) + * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) face_embedding = torch.from_numpy( found_face.normed_embedding ).unsqueeze(0) @@ -1248,7 +1323,9 @@ def unwrap_model(model): continue # (bs, 1, 768) - face_embeddings = torch.cat(face_embeddings, dim=0).to(device="cuda", dtype=weight_dtype) + face_embeddings = torch.cat(face_embeddings, dim=0).to( + device="cuda", dtype=weight_dtype + ) # Convert images to latent space latents = vae.encode( @@ -1434,7 +1511,6 @@ def unwrap_model(model): # Switch back to the original UNet parameters. ema_unet.restore(training_parameters) - # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: From aa337e71756c4b1d90e07c336ef4ec156387834f Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 15:50:58 -0700 Subject: [PATCH 26/30] lol --- examples/text_to_image/train_text_to_image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 214bb77228e9..9702ebf358c2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -234,7 +234,7 @@ def log_validation( sim = np.dot(faces["embedding"], faces_2["embedding"]) / ( np.linalg.norm(faces["embedding"]) * np.linalg.norm(faces_2["embedding"]) ) - image_similarity.append(sim) + image_similarity.append(float(sim)) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -249,7 +249,10 @@ def log_validation( wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") for i, image in enumerate(images) ], - "image_similarity": image_similarity, + } + | { + f"image_similarity_{n}": sim + for n, sim in enumerate(image_similarity) } ) else: From fb7a9aa61f683e4d878726af78be668ac1700969 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 16:33:16 -0700 Subject: [PATCH 27/30] ml magic --- examples/text_to_image/train_text_to_image.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 9702ebf358c2..95e5657794db 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1293,6 +1293,7 @@ def unwrap_model(model): with accelerator.accumulate(unet): with torch.no_grad(): batch_size = batch["pixel_values"].shape[0] + indices = [] face_embeddings = [] for batch_index in range(batch_size): tensor_img = batch["pixel_values"][batch_index] @@ -1303,8 +1304,8 @@ def unwrap_model(model): faces = app.get(numpy_img) print(f"detected {len(faces)} faces!") if len(faces) == 0: - print("replacing with taylor's face") - faces = taylor_faces + print("couldn't detect faces") + continue found_face = sorted( faces, @@ -1318,17 +1319,19 @@ def unwrap_model(model): ).unsqueeze(0) face_embedding = F.pad(face_embedding, (0, 768 - 512)) face_embeddings.append(face_embedding) - - if len(face_embeddings) != batch_size: - print( - "Skipping batch due to no face found in one of the images" - ) - continue + indices.append(batch_index) # (bs, 1, 768) face_embeddings = torch.cat(face_embeddings, dim=0).to( device="cuda", dtype=weight_dtype ) + # Recompute the pixel values from the indexes in face embeddings + batch["pixel_values"] = batch["pixel_values"][list(indices)] + batch["token_ids"] = input_ids[: batch["pixel_values"].shape[0]] + + if len(face_embeddings) != batch["pixel_values"].shape[0]: + print("Skipping batch due to no face found in one of the images") + continue # Convert images to latent space latents = vae.encode( From 121b02596ef11788c48d45e4de8aff4cf4aaaceb Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 16:34:59 -0700 Subject: [PATCH 28/30] x --- examples/text_to_image/train_text_to_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 95e5657794db..8a5ac4b9a342 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -1327,7 +1327,7 @@ def unwrap_model(model): ) # Recompute the pixel values from the indexes in face embeddings batch["pixel_values"] = batch["pixel_values"][list(indices)] - batch["token_ids"] = input_ids[: batch["pixel_values"].shape[0]] + batch_input_ids = input_ids[: batch["pixel_values"].shape[0]] if len(face_embeddings) != batch["pixel_values"].shape[0]: print("Skipping batch due to no face found in one of the images") @@ -1372,7 +1372,7 @@ def unwrap_model(model): # Get the text embedding for conditioning encoder_hidden_states = project_face_embs( - input_ids, + batch_input_ids, text_encoder, face_embeddings, arcface_token_id, From f4172c1b53a1f43a3fccd710395018d5efcc9e77 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Thu, 28 Mar 2024 18:23:43 -0700 Subject: [PATCH 29/30] lol --- examples/text_to_image/train_text_to_image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8a5ac4b9a342..894ebbc92306 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -932,8 +932,11 @@ def deepspeed_zero_init_disabled_context_manager(): text_encoder.train() text_encoder.requires_grad_(False) text_encoder.text_model.encoder.requires_grad_(True) - training_parameters = list(unet.parameters()) + list( - text_encoder.text_model.encoder.parameters() + text_encoder.text_model.final_layer_norm.requires_grad_(True) + training_parameters = ( + list(unet.parameters()) + + list(text_encoder.text_model.encoder.parameters()) + + list(text_encoder.text_model.final_layer_norm.parameters()) ) # Create EMA for the unet. From a8a102d4866239fdadb47b2e94cefe74b7d60fd9 Mon Sep 17 00:00:00 2001 From: Batuhan Taskaya Date: Fri, 29 Mar 2024 11:37:18 -0700 Subject: [PATCH 30/30] stuff --- examples/text_to_image/train_text_to_image.py | 26 ++----------------- 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 894ebbc92306..54dfdb72a56b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -922,7 +922,8 @@ def deepspeed_zero_init_disabled_context_manager(): from insightface.app import FaceAnalysis app = FaceAnalysis( - name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + name="buffalo_l", + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) app.prepare(ctx_id=0, det_size=(640, 640)) @@ -1086,29 +1087,6 @@ def load_model_hook(models, input_dir): f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) - # Preprocessing the datasets. - # We need to tokenize input captions and transform the images. - def tokenize_captions(examples, is_train=True): - captions = [] - for caption in examples[caption_column]: - if isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - # take a random caption if there are multiple - captions.append(random.choice(caption) if is_train else caption[0]) - else: - raise ValueError( - f"Caption column `{caption_column}` should contain either strings or lists of strings." - ) - inputs = tokenizer( - captions, - max_length=tokenizer.model_max_length, - padding="max_length", - truncation=True, - return_tensors="pt", - ) - return inputs.input_ids - # Preprocessing the datasets. train_transforms = transforms.Compose( [