diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index f2931d3f347e..3854f854acc5 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -76,6 +76,22 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \ +accelerate launch --mixed_precision="bf16" train_text_to_image.py \ + --pretrained_model_name_or_path=SG161222/Realistic_Vision_V6.0_B1_noVAE \ + --dataset_name=nielsr/CelebA-faces \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="arc2face-attempt-1" \ + --validation_steps=50 \ + --validation_prompt="taylor.jpg" \ + --report_to="wandb" + + To run on your own training files prepare the dataset according to the format required by `datasets`, you can find the instructions for how to do that in this [document](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder-with-metadata). If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 4138d1b46329..54dfdb72a56b 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -42,15 +42,24 @@ from transformers.utils import ContextManagers import diffusers -from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, compute_snr -from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid +from diffusers.utils import ( + check_min_version, + deprecate, + is_wandb_available, + make_image_grid, +) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.torch_utils import is_compiled_module - if is_wandb_available(): import wandb @@ -131,13 +140,31 @@ def save_model_card( inference=True, ) - tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"] + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "text-to-image", + "diffusers", + "diffusers-training", + ] model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) -def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): +def log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + epoch, + app, +): + from PIL import Image + logger.info("Running validation... ") pipeline = StableDiffusionPipeline.from_pretrained( @@ -163,23 +190,69 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] + image_similarity = [] for i in range(len(args.validation_prompts)): + validation_image = args.validation_prompts[i] + img = np.array(Image.open(validation_image))[:, :, ::-1] + faces = app.get(img) + if len(faces) == 0: + images.append(Image.new("RGB", (512, 512), (255, 255, 255))) + continue + + faces = sorted( + faces, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) + id_emb = torch.tensor(faces["embedding"], dtype=torch.bfloat16)[None].to("cuda") + id_emb = id_emb / torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding + id_emb = project_face_embs_inf(pipeline, id_emb) # pass throught the encoder + with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + prompt_embeds=id_emb, + num_inference_steps=50, + guidance_scale=3.0, + generator=generator, + ).images[0] images.append(image) + face_2 = np.array(image)[:, :, ::-1] + faces_2 = app.get(face_2) + if len(faces_2) == 0: + continue + + faces_2 = sorted( + faces_2, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) + # cosine similarity between embeddings + sim = np.dot(faces["embedding"], faces_2["embedding"]) / ( + np.linalg.norm(faces["embedding"]) * np.linalg.norm(faces_2["embedding"]) + ) + image_similarity.append(float(sim)) + for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + tracker.writer.add_images( + "validation", np_images, epoch, dataformats="NHWC" + ) elif tracker.name == "wandb": tracker.log( { "validation": [ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") for i, image in enumerate(images) - ] + ], + } + | { + f"image_similarity_{n}": sim + for n, sim in enumerate(image_similarity) } ) else: @@ -194,7 +267,10 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + "--input_perturbation", + type=float, + default=0, + help="The scale of input perturbation. Recommended 0.1.", ) parser.add_argument( "--pretrained_model_name_or_path", @@ -243,7 +319,10 @@ def parse_args(): ), ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing an image." + "--image_column", + type=str, + default="image", + help="The column of the dataset containing an image.", ) parser.add_argument( "--caption_column", @@ -265,7 +344,9 @@ def parse_args(): type=str, default=None, nargs="+", - help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + help=( + "A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`." + ), ) parser.add_argument( "--output_dir", @@ -279,7 +360,9 @@ def parse_args(): default=None, help="The directory where the downloaded models and datasets will be stored.", ) - parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--seed", type=int, default=None, help="A seed for reproducible training." + ) parser.add_argument( "--resolution", type=int, @@ -304,7 +387,10 @@ def parse_args(): help="whether to randomly flip images horizontally", ) parser.add_argument( - "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + "--train_batch_size", + type=int, + default=16, + help="Batch size (per device) for the training dataloader.", ) parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( @@ -346,7 +432,10 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( "--snr_gamma", @@ -356,7 +445,9 @@ def parse_args(): "More details here: https://arxiv.org/abs/2303.09556.", ) parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes.", ) parser.add_argument( "--allow_tf32", @@ -366,7 +457,9 @@ def parse_args(): " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) - parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--use_ema", action="store_true", help="Whether to use EMA model." + ) parser.add_argument( "--non_ema_revision", type=str, @@ -385,13 +478,41 @@ def parse_args(): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") - parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") - parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") - parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer.", + ) + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Whether or not to push the model to the Hub.", + ) + parser.add_argument( + "--hub_token", + type=str, + default=None, + help="The token to use to push to the Model Hub.", + ) parser.add_argument( "--prediction_type", type=str, @@ -433,7 +554,12 @@ def parse_args(): ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) - parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank", + ) parser.add_argument( "--checkpointing_steps", type=int, @@ -459,15 +585,25 @@ def parse_args(): ), ) parser.add_argument( - "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + "--enable_xformers_memory_efficient_attention", + action="store_true", + help="Whether or not to use xformers.", + ) + parser.add_argument( + "--noise_offset", type=float, default=0, help="The scale of noise offset." ) - parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") parser.add_argument( "--validation_epochs", type=int, default=5, help="Run validation every X epochs.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=50, + help="Run validation every X steps.", + ) parser.add_argument( "--tracker_project_name", type=str, @@ -494,6 +630,179 @@ def parse_args(): return args +import torch +from transformers import CLIPTextModel +from typing import Any, Callable, Dict, Optional, Tuple, Union, List +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.models.clip.modeling_clip import ( + _create_4d_causal_attention_mask, + _prepare_4d_attention_mask, +) + + +class CLIPTextModelWrapper(CLIPTextModel): + # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812 + # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them. + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + input_token_embs: Optional[torch.Tensor] = None, + return_token_embs: Optional[bool] = False, + ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]: + + if return_token_embs: + return self.text_model.embeddings.token_embedding(input_ids) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.text_model.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.text_model.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.text_model.config.use_return_dict + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.text_model.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=input_token_embs, + ) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _create_4d_causal_attention_mask( + input_shape, hidden_states.dtype, device=hidden_states.device + ) + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs = self.text_model.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.text_model.final_layer_norm(last_hidden_state) + + if self.text_model.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.text_model.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +import torch +import torch.nn.functional as F + + +@torch.no_grad() +def project_face_embs_inf(pipeline, face_embs): + """ + face_embs: (N, 512) normalized ArcFace embeddings + """ + + arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] + + input_ids = pipeline.tokenizer( + "photo of a id person", + truncation=True, + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids.to(pipeline.device) + + face_embs_padded = F.pad( + face_embs, (0, pipeline.text_encoder.config.hidden_size - 512), "constant", 0 + ) + token_embs = pipeline.text_encoder( + input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True + ) + token_embs[input_ids == arcface_token_id] = face_embs_padded + + prompt_embeds = pipeline.text_encoder( + input_ids=input_ids, input_token_embs=token_embs + )[0] + + return prompt_embeds + + +def project_face_embs(input_ids, text_encoder, face_embs, arcface_token_id): + """ + face_embs: (N, 512) normalized ArcFace embeddings + """ + + token_embs = text_encoder(input_ids=input_ids, return_token_embs=True) + token_embs[input_ids == arcface_token_id] = face_embs + + prompt_embeds = text_encoder(input_ids=input_ids, input_token_embs=token_embs)[0] + + return prompt_embeds + + def main(): args = parse_args() @@ -514,7 +823,9 @@ def main(): ) logging_dir = os.path.join(args.output_dir, args.logging_dir) - accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + accelerator_project_config = ProjectConfiguration( + project_dir=args.output_dir, logging_dir=logging_dir + ) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, @@ -550,20 +861,30 @@ def main(): if args.push_to_hub: repo_id = create_repo( - repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, ).repo_id # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, ) def deepspeed_zero_init_disabled_context_manager(): """ returns either a context list that includes one that will disable zero.Init or an empty context list """ - deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + deepspeed_plugin = ( + AcceleratorState().deepspeed_plugin + if accelerate.state.is_initialized() + else None + ) if deepspeed_plugin is None: return [] @@ -579,28 +900,59 @@ def deepspeed_zero_init_disabled_context_manager(): # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. with ContextManagers(deepspeed_zero_init_disabled_context_manager()): - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + text_encoder = CLIPTextModelWrapper.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision, + ) + + from insightface.app import FaceAnalysis + + app = FaceAnalysis( + name="buffalo_l", + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) + app.prepare(ctx_id=0, det_size=(640, 640)) - # Freeze vae and text_encoder and set unet to trainable + # Freeze vae and set unet and TE to trainable vae.requires_grad_(False) - text_encoder.requires_grad_(False) unet.train() + text_encoder.train() + text_encoder.requires_grad_(False) + text_encoder.text_model.encoder.requires_grad_(True) + text_encoder.text_model.final_layer_norm.requires_grad_(True) + training_parameters = ( + list(unet.parameters()) + + list(text_encoder.text_model.encoder.parameters()) + + list(text_encoder.text_model.final_layer_norm.parameters()) + ) # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant, + ) + ema_unet = EMAModel( + ema_unet.parameters(), + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, ) - ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -613,7 +965,9 @@ def deepspeed_zero_init_disabled_context_manager(): ) unet.enable_xformers_memory_efficient_attention() else: - raise ValueError("xformers is not available. Make sure it is installed correctly") + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): @@ -631,7 +985,9 @@ def save_model_hook(models, weights, output_dir): def load_model_hook(models, input_dir): if args.use_ema: - load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + load_model = EMAModel.from_pretrained( + os.path.join(input_dir, "unet_ema"), UNet2DConditionModel + ) ema_unet.load_state_dict(load_model.state_dict()) ema_unet.to(accelerator.device) del load_model @@ -641,7 +997,9 @@ def load_model_hook(models, input_dir): model = models.pop() # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + load_model = UNet2DConditionModel.from_pretrained( + input_dir, subfolder="unet" + ) model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) @@ -660,7 +1018,10 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) # Initialize the optimizer @@ -677,7 +1038,7 @@ def load_model_hook(models, input_dir): optimizer_cls = torch.optim.AdamW optimizer = optimizer_cls( - unet.parameters(), + training_parameters, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -716,47 +1077,32 @@ def load_model_hook(models, input_dir): # 6. Get the column names for input/target. dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: - image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + image_column = ( + dataset_columns[0] if dataset_columns is not None else column_names[0] + ) else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" ) - if args.caption_column is None: - caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] - else: - caption_column = args.caption_column - if caption_column not in column_names: - raise ValueError( - f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" - ) - - # Preprocessing the datasets. - # We need to tokenize input captions and transform the images. - def tokenize_captions(examples, is_train=True): - captions = [] - for caption in examples[caption_column]: - if isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - # take a random caption if there are multiple - captions.append(random.choice(caption) if is_train else caption[0]) - else: - raise ValueError( - f"Caption column `{caption_column}` should contain either strings or lists of strings." - ) - inputs = tokenizer( - captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" - ) - return inputs.input_ids # Preprocessing the datasets. train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.Resize( + args.resolution, interpolation=transforms.InterpolationMode.BILINEAR + ), + ( + transforms.CenterCrop(args.resolution) + if args.center_crop + else transforms.RandomCrop(args.resolution) + ), + ( + transforms.RandomHorizontalFlip() + if args.random_flip + else transforms.Lambda(lambda x: x) + ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -765,20 +1111,22 @@ def tokenize_captions(examples, is_train=True): def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] examples["pixel_values"] = [train_transforms(image) for image in images] - examples["input_ids"] = tokenize_captions(examples) return examples with accelerator.main_process_first(): if args.max_train_samples is not None: - dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + dataset["train"] = ( + dataset["train"] + .shuffle(seed=args.seed) + .select(range(args.max_train_samples)) + ) # Set the training transforms train_dataset = dataset["train"].with_transform(preprocess_train) def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = torch.stack([example["input_ids"] for example in examples]) - return {"pixel_values": pixel_values, "input_ids": input_ids} + return {"pixel_values": pixel_values} # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( @@ -791,7 +1139,9 @@ def collate_fn(examples): # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -826,7 +1176,9 @@ def collate_fn(examples): vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -846,13 +1198,19 @@ def unwrap_model(model): return model # Train! - total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + total_batch_size = ( + args.train_batch_size + * accelerator.num_processes + * args.gradient_accumulation_steps + ) logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 @@ -893,13 +1251,73 @@ def unwrap_model(model): # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) + with torch.no_grad(): + input_ids = tokenizer( + ["photo of a id person"] * args.train_batch_size, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + )["input_ids"].to("cuda") + arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0] + + with torch.no_grad(): + from PIL import Image + + validation_image = args.validation_prompts[0] + img = np.array(Image.open(validation_image))[:, :, ::-1] + taylor_faces = app.get(img) for epoch in range(first_epoch, args.num_train_epochs): train_loss = 0.0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): + with torch.no_grad(): + batch_size = batch["pixel_values"].shape[0] + indices = [] + face_embeddings = [] + for batch_index in range(batch_size): + tensor_img = batch["pixel_values"][batch_index] + numpy_img = ( + (tensor_img * 255).cpu().permute(1, 2, 0).detach().numpy() + ) + + faces = app.get(numpy_img) + print(f"detected {len(faces)} faces!") + if len(faces) == 0: + print("couldn't detect faces") + continue + + found_face = sorted( + faces, + key=lambda x: (x["bbox"][2] - x["bbox"][0]) + * (x["bbox"][3] - x["bbox"][1]), + )[ + -1 + ] # select largest face (if more than one detected) + face_embedding = torch.from_numpy( + found_face.normed_embedding + ).unsqueeze(0) + face_embedding = F.pad(face_embedding, (0, 768 - 512)) + face_embeddings.append(face_embedding) + indices.append(batch_index) + + # (bs, 1, 768) + face_embeddings = torch.cat(face_embeddings, dim=0).to( + device="cuda", dtype=weight_dtype + ) + # Recompute the pixel values from the indexes in face embeddings + batch["pixel_values"] = batch["pixel_values"][list(indices)] + batch_input_ids = input_ids[: batch["pixel_values"].shape[0]] + + if len(face_embeddings) != batch["pixel_values"].shape[0]: + print("Skipping batch due to no face found in one of the images") + continue + # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() + latents = vae.encode( + batch["pixel_values"].to(weight_dtype) + ).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents @@ -907,57 +1325,85 @@ def unwrap_model(model): if args.noise_offset: # https://www.crosslabs.org//blog/diffusion-with-offset-noise noise += args.noise_offset * torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + (latents.shape[0], latents.shape[1], 1, 1), + device=latents.device, ) if args.input_perturbation: - new_noise = noise + args.input_perturbation * torch.randn_like(noise) + new_noise = noise + args.input_perturbation * torch.randn_like( + noise + ) bsz = latents.shape[0] # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, + noise_scheduler.config.num_train_timesteps, + (bsz,), + device=latents.device, + ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.input_perturbation: - noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + noisy_latents = noise_scheduler.add_noise( + latents, new_noise, timesteps + ) else: noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] + encoder_hidden_states = project_face_embs( + batch_input_ids, + text_encoder, + face_embeddings, + arcface_token_id, + ) # Get the target for loss depending on the prediction type if args.prediction_type is not None: # set prediction_type of scheduler if defined - noise_scheduler.register_to_config(prediction_type=args.prediction_type) + noise_scheduler.register_to_config( + prediction_type=args.prediction_type + ) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] if args.snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) else: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( - dim=1 - )[0] + mse_loss_weights = torch.stack( + [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] if noise_scheduler.config.prediction_type == "epsilon": mse_loss_weights = mse_loss_weights / snr elif noise_scheduler.config.prediction_type == "v_prediction": mse_loss_weights = mse_loss_weights / (snr + 1) - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). @@ -967,7 +1413,7 @@ def unwrap_model(model): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + accelerator.clip_grad_norm_(training_parameters, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -975,7 +1421,7 @@ def unwrap_model(model): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: - ema_unet.step(unet.parameters()) + ema_unet.step(training_parameters) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) @@ -986,59 +1432,78 @@ def unwrap_model(model): # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) - checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] - checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + checkpoints = [ + d for d in checkpoints if d.startswith("checkpoint") + ] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1]) + ) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: - num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + num_to_remove = ( + len(checkpoints) - args.checkpoints_total_limit + 1 + ) removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) - logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + logger.info( + f"removing checkpoints: {', '.join(removing_checkpoints)}" + ) for removing_checkpoint in removing_checkpoints: - removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + removing_checkpoint = os.path.join( + args.output_dir, removing_checkpoint + ) shutil.rmtree(removing_checkpoint) - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + save_path = os.path.join( + args.output_dir, f"checkpoint-{global_step}" + ) accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } progress_bar.set_postfix(**logs) if global_step >= args.max_train_steps: break - if accelerator.is_main_process: - if args.validation_prompts is not None and epoch % args.validation_epochs == 0: - if args.use_ema: - # Store the UNet parameters temporarily and load the EMA parameters to perform inference. - ema_unet.store(unet.parameters()) - ema_unet.copy_to(unet.parameters()) - log_validation( - vae, - text_encoder, - tokenizer, - unet, - args, - accelerator, - weight_dtype, - global_step, - ) - if args.use_ema: - # Switch back to the original UNet parameters. - ema_unet.restore(unet.parameters()) + if accelerator.is_main_process: + if ( + args.validation_prompts is not None + and global_step % args.validation_steps == 0 + ): + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(training_parameters) + ema_unet.copy_to(training_parameters) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + app, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(training_parameters) # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unwrap_model(unet) if args.use_ema: - ema_unet.copy_to(unet.parameters()) + ema_unet.copy_to(training_parameters) pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1064,11 +1529,17 @@ def unwrap_model(model): if args.seed is None: generator = None else: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device).manual_seed( + args.seed + ) for i in range(len(args.validation_prompts)): with torch.autocast("cuda"): - image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + image = pipeline( + args.validation_prompts[i], + num_inference_steps=20, + generator=generator, + ).images[0] images.append(image) if args.push_to_hub: