diff --git a/examples/conditional_image_generation/README.md b/examples/conditional_image_generation/README.md new file mode 100644 index 000000000000..9cd10c35a621 --- /dev/null +++ b/examples/conditional_image_generation/README.md @@ -0,0 +1,103 @@ +## Training examples + +Creating a training image set is [described in a different document](https://huggingface.co/docs/datasets/image_process#image-datasets). + +### Installing the dependencies + +Before running the scipts, make sure to install the library's training dependencies: + +```bash +pip install diffusers[training] accelerate datasets +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### conditional example + +TODO: prepare examples + +### Using your own data + +To use your own dataset, there are 2 ways: +- you can either provide your own folder as `--train_data_dir` +- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. + +Below, we explain both in more detail. + +#### Provide the dataset as a folder + +If you provide your own folders with images, the script expects the following directory structure: + +```bash +data_dir/xxx.png +data_dir/xxy.png +data_dir/[...]/xxz.png +``` + +In other words, the script will take care of gathering all images inside the folder. You can then run the script like this: + +```bash +accelerate launch train_conditional.py \ + --train_data_dir \ + +``` + +Internally, the script will use the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature which will automatically turn the folders into 🤗 Dataset objects. + +#### Upload your data to the hub, as a (possibly private) repo + +It's very easy (and convenient) to upload your image dataset to the hub using the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder) feature available in 🤗 Datasets. Simply do the following: + +```python +from datasets import load_dataset + +# example 1: local folder +dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") + +# example 2: local files (suppoted formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="path_to_zip_file") + +# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip") + +# example 4: providing several splits +dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]}) +``` + +`ImageFolder` will create an `image` column containing the PIL-encoded images. + +Next, push it to the hub! + +```python +# assuming you have ran the huggingface-cli login command in a terminal +dataset.push_to_hub("name_of_your_dataset") + +# if you want to push to a private repo, simply pass private=True: +dataset.push_to_hub("name_of_your_dataset", private=True) +``` + +and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. + +More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). + +#### How to use in the pipeline + +```python +# make sure you're logged in with `huggingface-cli login` +from torch import autocast +from diffusers import StableDiffusionPipeline + +# Replace it to model that you want to use. +unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_auth_token=True) + +pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", unet=unet use_auth_token=True) +pipe = pipe.to("cuda") + +prompt = "a photo of an astronaut riding a horse on mars" +with autocast("cuda"): + image = pipe(prompt)["sample"][0] +``` \ No newline at end of file diff --git a/examples/conditional_image_generation/dataset_example/huggingface.png b/examples/conditional_image_generation/dataset_example/huggingface.png new file mode 100644 index 000000000000..c4f5bbd66df2 Binary files /dev/null and b/examples/conditional_image_generation/dataset_example/huggingface.png differ diff --git a/examples/conditional_image_generation/requirements.txt b/examples/conditional_image_generation/requirements.txt new file mode 100644 index 000000000000..bbc690556020 --- /dev/null +++ b/examples/conditional_image_generation/requirements.txt @@ -0,0 +1,3 @@ +accelerate +torchvision +datasets diff --git a/examples/conditional_image_generation/train_conditional.py b/examples/conditional_image_generation/train_conditional.py new file mode 100644 index 000000000000..eb333c86779f --- /dev/null +++ b/examples/conditional_image_generation/train_conditional.py @@ -0,0 +1,265 @@ +import argparse +import math +import os + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from accelerate.logging import get_logger +from datasets import load_dataset +# from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers import DDPMPipeline, DDPMScheduler, UNet2DConditionModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + + +logger = get_logger(__name__) + +def main(args): + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + # FIXME implement training script + model = UNet2DConditionModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") + optimizer = torch.optim.AdamW( + model.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # it is needed to generate tokenized input to train. + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + + if args.dataset_name is not None: + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + use_auth_token=True if args.use_auth_token else None, + split="train", + ) + else: + dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + if accelerator.is_main_process: + run = os.path.split(__file__)[-1].split(".")[0] + accelerator.init_trackers(run) + + global_step = 0 + for epoch in range(args.num_epochs): + model.train() + progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bsz = clean_images.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device + ).long() + + # FIXME The input should probably select the appropriate one from the dataset. + # Sample a text input + uncond_input = tokenizer( + [""] * args.eval_batch_size, padding="max_length", max_length=77, return_tensors="pt" + ) + uncond_embeddings = text_encoder(uncond_input.input_ids.to(clean_images.device))[0] + hidden_state = uncond_embeddings + + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) + + with accelerator.accumulate(model): + # Predict the noise residual + # FIXME Implement a successfully trainable model and training script + noise_pred = model(noisy_images, timesteps, encoder_hidden_states=hidden_state)["sample"] + loss = F.mse_loss(noise_pred, noise) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + if args.use_ema: + ema_model.step(model) + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + + accelerator.wait_for_everyone() + + # Generate sample images for visual inspection + if accelerator.is_main_process: + if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + pipeline = DDPMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + scheduler=noise_scheduler, + ) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] + + # denormalize the images and save to tensorboard + images_processed = (images * 255).round().astype("uint8") + accelerator.trackers[0].writer.add_images( + "test_samples", images_processed.transpose(0, 3, 1, 2), epoch + ) + + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + accelerator.end_training() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset_name", type=str, default=None) + parser.add_argument("--dataset_config_name", type=str, default=None) + parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") + parser.add_argument("--output_dir", type=str, default="ddpm-model-64") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--cache_dir", type=str, default=None) + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--train_batch_size", type=int, default=16) + parser.add_argument("--eval_batch_size", type=int, default=16) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10) + parser.add_argument("--save_model_epochs", type=int, default=10) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler", type=str, default="cosine") + parser.add_argument("--lr_warmup_steps", type=int, default=500) + parser.add_argument("--adam_beta1", type=float, default=0.95) + parser.add_argument("--adam_beta2", type=float, default=0.999) + parser.add_argument("--adam_weight_decay", type=float, default=1e-6) + parser.add_argument("--adam_epsilon", type=float, default=1e-08) + parser.add_argument("--use_ema", action="store_true", default=True) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3 / 4) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--use_auth_token", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument("--logging_dir", type=str, default="logs") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + main(args)