diff --git a/.gitignore b/.gitignore index cf8183463613..1964b12f941e 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,5 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +*.png diff --git a/examples/progressive_distillation/colab.py b/examples/progressive_distillation/colab.py new file mode 100644 index 000000000000..511481142c8c --- /dev/null +++ b/examples/progressive_distillation/colab.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass + + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + learning_rate = 1e-4 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddpm-butterflies-128" # the model namy locally and on the HF Hub + + push_to_hub = True # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +config = TrainingConfig() + +from datasets import load_dataset + +config.dataset_name = "huggan/smithsonian_butterflies_subset" +dataset = load_dataset(config.dataset_name, split="train") +from torchvision import transforms + +preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + +dataset.set_transform(transform) +import torch +import os + +from diffusers import UNet2DModel, DistillationPipeline, DDPMPipeline, DDPMScheduler, DDIMPipeline, DDIMScheduler +from accelerate import Accelerator + + +teacher = UNet2DModel.from_pretrained("bglick13/ddim-butterflies-128-v-diffusion", subfolder="unet") + +# accelerator = Accelerator( +# mixed_precision=config.mixed_precision, +# gradient_accumulation_steps=config.gradient_accumulation_steps, +# log_with="tensorboard", +# logging_dir=os.path.join(config.output_dir, "logs"), +# ) +# teacher = accelerator.prepare(teacher) +distiller = DistillationPipeline() +n_teacher_trainsteps = 1000 +new_teacher, distilled_ema, distill_accelrator = distiller( + teacher, + n_teacher_trainsteps, + dataset, + epochs=100, + batch_size=32, + mixed_precision="fp16", + sample_every=1, + gamma=0.0, + lr=1e-4, +) +new_scheduler = DDIMScheduler( + num_train_timesteps=500, beta_schedule="squaredcos_cap_v2", variance_type="v_diffusion", prediction_type="v" +) +pipeline = DDIMPipeline( + unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model), + scheduler=new_scheduler, +) + +# run pipeline in inference (sample random noise and denoise) +images = pipeline(batch_size=4, output_type="numpy", generator=torch.manual_seed(0)).images + +# denormalize the images and save to tensorboard +images_processed = (images * 255).round().astype("uint8") +from PIL import Image + +img = Image.fromarray(images_processed[0]) +img.save("denoised.png") diff --git a/examples/progressive_distillation/image_diffusion.ipynb b/examples/progressive_distillation/image_diffusion.ipynb new file mode 100644 index 000000000000..74289d1835b1 --- /dev/null +++ b/examples/progressive_distillation/image_diffusion.ipynb @@ -0,0 +1,273 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n", + "WARNING:torch.distributed.elastic.multiprocessing.redirects:NOTE: Redirects are currently not supported in Windows or MacOs.\n" + ] + } + ], + "source": [ + "import torch\n", + "from PIL import Image\n", + "from diffusers import AutoencoderKL, UNet2DModel, DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, DistillationPipeline\n", + "from diffusers.optimization import get_scheduler\n", + "from diffusers.training_utils import EMAModel\n", + "import math\n", + "import requests\n", + "from torchvision.transforms import (\n", + " CenterCrop,\n", + " Compose,\n", + " InterpolationMode,\n", + " Normalize,\n", + " RandomHorizontalFlip,\n", + " Resize,\n", + " ToTensor,\n", + " ToPILImage\n", + ")\n", + "from torch.utils.data import Dataset\n", + "from accelerate import Accelerator\n", + "import utils\n", + "from tqdm import tqdm\n", + "import torch.nn.functional as F\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.manual_seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "training_config = utils.DiffusionTrainingArgs()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Load an image of my dog for this example\n", + "\n", + "image_url = \"https://i.imgur.com/IJcs4Aa.jpeg\"\n", + "image = Image.open(requests.get(image_url, stream=True).raw)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the transforms to apply to the image for training\n", + "augmentations = utils.get_train_transforms(training_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class SingleImageDataset(Dataset):\n", + " def __init__(self, image, batch_size):\n", + " self.image = image\n", + " self.batch_size = batch_size\n", + "\n", + " def __len__(self):\n", + " return self.batch_size\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.image\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "train_image = augmentations(image.convert(\"RGB\"))\n", + "train_dataset = SingleImageDataset(train_image, training_config.batch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2b23b591496741a299b75e4e9448b29a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading: 0%| | 0.00/455M [00:00 {N // 2}\")\n", + " teacher, distilled_ema, distill_accelrator = distiller(teacher, N, train_dataset, epochs=300, batch_size=training_config.batch_size)\n", + " N = N // 2\n", + " new_scheduler = DDPMScheduler(num_train_timesteps=N, beta_schedule=\"squaredcos_cap_v2\")\n", + " pipeline = DDPMPipeline(\n", + " unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model if training_config.use_ema else teacher),\n", + " scheduler=new_scheduler,\n", + " )\n", + "\n", + " # run pipeline in inference (sample random noise and denoise)\n", + " images = pipeline(generator=generator, batch_size=training_config.batch_size, output_type=\"numpy\").images\n", + "\n", + " # denormalize the images and save to tensorboard\n", + " images_processed = (images * 255).round().astype(\"uint8\")\n", + " distilled_images.append(images_processed[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Display train image for reference\n", + "train_image_display = train_image * 0.5 + 0.5\n", + "train_image_display = ToPILImage()(train_image_display)\n", + "display(train_image_display)\n", + "\n", + "for i, image in enumerate(distilled_images):\n", + " print(f\"Distilled image {i}\")\n", + " display(Image.fromarray(image))\n", + " Image.fromarray(image).save(f\"distilled_{i}.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "display(Image.fromarray(images_processed[0]))\n", + "display(Image.fromarray(images_processed[1]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.6 ('diffusers')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "77f6871a522595648ebba7232d315a2f946cc4cd5f56470cb61e517ec9b94e2e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/progressive_distillation/train_butterflies.py b/examples/progressive_distillation/train_butterflies.py new file mode 100644 index 000000000000..aa3fffdd9026 --- /dev/null +++ b/examples/progressive_distillation/train_butterflies.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass + + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddpm-butterflies-128" # the model namy locally and on the HF Hub + + push_to_hub = False # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +config = TrainingConfig() +from datasets import load_dataset + +config.dataset_name = "huggan/smithsonian_butterflies_subset" +dataset = load_dataset(config.dataset_name, split="train") + +import matplotlib.pyplot as plt + +from torchvision import transforms + +preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + +dataset.set_transform(transform) + +import torch + +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) + +from diffusers import UNet2DModel + + +model = UNet2DModel( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block + down_block_types=( + "DownBlock2D", # a regular ResNet downsampling block + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", # a regular ResNet upsampling block + "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), +) + +from diffusers import DDPMScheduler + +noise_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", +) + +import torch +import torch.nn.functional as F + +from PIL import Image + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + + +from diffusers.optimization import get_cosine_schedule_with_warmup + +lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * config.num_epochs), +) + +from diffusers import DDPMPipeline + +import math + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def evaluate(config, epoch, pipeline): + # Sample some images from random noise (this is the backward diffusion process). + # The default pipeline output type is `List[PIL.Image]` + images = pipeline( + batch_size=config.eval_batch_size, + generator=torch.manual_seed(config.seed), + ).images + + # Make a grid out of the images + image_grid = make_grid(images, rows=4, cols=4) + + # Save the images + test_dir = os.path.join(config.output_dir, "samples") + os.makedirs(test_dir, exist_ok=True) + image_grid.save(f"{test_dir}/{epoch:04d}.png") + + +from accelerate import Accelerator +from diffusers.hub_utils import init_git_repo, push_to_hub + +from tqdm.auto import tqdm +import os + + +def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): + # Initialize accelerator and tensorboard logging + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with="tensorboard", + logging_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if config.push_to_hub: + repo = init_git_repo(config, at_init=True) + accelerator.init_trackers("train_example") + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + + # Now you train the model + for epoch in range(config.num_epochs): + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + # Sample noise to add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bs = clean_images.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() + + with accelerator.accumulate(model): + # Predict the noise residual + alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device) + z_t = alpha_t * clean_images + sigma_t * noise + noise_pred = model(z_t, timesteps).sample + v = alpha_t * noise - sigma_t * clean_images + loss = F.mse_loss(noise_pred, v) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # After each epoch you optionally sample some demo images with evaluate() and save the model + if accelerator.is_main_process: + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + evaluate(config, epoch, pipeline) + + if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: + if config.push_to_hub: + push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True) + else: + pipeline.save_pretrained(config.output_dir) + + +"""## Let's train! + +Let's launch the training (including multi-GPU training) from the notebook using Accelerate's `notebook_launcher` function: +""" + +from accelerate import notebook_launcher + +args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) + +train_loop(*args) + +"""Let's have a look at the final image grid produced by the trained diffusion model:""" + +import glob + +sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) +Image.open(sample_images[-1]) + +"""Not bad! There's room for improvement of course, so feel free to play with the hyperparameters, model definition and image augmentations ๐Ÿค— + +If you've chosen to upload the model to the Hugging Face Hub, its repository should now look like so: +https://huggingface.co/anton-l/ddpm-butterflies-128 + +If you want to dive deeper into the code, we also have more advanced training scripts with features like Exponential Moving Average of model weights here: + +https://github.com/huggingface/diffusers/tree/main/examples +""" diff --git a/examples/progressive_distillation/train_teacher_model.ipynb b/examples/progressive_distillation/train_teacher_model.ipynb new file mode 100644 index 000000000000..ec1a47d25945 --- /dev/null +++ b/examples/progressive_distillation/train_teacher_model.ipynb @@ -0,0 +1,835 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + } + ], + "source": [ + "import torch\n", + "from PIL import Image\n", + "from diffusers import AutoencoderKL, UNet2DModel, DDPMScheduler, DDPMPipeline\n", + "from diffusers.optimization import get_scheduler\n", + "from diffusers.training_utils import EMAModel\n", + "import math\n", + "import requests\n", + "from torchvision.transforms import (\n", + " CenterCrop,\n", + " Compose,\n", + " InterpolationMode,\n", + " Normalize,\n", + " RandomHorizontalFlip,\n", + " Resize,\n", + " ToTensor,\n", + ")\n", + "from accelerate import Accelerator\n", + "import utils\n", + "from tqdm import tqdm\n", + "import torch.nn.functional as F" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "training_config = utils.DiffusionTrainingArgs()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Load an image of my dog for this example\n", + "\n", + "image_url = \"https://i.imgur.com/IJcs4Aa.jpeg\"\n", + "image = Image.open(requests.get(image_url, stream=True).raw)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the transforms to apply to the image for training\n", + "augmentations = utils.get_train_transforms(training_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "train_image = augmentations(image.convert(\"RGB\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "accelerator = Accelerator(\n", + " gradient_accumulation_steps=training_config.gradient_accumulation_steps,\n", + " mixed_precision=training_config.mixed_precision,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "model = utils.get_unet(training_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "noise_scheduler = DDPMScheduler(num_train_timesteps=1000)\n", + "optimizer = torch.optim.AdamW(\n", + " model.parameters(),\n", + " lr=training_config.learning_rate,\n", + " betas=(training_config.adam_beta1, training_config.adam_beta2),\n", + " weight_decay=training_config.adam_weight_decay,\n", + " eps=training_config.adam_epsilon,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "lr_scheduler = get_scheduler(\n", + " training_config.lr_scheduler,\n", + " optimizer=optimizer,\n", + " num_warmup_steps=training_config.lr_warmup_steps,\n", + " num_training_steps=(training_config.num_epochs) // training_config.gradient_accumulation_steps,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer, train_image, lr_scheduler = accelerator.prepare(\n", + " model, optimizer, train_image, lr_scheduler\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "num_update_steps_per_epoch = math.ceil(training_config.batch_size / training_config.gradient_accumulation_steps)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "ema_model = EMAModel(model, inv_gamma=training_config.ema_inv_gamma, power=training_config.ema_power, max_value=training_config.ema_max_decay)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "if accelerator.is_main_process:\n", + " run = \"train.py\"\n", + " accelerator.init_trackers(run)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 0: 6%|โ–‹ | 1/16 [00:00<00:06, 2.16it/s, ema_decay=0, loss=1.09, lr=2e-7, step=1]\n", + "Epoch 1: 6%|โ–‹ | 1/16 [00:00<00:03, 4.52it/s, ema_decay=0, loss=1.09, lr=4e-7, step=2]\n", + "Epoch 2: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.405, loss=1.09, lr=6e-7, step=3]\n", + "Epoch 3: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.561, loss=1.09, lr=8e-7, step=4]\n", + "Epoch 4: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.646, loss=1.08, lr=1e-6, step=5]\n", + "Epoch 5: 6%|โ–‹ | 1/16 [00:00<00:03, 4.83it/s, ema_decay=0.701, loss=1.08, lr=1.2e-6, step=6]\n", + "Epoch 6: 6%|โ–‹ | 1/16 [00:00<00:02, 5.37it/s, ema_decay=0.739, loss=1.08, lr=1.4e-6, step=7]\n", + "Epoch 7: 6%|โ–‹ | 1/16 [00:00<00:02, 5.53it/s, ema_decay=0.768, loss=1.07, lr=1.6e-6, step=8]\n", + "Epoch 8: 6%|โ–‹ | 1/16 [00:00<00:02, 5.51it/s, ema_decay=0.79, loss=1.08, lr=1.8e-6, step=9]\n", + "Epoch 9: 6%|โ–‹ | 1/16 [00:00<00:03, 4.77it/s, ema_decay=0.808, loss=1.06, lr=2e-6, step=10]\n", + "Epoch 10: 6%|โ–‹ | 1/16 [00:00<00:02, 5.09it/s, ema_decay=0.822, loss=1.06, lr=2.2e-6, step=11]\n", + "Epoch 11: 6%|โ–‹ | 1/16 [00:00<00:03, 4.37it/s, ema_decay=0.834, loss=1.05, lr=2.4e-6, step=12]\n", + "Epoch 12: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.845, loss=1.06, lr=2.6e-6, step=13]\n", + "Epoch 13: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.854, loss=1.05, lr=2.8e-6, step=14]\n", + "Epoch 14: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.862, loss=1.04, lr=3e-6, step=15]\n", + "Epoch 15: 6%|โ–‹ | 1/16 [00:00<00:03, 4.99it/s, ema_decay=0.869, loss=1.03, lr=3.2e-6, step=16]\n", + "Epoch 16: 6%|โ–‹ | 1/16 [00:00<00:02, 5.03it/s, ema_decay=0.875, loss=1.02, lr=3.4e-6, step=17]\n", + "Epoch 17: 6%|โ–‹ | 1/16 [00:00<00:03, 4.72it/s, ema_decay=0.881, loss=1.01, lr=3.6e-6, step=18]\n", + "Epoch 18: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.886, loss=1, lr=3.8e-6, step=19]\n", + "Epoch 19: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.89, loss=0.995, lr=4e-6, step=20]\n", + "Epoch 20: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.894, loss=0.981, lr=4.2e-6, step=21]\n", + "Epoch 21: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.898, loss=0.965, lr=4.4e-6, step=22]\n", + "Epoch 22: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.902, loss=0.966, lr=4.6e-6, step=23]\n", + "Epoch 23: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.905, loss=0.944, lr=4.8e-6, step=24]\n", + "Epoch 24: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.908, loss=0.944, lr=5e-6, step=25]\n", + "Epoch 25: 6%|โ–‹ | 1/16 [00:00<00:03, 4.95it/s, ema_decay=0.911, loss=0.936, lr=5.2e-6, step=26]\n", + "Epoch 26: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.913, loss=0.905, lr=5.4e-6, step=27]\n", + "Epoch 27: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.916, loss=0.89, lr=5.6e-6, step=28]\n", + "Epoch 28: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.918, loss=0.882, lr=5.8e-6, step=29]\n", + "Epoch 29: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.92, loss=0.869, lr=6e-6, step=30]\n", + "Epoch 30: 6%|โ–‹ | 1/16 [00:00<00:02, 5.07it/s, ema_decay=0.922, loss=0.858, lr=6.2e-6, step=31]\n", + "Epoch 31: 6%|โ–‹ | 1/16 [00:00<00:02, 5.37it/s, ema_decay=0.924, loss=0.861, lr=6.4e-6, step=32]\n", + "Epoch 32: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.926, loss=0.856, lr=6.6e-6, step=33]\n", + "Epoch 33: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.927, loss=0.835, lr=6.8e-6, step=34]\n", + "Epoch 34: 6%|โ–‹ | 1/16 [00:00<00:03, 4.51it/s, ema_decay=0.929, loss=0.812, lr=7e-6, step=35]\n", + "Epoch 35: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.931, loss=0.785, lr=7.2e-6, step=36]\n", + "Epoch 36: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.932, loss=0.778, lr=7.4e-6, step=37]\n", + "Epoch 37: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.933, loss=0.793, lr=7.6e-6, step=38]\n", + "Epoch 38: 6%|โ–‹ | 1/16 [00:00<00:02, 5.09it/s, ema_decay=0.935, loss=0.737, lr=7.8e-6, step=39]\n", + "Epoch 39: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.936, loss=0.714, lr=8e-6, step=40]\n", + "Epoch 40: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.937, loss=0.728, lr=8.2e-6, step=41]\n", + "Epoch 41: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.938, loss=0.689, lr=8.4e-6, step=42]\n", + "Epoch 42: 6%|โ–‹ | 1/16 [00:00<00:03, 4.48it/s, ema_decay=0.939, loss=0.699, lr=8.6e-6, step=43]\n", + "Epoch 43: 6%|โ–‹ | 1/16 [00:00<00:03, 4.96it/s, ema_decay=0.94, loss=0.663, lr=8.8e-6, step=44]\n", + "Epoch 44: 6%|โ–‹ | 1/16 [00:00<00:03, 4.96it/s, ema_decay=0.941, loss=0.661, lr=9e-6, step=45]\n", + "Epoch 45: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.942, loss=0.632, lr=9.2e-6, step=46]\n", + "Epoch 46: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.943, loss=0.597, lr=9.4e-6, step=47]\n", + "Epoch 47: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.944, loss=0.613, lr=9.6e-6, step=48]\n", + "Epoch 48: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.945, loss=0.597, lr=9.8e-6, step=49]\n", + "Epoch 49: 6%|โ–‹ | 1/16 [00:00<00:03, 4.68it/s, ema_decay=0.946, loss=0.612, lr=1e-5, step=50]\n", + "Epoch 50: 6%|โ–‹ | 1/16 [00:00<00:03, 4.64it/s, ema_decay=0.947, loss=0.601, lr=1.02e-5, step=51]\n", + "Epoch 51: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.948, loss=0.532, lr=1.04e-5, step=52]\n", + "Epoch 52: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.948, loss=0.495, lr=1.06e-5, step=53]\n", + "Epoch 53: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.949, loss=0.516, lr=1.08e-5, step=54]\n", + "Epoch 54: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.95, loss=0.493, lr=1.1e-5, step=55]\n", + "Epoch 55: 6%|โ–‹ | 1/16 [00:00<00:03, 4.44it/s, ema_decay=0.95, loss=0.461, lr=1.12e-5, step=56]\n", + "Epoch 56: 6%|โ–‹ | 1/16 [00:00<00:03, 4.91it/s, ema_decay=0.951, loss=0.462, lr=1.14e-5, step=57]\n", + "Epoch 57: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.952, loss=0.521, lr=1.16e-5, step=58]\n", + "Epoch 58: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.952, loss=0.431, lr=1.18e-5, step=59]\n", + "Epoch 59: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.953, loss=0.418, lr=1.2e-5, step=60]\n", + "Epoch 60: 6%|โ–‹ | 1/16 [00:00<00:03, 4.59it/s, ema_decay=0.954, loss=0.41, lr=1.22e-5, step=61]\n", + "Epoch 61: 6%|โ–‹ | 1/16 [00:00<00:03, 4.72it/s, ema_decay=0.954, loss=0.376, lr=1.24e-5, step=62]\n", + "Epoch 62: 6%|โ–‹ | 1/16 [00:00<00:03, 3.85it/s, ema_decay=0.955, loss=0.362, lr=1.26e-5, step=63]\n", + "Epoch 63: 6%|โ–‹ | 1/16 [00:00<00:03, 4.20it/s, ema_decay=0.955, loss=0.352, lr=1.28e-5, step=64]\n", + "Epoch 64: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.956, loss=0.348, lr=1.3e-5, step=65]\n", + "Epoch 65: 6%|โ–‹ | 1/16 [00:00<00:03, 4.45it/s, ema_decay=0.956, loss=0.327, lr=1.32e-5, step=66]\n", + "Epoch 66: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.957, loss=0.372, lr=1.34e-5, step=67]\n", + "Epoch 67: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.957, loss=0.299, lr=1.36e-5, step=68]\n", + "Epoch 68: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.958, loss=0.306, lr=1.38e-5, step=69]\n", + "Epoch 69: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.958, loss=0.373, lr=1.4e-5, step=70]\n", + "Epoch 70: 6%|โ–‹ | 1/16 [00:00<00:02, 5.36it/s, ema_decay=0.959, loss=0.34, lr=1.42e-5, step=71]\n", + "Epoch 71: 6%|โ–‹ | 1/16 [00:00<00:03, 4.27it/s, ema_decay=0.959, loss=0.239, lr=1.44e-5, step=72]\n", + "Epoch 72: 6%|โ–‹ | 1/16 [00:00<00:03, 4.34it/s, ema_decay=0.96, loss=0.295, lr=1.46e-5, step=73]\n", + "Epoch 73: 6%|โ–‹ | 1/16 [00:00<00:03, 4.82it/s, ema_decay=0.96, loss=0.225, lr=1.48e-5, step=74]\n", + "Epoch 74: 6%|โ–‹ | 1/16 [00:00<00:03, 4.96it/s, ema_decay=0.96, loss=0.295, lr=1.5e-5, step=75]\n", + "Epoch 75: 6%|โ–‹ | 1/16 [00:00<00:04, 3.02it/s, ema_decay=0.961, loss=0.203, lr=1.52e-5, step=76]\n", + "Epoch 76: 6%|โ–‹ | 1/16 [00:00<00:03, 4.44it/s, ema_decay=0.961, loss=0.203, lr=1.54e-5, step=77]\n", + "Epoch 77: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.962, loss=0.204, lr=1.56e-5, step=78]\n", + "Epoch 78: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.962, loss=0.176, lr=1.58e-5, step=79]\n", + "Epoch 79: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.962, loss=0.197, lr=1.6e-5, step=80]\n", + "Epoch 80: 6%|โ–‹ | 1/16 [00:00<00:03, 4.66it/s, ema_decay=0.963, loss=0.2, lr=1.62e-5, step=81]\n", + "Epoch 81: 6%|โ–‹ | 1/16 [00:00<00:02, 5.09it/s, ema_decay=0.963, loss=0.284, lr=1.64e-5, step=82]\n", + "Epoch 82: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.963, loss=0.211, lr=1.66e-5, step=83]\n", + "Epoch 83: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.964, loss=0.133, lr=1.68e-5, step=84]\n", + "Epoch 84: 6%|โ–‹ | 1/16 [00:00<00:02, 5.07it/s, ema_decay=0.964, loss=0.151, lr=1.7e-5, step=85]\n", + "Epoch 85: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.964, loss=0.172, lr=1.72e-5, step=86]\n", + "Epoch 86: 6%|โ–‹ | 1/16 [00:00<00:03, 4.73it/s, ema_decay=0.965, loss=0.18, lr=1.74e-5, step=87]\n", + "Epoch 87: 6%|โ–‹ | 1/16 [00:00<00:03, 4.83it/s, ema_decay=0.965, loss=0.116, lr=1.76e-5, step=88]\n", + "Epoch 88: 6%|โ–‹ | 1/16 [00:00<00:03, 4.89it/s, ema_decay=0.965, loss=0.124, lr=1.78e-5, step=89]\n", + "Epoch 89: 6%|โ–‹ | 1/16 [00:00<00:02, 5.06it/s, ema_decay=0.965, loss=0.137, lr=1.8e-5, step=90]\n", + "Epoch 90: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.966, loss=0.189, lr=1.82e-5, step=91]\n", + "Epoch 91: 6%|โ–‹ | 1/16 [00:00<00:02, 5.15it/s, ema_decay=0.966, loss=0.131, lr=1.84e-5, step=92]\n", + "Epoch 92: 6%|โ–‹ | 1/16 [00:00<00:03, 4.89it/s, ema_decay=0.966, loss=0.184, lr=1.86e-5, step=93]\n", + "Epoch 93: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.967, loss=0.157, lr=1.88e-5, step=94]\n", + "Epoch 94: 6%|โ–‹ | 1/16 [00:00<00:03, 4.25it/s, ema_decay=0.967, loss=0.128, lr=1.9e-5, step=95]\n", + "Epoch 95: 6%|โ–‹ | 1/16 [00:00<00:02, 5.12it/s, ema_decay=0.967, loss=0.12, lr=1.92e-5, step=96]\n", + "Epoch 96: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.967, loss=0.127, lr=1.94e-5, step=97]\n", + "Epoch 97: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.968, loss=0.189, lr=1.96e-5, step=98]\n", + "Epoch 98: 6%|โ–‹ | 1/16 [00:00<00:02, 5.07it/s, ema_decay=0.968, loss=0.125, lr=1.98e-5, step=99]\n", + "Epoch 99: 6%|โ–‹ | 1/16 [00:00<00:03, 4.55it/s, ema_decay=0.968, loss=0.162, lr=2e-5, step=100]\n", + "Epoch 100: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.968, loss=0.0917, lr=2.02e-5, step=101]\n", + "Epoch 101: 6%|โ–‹ | 1/16 [00:00<00:03, 4.89it/s, ema_decay=0.969, loss=0.117, lr=2.04e-5, step=102]\n", + "Epoch 102: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.969, loss=0.122, lr=2.06e-5, step=103]\n", + "Epoch 103: 6%|โ–‹ | 1/16 [00:00<00:03, 4.97it/s, ema_decay=0.969, loss=0.109, lr=2.08e-5, step=104]\n", + "Epoch 104: 6%|โ–‹ | 1/16 [00:00<00:03, 4.68it/s, ema_decay=0.969, loss=0.0759, lr=2.1e-5, step=105]\n", + "Epoch 105: 6%|โ–‹ | 1/16 [00:00<00:03, 4.39it/s, ema_decay=0.97, loss=0.086, lr=2.12e-5, step=106]\n", + "Epoch 106: 6%|โ–‹ | 1/16 [00:00<00:03, 4.57it/s, ema_decay=0.97, loss=0.0864, lr=2.14e-5, step=107]\n", + "Epoch 107: 6%|โ–‹ | 1/16 [00:00<00:02, 5.03it/s, ema_decay=0.97, loss=0.107, lr=2.16e-5, step=108]\n", + "Epoch 108: 6%|โ–‹ | 1/16 [00:00<00:02, 5.50it/s, ema_decay=0.97, loss=0.158, lr=2.18e-5, step=109]\n", + "Epoch 109: 6%|โ–‹ | 1/16 [00:00<00:02, 5.52it/s, ema_decay=0.97, loss=0.101, lr=2.2e-5, step=110]\n", + "Epoch 110: 6%|โ–‹ | 1/16 [00:00<00:02, 5.00it/s, ema_decay=0.971, loss=0.0762, lr=2.22e-5, step=111]\n", + "Epoch 111: 6%|โ–‹ | 1/16 [00:00<00:02, 5.03it/s, ema_decay=0.971, loss=0.0671, lr=2.24e-5, step=112]\n", + "Epoch 112: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.971, loss=0.0693, lr=2.26e-5, step=113]\n", + "Epoch 113: 6%|โ–‹ | 1/16 [00:00<00:03, 4.95it/s, ema_decay=0.971, loss=0.139, lr=2.28e-5, step=114]\n", + "Epoch 114: 6%|โ–‹ | 1/16 [00:00<00:03, 4.55it/s, ema_decay=0.971, loss=0.111, lr=2.3e-5, step=115]\n", + "Epoch 115: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.972, loss=0.0885, lr=2.32e-5, step=116]\n", + "Epoch 116: 6%|โ–‹ | 1/16 [00:00<00:03, 4.61it/s, ema_decay=0.972, loss=0.115, lr=2.34e-5, step=117]\n", + "Epoch 117: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.972, loss=0.0874, lr=2.36e-5, step=118]\n", + "Epoch 118: 6%|โ–‹ | 1/16 [00:00<00:03, 4.87it/s, ema_decay=0.972, loss=0.151, lr=2.38e-5, step=119]\n", + "Epoch 119: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.972, loss=0.0648, lr=2.4e-5, step=120]\n", + "Epoch 120: 6%|โ–‹ | 1/16 [00:00<00:03, 4.97it/s, ema_decay=0.972, loss=0.104, lr=2.42e-5, step=121]\n", + "Epoch 121: 6%|โ–‹ | 1/16 [00:00<00:02, 5.26it/s, ema_decay=0.973, loss=0.0634, lr=2.44e-5, step=122]\n", + "Epoch 122: 6%|โ–‹ | 1/16 [00:00<00:03, 4.95it/s, ema_decay=0.973, loss=0.061, lr=2.46e-5, step=123]\n", + "Epoch 123: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.973, loss=0.099, lr=2.48e-5, step=124]\n", + "Epoch 124: 6%|โ–‹ | 1/16 [00:00<00:03, 4.30it/s, ema_decay=0.973, loss=0.0714, lr=2.5e-5, step=125]\n", + "Epoch 125: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.973, loss=0.0665, lr=2.52e-5, step=126]\n", + "Epoch 126: 6%|โ–‹ | 1/16 [00:00<00:03, 5.00it/s, ema_decay=0.973, loss=0.0831, lr=2.54e-5, step=127]\n", + "Epoch 127: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.974, loss=0.0801, lr=2.56e-5, step=128]\n", + "Epoch 128: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.974, loss=0.0688, lr=2.58e-5, step=129]\n", + "Epoch 129: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.974, loss=0.0724, lr=2.6e-5, step=130]\n", + "Epoch 130: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.974, loss=0.0863, lr=2.62e-5, step=131]\n", + "Epoch 131: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.974, loss=0.06, lr=2.64e-5, step=132]\n", + "Epoch 132: 6%|โ–‹ | 1/16 [00:00<00:03, 4.87it/s, ema_decay=0.974, loss=0.0578, lr=2.66e-5, step=133]\n", + "Epoch 133: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.974, loss=0.0901, lr=2.68e-5, step=134]\n", + "Epoch 134: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.975, loss=0.0557, lr=2.7e-5, step=135]\n", + "Epoch 135: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.975, loss=0.0742, lr=2.72e-5, step=136]\n", + "Epoch 136: 6%|โ–‹ | 1/16 [00:00<00:03, 5.00it/s, ema_decay=0.975, loss=0.0627, lr=2.74e-5, step=137]\n", + "Epoch 137: 6%|โ–‹ | 1/16 [00:00<00:03, 4.99it/s, ema_decay=0.975, loss=0.0622, lr=2.76e-5, step=138]\n", + "Epoch 138: 6%|โ–‹ | 1/16 [00:00<00:03, 4.76it/s, ema_decay=0.975, loss=0.111, lr=2.78e-5, step=139]\n", + "Epoch 139: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.975, loss=0.0709, lr=2.8e-5, step=140]\n", + "Epoch 140: 6%|โ–‹ | 1/16 [00:00<00:02, 5.37it/s, ema_decay=0.975, loss=0.057, lr=2.82e-5, step=141]\n", + "Epoch 141: 6%|โ–‹ | 1/16 [00:00<00:02, 5.12it/s, ema_decay=0.976, loss=0.0926, lr=2.84e-5, step=142]\n", + "Epoch 142: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.976, loss=0.0583, lr=2.86e-5, step=143]\n", + "Epoch 143: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.976, loss=0.059, lr=2.88e-5, step=144]\n", + "Epoch 144: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.976, loss=0.0514, lr=2.9e-5, step=145]\n", + "Epoch 145: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.976, loss=0.0561, lr=2.92e-5, step=146]\n", + "Epoch 146: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.976, loss=0.071, lr=2.94e-5, step=147]\n", + "Epoch 147: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.976, loss=0.0786, lr=2.96e-5, step=148]\n", + "Epoch 148: 6%|โ–‹ | 1/16 [00:00<00:03, 4.25it/s, ema_decay=0.976, loss=0.0524, lr=2.98e-5, step=149]\n", + "Epoch 149: 6%|โ–‹ | 1/16 [00:00<00:03, 4.41it/s, ema_decay=0.977, loss=0.115, lr=3e-5, step=150]\n", + "Epoch 150: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.977, loss=0.0731, lr=3.02e-5, step=151]\n", + "Epoch 151: 6%|โ–‹ | 1/16 [00:00<00:03, 4.93it/s, ema_decay=0.977, loss=0.0455, lr=3.04e-5, step=152]\n", + "Epoch 152: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.977, loss=0.0566, lr=3.06e-5, step=153]\n", + "Epoch 153: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.977, loss=0.051, lr=3.08e-5, step=154]\n", + "Epoch 154: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.977, loss=0.0746, lr=3.1e-5, step=155]\n", + "Epoch 155: 6%|โ–‹ | 1/16 [00:00<00:02, 5.06it/s, ema_decay=0.977, loss=0.0725, lr=3.12e-5, step=156]\n", + "Epoch 156: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.977, loss=0.0661, lr=3.14e-5, step=157]\n", + "Epoch 157: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.977, loss=0.0697, lr=3.16e-5, step=158]\n", + "Epoch 158: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.978, loss=0.0526, lr=3.18e-5, step=159]\n", + "Epoch 159: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.978, loss=0.0796, lr=3.2e-5, step=160]\n", + "Epoch 160: 6%|โ–‹ | 1/16 [00:00<00:03, 4.88it/s, ema_decay=0.978, loss=0.0802, lr=3.22e-5, step=161]\n", + "Epoch 161: 6%|โ–‹ | 1/16 [00:00<00:02, 5.09it/s, ema_decay=0.978, loss=0.113, lr=3.24e-5, step=162]\n", + "Epoch 162: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.978, loss=0.0446, lr=3.26e-5, step=163]\n", + "Epoch 163: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.978, loss=0.0558, lr=3.28e-5, step=164]\n", + "Epoch 164: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.978, loss=0.0407, lr=3.3e-5, step=165]\n", + "Epoch 165: 6%|โ–‹ | 1/16 [00:00<00:03, 4.27it/s, ema_decay=0.978, loss=0.0429, lr=3.32e-5, step=166]\n", + "Epoch 166: 6%|โ–‹ | 1/16 [00:00<00:03, 4.24it/s, ema_decay=0.978, loss=0.0578, lr=3.34e-5, step=167]\n", + "Epoch 167: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.978, loss=0.0424, lr=3.36e-5, step=168]\n", + "Epoch 168: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.979, loss=0.0382, lr=3.38e-5, step=169]\n", + "Epoch 169: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.979, loss=0.113, lr=3.4e-5, step=170]\n", + "Epoch 170: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.979, loss=0.0452, lr=3.42e-5, step=171]\n", + "Epoch 171: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.979, loss=0.0406, lr=3.44e-5, step=172]\n", + "Epoch 172: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.979, loss=0.0613, lr=3.46e-5, step=173]\n", + "Epoch 173: 6%|โ–‹ | 1/16 [00:00<00:03, 4.78it/s, ema_decay=0.979, loss=0.0516, lr=3.48e-5, step=174]\n", + "Epoch 174: 6%|โ–‹ | 1/16 [00:00<00:03, 4.99it/s, ema_decay=0.979, loss=0.0604, lr=3.5e-5, step=175]\n", + "Epoch 175: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.979, loss=0.0924, lr=3.52e-5, step=176]\n", + "Epoch 176: 6%|โ–‹ | 1/16 [00:00<00:03, 4.37it/s, ema_decay=0.979, loss=0.087, lr=3.54e-5, step=177]\n", + "Epoch 177: 6%|โ–‹ | 1/16 [00:00<00:03, 4.40it/s, ema_decay=0.979, loss=0.0669, lr=3.56e-5, step=178]\n", + "Epoch 178: 6%|โ–‹ | 1/16 [00:00<00:03, 4.60it/s, ema_decay=0.979, loss=0.0556, lr=3.58e-5, step=179]\n", + "Epoch 179: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.98, loss=0.0723, lr=3.6e-5, step=180]\n", + "Epoch 180: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.98, loss=0.0453, lr=3.62e-5, step=181]\n", + "Epoch 181: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.98, loss=0.0455, lr=3.64e-5, step=182]\n", + "Epoch 182: 6%|โ–‹ | 1/16 [00:00<00:03, 4.47it/s, ema_decay=0.98, loss=0.0365, lr=3.66e-5, step=183]\n", + "Epoch 183: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.98, loss=0.0486, lr=3.68e-5, step=184]\n", + "Epoch 184: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.98, loss=0.0556, lr=3.7e-5, step=185]\n", + "Epoch 185: 6%|โ–‹ | 1/16 [00:00<00:02, 5.36it/s, ema_decay=0.98, loss=0.0817, lr=3.72e-5, step=186]\n", + "Epoch 186: 6%|โ–‹ | 1/16 [00:00<00:03, 4.85it/s, ema_decay=0.98, loss=0.0511, lr=3.74e-5, step=187]\n", + "Epoch 187: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.98, loss=0.0849, lr=3.76e-5, step=188]\n", + "Epoch 188: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.98, loss=0.0575, lr=3.78e-5, step=189]\n", + "Epoch 189: 6%|โ–‹ | 1/16 [00:00<00:02, 5.36it/s, ema_decay=0.98, loss=0.0804, lr=3.8e-5, step=190]\n", + "Epoch 190: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.98, loss=0.0405, lr=3.82e-5, step=191]\n", + "Epoch 191: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.981, loss=0.0437, lr=3.84e-5, step=192]\n", + "Epoch 192: 6%|โ–‹ | 1/16 [00:00<00:03, 4.53it/s, ema_decay=0.981, loss=0.0564, lr=3.86e-5, step=193]\n", + "Epoch 193: 6%|โ–‹ | 1/16 [00:00<00:03, 4.62it/s, ema_decay=0.981, loss=0.0744, lr=3.88e-5, step=194]\n", + "Epoch 194: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.981, loss=0.0554, lr=3.9e-5, step=195]\n", + "Epoch 195: 6%|โ–‹ | 1/16 [00:00<00:03, 4.45it/s, ema_decay=0.981, loss=0.0421, lr=3.92e-5, step=196]\n", + "Epoch 196: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.981, loss=0.0355, lr=3.94e-5, step=197]\n", + "Epoch 197: 6%|โ–‹ | 1/16 [00:00<00:03, 5.00it/s, ema_decay=0.981, loss=0.0396, lr=3.96e-5, step=198]\n", + "Epoch 198: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.981, loss=0.0349, lr=3.98e-5, step=199]\n", + "Epoch 199: 6%|โ–‹ | 1/16 [00:00<00:03, 4.60it/s, ema_decay=0.981, loss=0.0342, lr=4e-5, step=200]\n", + "Epoch 200: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.981, loss=0.0403, lr=4.02e-5, step=201]\n", + "Epoch 201: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.981, loss=0.0589, lr=4.04e-5, step=202]\n", + "Epoch 202: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.981, loss=0.0379, lr=4.06e-5, step=203]\n", + "Epoch 203: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.981, loss=0.0777, lr=4.08e-5, step=204]\n", + "Epoch 204: 6%|โ–‹ | 1/16 [00:00<00:02, 5.12it/s, ema_decay=0.981, loss=0.051, lr=4.1e-5, step=205]\n", + "Epoch 205: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.982, loss=0.0575, lr=4.12e-5, step=206]\n", + "Epoch 206: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.982, loss=0.0911, lr=4.14e-5, step=207]\n", + "Epoch 207: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.982, loss=0.0485, lr=4.16e-5, step=208]\n", + "Epoch 208: 6%|โ–‹ | 1/16 [00:00<00:03, 4.62it/s, ema_decay=0.982, loss=0.053, lr=4.18e-5, step=209]\n", + "Epoch 209: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.982, loss=0.0463, lr=4.2e-5, step=210]\n", + "Epoch 210: 6%|โ–‹ | 1/16 [00:00<00:03, 4.41it/s, ema_decay=0.982, loss=0.0362, lr=4.22e-5, step=211]\n", + "Epoch 211: 6%|โ–‹ | 1/16 [00:00<00:03, 4.57it/s, ema_decay=0.982, loss=0.0275, lr=4.24e-5, step=212]\n", + "Epoch 212: 6%|โ–‹ | 1/16 [00:00<00:02, 5.06it/s, ema_decay=0.982, loss=0.0321, lr=4.26e-5, step=213]\n", + "Epoch 213: 6%|โ–‹ | 1/16 [00:00<00:02, 5.43it/s, ema_decay=0.982, loss=0.0679, lr=4.28e-5, step=214]\n", + "Epoch 214: 6%|โ–‹ | 1/16 [00:00<00:03, 4.37it/s, ema_decay=0.982, loss=0.0298, lr=4.3e-5, step=215]\n", + "Epoch 215: 6%|โ–‹ | 1/16 [00:00<00:03, 4.75it/s, ema_decay=0.982, loss=0.0524, lr=4.32e-5, step=216]\n", + "Epoch 216: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.982, loss=0.0373, lr=4.34e-5, step=217]\n", + "Epoch 217: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.982, loss=0.0427, lr=4.36e-5, step=218]\n", + "Epoch 218: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.982, loss=0.0529, lr=4.38e-5, step=219]\n", + "Epoch 219: 6%|โ–‹ | 1/16 [00:00<00:03, 4.84it/s, ema_decay=0.982, loss=0.0412, lr=4.4e-5, step=220]\n", + "Epoch 220: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.982, loss=0.0447, lr=4.42e-5, step=221]\n", + "Epoch 221: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.983, loss=0.0393, lr=4.44e-5, step=222]\n", + "Epoch 222: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.983, loss=0.0429, lr=4.46e-5, step=223]\n", + "Epoch 223: 6%|โ–‹ | 1/16 [00:00<00:03, 4.83it/s, ema_decay=0.983, loss=0.0298, lr=4.48e-5, step=224]\n", + "Epoch 224: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.983, loss=0.0266, lr=4.5e-5, step=225]\n", + "Epoch 225: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.983, loss=0.0324, lr=4.52e-5, step=226]\n", + "Epoch 226: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.983, loss=0.0555, lr=4.54e-5, step=227]\n", + "Epoch 227: 6%|โ–‹ | 1/16 [00:00<00:03, 4.44it/s, ema_decay=0.983, loss=0.0532, lr=4.56e-5, step=228]\n", + "Epoch 228: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.983, loss=0.0436, lr=4.58e-5, step=229]\n", + "Epoch 229: 6%|โ–‹ | 1/16 [00:00<00:03, 4.42it/s, ema_decay=0.983, loss=0.0832, lr=4.6e-5, step=230]\n", + "Epoch 230: 6%|โ–‹ | 1/16 [00:00<00:03, 4.60it/s, ema_decay=0.983, loss=0.0621, lr=4.62e-5, step=231]\n", + "Epoch 231: 6%|โ–‹ | 1/16 [00:00<00:02, 5.15it/s, ema_decay=0.983, loss=0.0317, lr=4.64e-5, step=232]\n", + "Epoch 232: 6%|โ–‹ | 1/16 [00:00<00:02, 5.03it/s, ema_decay=0.983, loss=0.0362, lr=4.66e-5, step=233]\n", + "Epoch 233: 6%|โ–‹ | 1/16 [00:00<00:02, 5.49it/s, ema_decay=0.983, loss=0.0332, lr=4.68e-5, step=234]\n", + "Epoch 234: 6%|โ–‹ | 1/16 [00:00<00:02, 5.50it/s, ema_decay=0.983, loss=0.0459, lr=4.7e-5, step=235]\n", + "Epoch 235: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.983, loss=0.0318, lr=4.72e-5, step=236]\n", + "Epoch 236: 6%|โ–‹ | 1/16 [00:00<00:02, 5.52it/s, ema_decay=0.983, loss=0.0502, lr=4.74e-5, step=237]\n", + "Epoch 237: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.983, loss=0.0797, lr=4.76e-5, step=238]\n", + "Epoch 238: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.983, loss=0.0414, lr=4.78e-5, step=239]\n", + "Epoch 239: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.984, loss=0.0456, lr=4.8e-5, step=240]\n", + "Epoch 240: 6%|โ–‹ | 1/16 [00:00<00:02, 5.37it/s, ema_decay=0.984, loss=0.0711, lr=4.82e-5, step=241]\n", + "Epoch 241: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.984, loss=0.0717, lr=4.84e-5, step=242]\n", + "Epoch 242: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.984, loss=0.0476, lr=4.86e-5, step=243]\n", + "Epoch 243: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.984, loss=0.0361, lr=4.88e-5, step=244]\n", + "Epoch 244: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.984, loss=0.0406, lr=4.9e-5, step=245]\n", + "Epoch 245: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.984, loss=0.0471, lr=4.92e-5, step=246]\n", + "Epoch 246: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.984, loss=0.038, lr=4.94e-5, step=247]\n", + "Epoch 247: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.984, loss=0.0326, lr=4.96e-5, step=248]\n", + "Epoch 248: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.984, loss=0.0465, lr=4.98e-5, step=249]\n", + "Epoch 249: 6%|โ–‹ | 1/16 [00:00<00:02, 5.43it/s, ema_decay=0.984, loss=0.0313, lr=5e-5, step=250]\n", + "Epoch 250: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.984, loss=0.0336, lr=5.02e-5, step=251]\n", + "Epoch 251: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.984, loss=0.0318, lr=5.04e-5, step=252]\n", + "Epoch 252: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.984, loss=0.0603, lr=5.06e-5, step=253]\n", + "Epoch 253: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.984, loss=0.0307, lr=5.08e-5, step=254]\n", + "Epoch 254: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.984, loss=0.0588, lr=5.1e-5, step=255]\n", + "Epoch 255: 6%|โ–‹ | 1/16 [00:00<00:02, 5.44it/s, ema_decay=0.984, loss=0.0265, lr=5.12e-5, step=256]\n", + "Epoch 256: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.984, loss=0.064, lr=5.14e-5, step=257]\n", + "Epoch 257: 6%|โ–‹ | 1/16 [00:00<00:03, 4.64it/s, ema_decay=0.984, loss=0.0204, lr=5.16e-5, step=258]\n", + "Epoch 258: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.984, loss=0.0218, lr=5.18e-5, step=259]\n", + "Epoch 259: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.985, loss=0.0659, lr=5.2e-5, step=260]\n", + "Epoch 260: 6%|โ–‹ | 1/16 [00:00<00:04, 3.67it/s, ema_decay=0.985, loss=0.0257, lr=5.22e-5, step=261]\n", + "Epoch 261: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.985, loss=0.0423, lr=5.24e-5, step=262]\n", + "Epoch 262: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.985, loss=0.0379, lr=5.26e-5, step=263]\n", + "Epoch 263: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.985, loss=0.0305, lr=5.28e-5, step=264]\n", + "Epoch 264: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.985, loss=0.11, lr=5.3e-5, step=265]\n", + "Epoch 265: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.985, loss=0.0569, lr=5.32e-5, step=266]\n", + "Epoch 266: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.985, loss=0.0296, lr=5.34e-5, step=267]\n", + "Epoch 267: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.985, loss=0.0328, lr=5.36e-5, step=268]\n", + "Epoch 268: 6%|โ–‹ | 1/16 [00:00<00:02, 5.44it/s, ema_decay=0.985, loss=0.091, lr=5.38e-5, step=269]\n", + "Epoch 269: 6%|โ–‹ | 1/16 [00:00<00:02, 5.49it/s, ema_decay=0.985, loss=0.0364, lr=5.4e-5, step=270]\n", + "Epoch 270: 6%|โ–‹ | 1/16 [00:00<00:03, 4.44it/s, ema_decay=0.985, loss=0.0322, lr=5.42e-5, step=271]\n", + "Epoch 271: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.985, loss=0.0336, lr=5.44e-5, step=272]\n", + "Epoch 272: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.985, loss=0.0281, lr=5.46e-5, step=273]\n", + "Epoch 273: 6%|โ–‹ | 1/16 [00:00<00:03, 4.45it/s, ema_decay=0.985, loss=0.039, lr=5.48e-5, step=274]\n", + "Epoch 274: 6%|โ–‹ | 1/16 [00:00<00:03, 4.95it/s, ema_decay=0.985, loss=0.0823, lr=5.5e-5, step=275]\n", + "Epoch 275: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.985, loss=0.0316, lr=5.52e-5, step=276]\n", + "Epoch 276: 6%|โ–‹ | 1/16 [00:00<00:03, 4.58it/s, ema_decay=0.985, loss=0.0237, lr=5.54e-5, step=277]\n", + "Epoch 277: 6%|โ–‹ | 1/16 [00:00<00:03, 4.36it/s, ema_decay=0.985, loss=0.0213, lr=5.56e-5, step=278]\n", + "Epoch 278: 6%|โ–‹ | 1/16 [00:00<00:03, 4.06it/s, ema_decay=0.985, loss=0.0293, lr=5.58e-5, step=279]\n", + "Epoch 279: 6%|โ–‹ | 1/16 [00:00<00:03, 4.62it/s, ema_decay=0.985, loss=0.0354, lr=5.6e-5, step=280]\n", + "Epoch 280: 6%|โ–‹ | 1/16 [00:00<00:03, 4.37it/s, ema_decay=0.985, loss=0.0734, lr=5.62e-5, step=281]\n", + "Epoch 281: 6%|โ–‹ | 1/16 [00:00<00:05, 2.66it/s, ema_decay=0.985, loss=0.0315, lr=5.64e-5, step=282]\n", + "Epoch 282: 6%|โ–‹ | 1/16 [00:00<00:03, 3.94it/s, ema_decay=0.985, loss=0.0352, lr=5.66e-5, step=283]\n", + "Epoch 283: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.986, loss=0.0697, lr=5.68e-5, step=284]\n", + "Epoch 284: 6%|โ–‹ | 1/16 [00:00<00:02, 5.36it/s, ema_decay=0.986, loss=0.0479, lr=5.7e-5, step=285]\n", + "Epoch 285: 6%|โ–‹ | 1/16 [00:00<00:03, 4.80it/s, ema_decay=0.986, loss=0.0246, lr=5.72e-5, step=286]\n", + "Epoch 286: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.986, loss=0.0248, lr=5.74e-5, step=287]\n", + "Epoch 287: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.986, loss=0.0469, lr=5.76e-5, step=288]\n", + "Epoch 288: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.986, loss=0.0251, lr=5.78e-5, step=289]\n", + "Epoch 289: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.986, loss=0.024, lr=5.8e-5, step=290]\n", + "Epoch 290: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.986, loss=0.0216, lr=5.82e-5, step=291]\n", + "Epoch 291: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.986, loss=0.0231, lr=5.84e-5, step=292]\n", + "Epoch 292: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.986, loss=0.049, lr=5.86e-5, step=293]\n", + "Epoch 293: 6%|โ–‹ | 1/16 [00:00<00:03, 4.47it/s, ema_decay=0.986, loss=0.0541, lr=5.88e-5, step=294]\n", + "Epoch 294: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.986, loss=0.0287, lr=5.9e-5, step=295]\n", + "Epoch 295: 6%|โ–‹ | 1/16 [00:00<00:04, 3.73it/s, ema_decay=0.986, loss=0.0301, lr=5.92e-5, step=296]\n", + "Epoch 296: 6%|โ–‹ | 1/16 [00:00<00:03, 4.82it/s, ema_decay=0.986, loss=0.0633, lr=5.94e-5, step=297]\n", + "Epoch 297: 6%|โ–‹ | 1/16 [00:00<00:02, 5.41it/s, ema_decay=0.986, loss=0.0324, lr=5.96e-5, step=298]\n", + "Epoch 298: 6%|โ–‹ | 1/16 [00:00<00:03, 4.91it/s, ema_decay=0.986, loss=0.0429, lr=5.98e-5, step=299]\n", + "Epoch 299: 6%|โ–‹ | 1/16 [00:00<00:03, 4.75it/s, ema_decay=0.986, loss=0.02, lr=6e-5, step=300]\n", + "Epoch 300: 6%|โ–‹ | 1/16 [00:00<00:03, 4.64it/s, ema_decay=0.986, loss=0.0459, lr=6.02e-5, step=301]\n", + "Epoch 301: 6%|โ–‹ | 1/16 [00:00<00:03, 4.80it/s, ema_decay=0.986, loss=0.0213, lr=6.04e-5, step=302]\n", + "Epoch 302: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.986, loss=0.0328, lr=6.06e-5, step=303]\n", + "Epoch 303: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.986, loss=0.0295, lr=6.08e-5, step=304]\n", + "Epoch 304: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.986, loss=0.0358, lr=6.1e-5, step=305]\n", + "Epoch 305: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.986, loss=0.0339, lr=6.12e-5, step=306]\n", + "Epoch 306: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.986, loss=0.0323, lr=6.14e-5, step=307]\n", + "Epoch 307: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.986, loss=0.0324, lr=6.16e-5, step=308]\n", + "Epoch 308: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.986, loss=0.0242, lr=6.18e-5, step=309]\n", + "Epoch 309: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.986, loss=0.0251, lr=6.2e-5, step=310]\n", + "Epoch 310: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.986, loss=0.0304, lr=6.22e-5, step=311]\n", + "Epoch 311: 6%|โ–‹ | 1/16 [00:00<00:03, 4.97it/s, ema_decay=0.986, loss=0.0188, lr=6.24e-5, step=312]\n", + "Epoch 312: 6%|โ–‹ | 1/16 [00:00<00:02, 5.27it/s, ema_decay=0.987, loss=0.0274, lr=6.26e-5, step=313]\n", + "Epoch 313: 6%|โ–‹ | 1/16 [00:00<00:02, 5.47it/s, ema_decay=0.987, loss=0.0212, lr=6.28e-5, step=314]\n", + "Epoch 314: 6%|โ–‹ | 1/16 [00:00<00:03, 4.66it/s, ema_decay=0.987, loss=0.0256, lr=6.3e-5, step=315]\n", + "Epoch 315: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.987, loss=0.0281, lr=6.32e-5, step=316]\n", + "Epoch 316: 6%|โ–‹ | 1/16 [00:00<00:02, 5.17it/s, ema_decay=0.987, loss=0.0536, lr=6.34e-5, step=317]\n", + "Epoch 317: 6%|โ–‹ | 1/16 [00:00<00:02, 5.36it/s, ema_decay=0.987, loss=0.0238, lr=6.36e-5, step=318]\n", + "Epoch 318: 6%|โ–‹ | 1/16 [00:00<00:03, 4.56it/s, ema_decay=0.987, loss=0.0183, lr=6.38e-5, step=319]\n", + "Epoch 319: 6%|โ–‹ | 1/16 [00:00<00:02, 5.48it/s, ema_decay=0.987, loss=0.0226, lr=6.4e-5, step=320]\n", + "Epoch 320: 6%|โ–‹ | 1/16 [00:00<00:02, 5.44it/s, ema_decay=0.987, loss=0.0193, lr=6.42e-5, step=321]\n", + "Epoch 321: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.987, loss=0.0302, lr=6.44e-5, step=322]\n", + "Epoch 322: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.987, loss=0.021, lr=6.46e-5, step=323]\n", + "Epoch 323: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.987, loss=0.02, lr=6.48e-5, step=324]\n", + "Epoch 324: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.987, loss=0.0268, lr=6.5e-5, step=325]\n", + "Epoch 325: 6%|โ–‹ | 1/16 [00:00<00:03, 4.55it/s, ema_decay=0.987, loss=0.0162, lr=6.52e-5, step=326]\n", + "Epoch 326: 6%|โ–‹ | 1/16 [00:00<00:03, 4.46it/s, ema_decay=0.987, loss=0.0244, lr=6.54e-5, step=327]\n", + "Epoch 327: 6%|โ–‹ | 1/16 [00:00<00:03, 4.54it/s, ema_decay=0.987, loss=0.0431, lr=6.56e-5, step=328]\n", + "Epoch 328: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.987, loss=0.0165, lr=6.58e-5, step=329]\n", + "Epoch 329: 6%|โ–‹ | 1/16 [00:00<00:02, 5.27it/s, ema_decay=0.987, loss=0.0189, lr=6.6e-5, step=330]\n", + "Epoch 330: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.987, loss=0.0471, lr=6.62e-5, step=331]\n", + "Epoch 331: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.987, loss=0.0236, lr=6.64e-5, step=332]\n", + "Epoch 332: 6%|โ–‹ | 1/16 [00:00<00:03, 4.38it/s, ema_decay=0.987, loss=0.0258, lr=6.66e-5, step=333]\n", + "Epoch 333: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.987, loss=0.0676, lr=6.68e-5, step=334]\n", + "Epoch 334: 6%|โ–‹ | 1/16 [00:00<00:03, 4.42it/s, ema_decay=0.987, loss=0.02, lr=6.7e-5, step=335]\n", + "Epoch 335: 6%|โ–‹ | 1/16 [00:00<00:03, 4.47it/s, ema_decay=0.987, loss=0.0218, lr=6.72e-5, step=336]\n", + "Epoch 336: 6%|โ–‹ | 1/16 [00:00<00:04, 3.21it/s, ema_decay=0.987, loss=0.0219, lr=6.74e-5, step=337]\n", + "Epoch 337: 6%|โ–‹ | 1/16 [00:00<00:02, 5.05it/s, ema_decay=0.987, loss=0.031, lr=6.76e-5, step=338]\n", + "Epoch 338: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.987, loss=0.0174, lr=6.78e-5, step=339]\n", + "Epoch 339: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.987, loss=0.0194, lr=6.8e-5, step=340]\n", + "Epoch 340: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.987, loss=0.0168, lr=6.82e-5, step=341]\n", + "Epoch 341: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.987, loss=0.0766, lr=6.84e-5, step=342]\n", + "Epoch 342: 6%|โ–‹ | 1/16 [00:00<00:03, 4.80it/s, ema_decay=0.987, loss=0.0173, lr=6.86e-5, step=343]\n", + "Epoch 343: 6%|โ–‹ | 1/16 [00:00<00:02, 5.11it/s, ema_decay=0.987, loss=0.0285, lr=6.88e-5, step=344]\n", + "Epoch 344: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.987, loss=0.016, lr=6.9e-5, step=345]\n", + "Epoch 345: 6%|โ–‹ | 1/16 [00:00<00:03, 4.89it/s, ema_decay=0.988, loss=0.0267, lr=6.92e-5, step=346]\n", + "Epoch 346: 6%|โ–‹ | 1/16 [00:00<00:03, 4.60it/s, ema_decay=0.988, loss=0.0201, lr=6.94e-5, step=347]\n", + "Epoch 347: 6%|โ–‹ | 1/16 [00:00<00:03, 4.98it/s, ema_decay=0.988, loss=0.0246, lr=6.96e-5, step=348]\n", + "Epoch 348: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.988, loss=0.0177, lr=6.98e-5, step=349]\n", + "Epoch 349: 6%|โ–‹ | 1/16 [00:00<00:03, 4.82it/s, ema_decay=0.988, loss=0.0152, lr=7e-5, step=350]\n", + "Epoch 350: 6%|โ–‹ | 1/16 [00:00<00:02, 5.24it/s, ema_decay=0.988, loss=0.0323, lr=7.02e-5, step=351]\n", + "Epoch 351: 6%|โ–‹ | 1/16 [00:00<00:03, 4.16it/s, ema_decay=0.988, loss=0.0175, lr=7.04e-5, step=352]\n", + "Epoch 352: 6%|โ–‹ | 1/16 [00:00<00:03, 4.84it/s, ema_decay=0.988, loss=0.0186, lr=7.06e-5, step=353]\n", + "Epoch 353: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.988, loss=0.0279, lr=7.08e-5, step=354]\n", + "Epoch 354: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.988, loss=0.0239, lr=7.1e-5, step=355]\n", + "Epoch 355: 6%|โ–‹ | 1/16 [00:00<00:04, 3.63it/s, ema_decay=0.988, loss=0.0392, lr=7.12e-5, step=356]\n", + "Epoch 356: 6%|โ–‹ | 1/16 [00:00<00:04, 3.66it/s, ema_decay=0.988, loss=0.0191, lr=7.14e-5, step=357]\n", + "Epoch 357: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.988, loss=0.0175, lr=7.16e-5, step=358]\n", + "Epoch 358: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.988, loss=0.0119, lr=7.18e-5, step=359]\n", + "Epoch 359: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.988, loss=0.0181, lr=7.2e-5, step=360]\n", + "Epoch 360: 6%|โ–‹ | 1/16 [00:00<00:03, 4.17it/s, ema_decay=0.988, loss=0.0208, lr=7.22e-5, step=361]\n", + "Epoch 361: 6%|โ–‹ | 1/16 [00:00<00:02, 5.41it/s, ema_decay=0.988, loss=0.0167, lr=7.24e-5, step=362]\n", + "Epoch 362: 6%|โ–‹ | 1/16 [00:00<00:03, 4.97it/s, ema_decay=0.988, loss=0.0182, lr=7.26e-5, step=363]\n", + "Epoch 363: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.988, loss=0.0166, lr=7.28e-5, step=364]\n", + "Epoch 364: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.988, loss=0.0141, lr=7.3e-5, step=365]\n", + "Epoch 365: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.988, loss=0.0276, lr=7.32e-5, step=366]\n", + "Epoch 366: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.988, loss=0.0597, lr=7.34e-5, step=367]\n", + "Epoch 367: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.988, loss=0.0143, lr=7.36e-5, step=368]\n", + "Epoch 368: 6%|โ–‹ | 1/16 [00:00<00:02, 5.50it/s, ema_decay=0.988, loss=0.046, lr=7.38e-5, step=369]\n", + "Epoch 369: 6%|โ–‹ | 1/16 [00:00<00:02, 5.49it/s, ema_decay=0.988, loss=0.0167, lr=7.4e-5, step=370]\n", + "Epoch 370: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.988, loss=0.0228, lr=7.42e-5, step=371]\n", + "Epoch 371: 6%|โ–‹ | 1/16 [00:00<00:02, 5.43it/s, ema_decay=0.988, loss=0.0117, lr=7.44e-5, step=372]\n", + "Epoch 372: 6%|โ–‹ | 1/16 [00:00<00:02, 5.07it/s, ema_decay=0.988, loss=0.0809, lr=7.46e-5, step=373]\n", + "Epoch 373: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.988, loss=0.0141, lr=7.48e-5, step=374]\n", + "Epoch 374: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.988, loss=0.0212, lr=7.5e-5, step=375]\n", + "Epoch 375: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.988, loss=0.0421, lr=7.52e-5, step=376]\n", + "Epoch 376: 6%|โ–‹ | 1/16 [00:00<00:02, 5.12it/s, ema_decay=0.988, loss=0.0104, lr=7.54e-5, step=377]\n", + "Epoch 377: 6%|โ–‹ | 1/16 [00:00<00:03, 4.26it/s, ema_decay=0.988, loss=0.0206, lr=7.56e-5, step=378]\n", + "Epoch 378: 6%|โ–‹ | 1/16 [00:00<00:02, 5.27it/s, ema_decay=0.988, loss=0.0156, lr=7.58e-5, step=379]\n", + "Epoch 379: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.988, loss=0.0246, lr=7.6e-5, step=380]\n", + "Epoch 380: 6%|โ–‹ | 1/16 [00:00<00:03, 4.90it/s, ema_decay=0.988, loss=0.013, lr=7.62e-5, step=381]\n", + "Epoch 381: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.988, loss=0.0169, lr=7.64e-5, step=382]\n", + "Epoch 382: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.988, loss=0.0192, lr=7.66e-5, step=383]\n", + "Epoch 383: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.988, loss=0.0521, lr=7.68e-5, step=384]\n", + "Epoch 384: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.988, loss=0.0218, lr=7.7e-5, step=385]\n", + "Epoch 385: 6%|โ–‹ | 1/16 [00:00<00:03, 4.38it/s, ema_decay=0.988, loss=0.0297, lr=7.72e-5, step=386]\n", + "Epoch 386: 6%|โ–‹ | 1/16 [00:00<00:03, 4.55it/s, ema_decay=0.989, loss=0.0373, lr=7.74e-5, step=387]\n", + "Epoch 387: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.989, loss=0.0153, lr=7.76e-5, step=388]\n", + "Epoch 388: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.989, loss=0.022, lr=7.78e-5, step=389]\n", + "Epoch 389: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.989, loss=0.0218, lr=7.8e-5, step=390]\n", + "Epoch 390: 6%|โ–‹ | 1/16 [00:00<00:03, 4.61it/s, ema_decay=0.989, loss=0.0482, lr=7.82e-5, step=391]\n", + "Epoch 391: 6%|โ–‹ | 1/16 [00:00<00:02, 5.04it/s, ema_decay=0.989, loss=0.0148, lr=7.84e-5, step=392]\n", + "Epoch 392: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.989, loss=0.0255, lr=7.86e-5, step=393]\n", + "Epoch 393: 6%|โ–‹ | 1/16 [00:00<00:03, 3.96it/s, ema_decay=0.989, loss=0.0194, lr=7.88e-5, step=394]\n", + "Epoch 394: 6%|โ–‹ | 1/16 [00:00<00:04, 3.19it/s, ema_decay=0.989, loss=0.0198, lr=7.9e-5, step=395]\n", + "Epoch 395: 6%|โ–‹ | 1/16 [00:00<00:02, 5.11it/s, ema_decay=0.989, loss=0.027, lr=7.92e-5, step=396]\n", + "Epoch 396: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.989, loss=0.0133, lr=7.94e-5, step=397]\n", + "Epoch 397: 6%|โ–‹ | 1/16 [00:00<00:02, 5.49it/s, ema_decay=0.989, loss=0.0109, lr=7.96e-5, step=398]\n", + "Epoch 398: 6%|โ–‹ | 1/16 [00:00<00:03, 4.62it/s, ema_decay=0.989, loss=0.0165, lr=7.98e-5, step=399]\n", + "Epoch 399: 6%|โ–‹ | 1/16 [00:00<00:03, 4.92it/s, ema_decay=0.989, loss=0.0156, lr=8e-5, step=400]\n", + "Epoch 400: 6%|โ–‹ | 1/16 [00:00<00:02, 5.41it/s, ema_decay=0.989, loss=0.0187, lr=8.02e-5, step=401]\n", + "Epoch 401: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.989, loss=0.0218, lr=8.04e-5, step=402]\n", + "Epoch 402: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.989, loss=0.0111, lr=8.06e-5, step=403]\n", + "Epoch 403: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.989, loss=0.0309, lr=8.08e-5, step=404]\n", + "Epoch 404: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.989, loss=0.0111, lr=8.1e-5, step=405]\n", + "Epoch 405: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.989, loss=0.0131, lr=8.12e-5, step=406]\n", + "Epoch 406: 6%|โ–‹ | 1/16 [00:00<00:02, 5.31it/s, ema_decay=0.989, loss=0.0249, lr=8.14e-5, step=407]\n", + "Epoch 407: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.989, loss=0.0121, lr=8.16e-5, step=408]\n", + "Epoch 408: 6%|โ–‹ | 1/16 [00:00<00:02, 5.02it/s, ema_decay=0.989, loss=0.0186, lr=8.18e-5, step=409]\n", + "Epoch 409: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.989, loss=0.00988, lr=8.2e-5, step=410]\n", + "Epoch 410: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.989, loss=0.0151, lr=8.22e-5, step=411]\n", + "Epoch 411: 6%|โ–‹ | 1/16 [00:00<00:02, 5.21it/s, ema_decay=0.989, loss=0.031, lr=8.24e-5, step=412]\n", + "Epoch 412: 6%|โ–‹ | 1/16 [00:00<00:02, 5.40it/s, ema_decay=0.989, loss=0.0119, lr=8.26e-5, step=413]\n", + "Epoch 413: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.989, loss=0.0154, lr=8.28e-5, step=414]\n", + "Epoch 414: 6%|โ–‹ | 1/16 [00:00<00:02, 5.14it/s, ema_decay=0.989, loss=0.0119, lr=8.3e-5, step=415]\n", + "Epoch 415: 6%|โ–‹ | 1/16 [00:00<00:03, 4.93it/s, ema_decay=0.989, loss=0.00987, lr=8.32e-5, step=416]\n", + "Epoch 416: 6%|โ–‹ | 1/16 [00:00<00:03, 4.42it/s, ema_decay=0.989, loss=0.0163, lr=8.34e-5, step=417]\n", + "Epoch 417: 6%|โ–‹ | 1/16 [00:00<00:02, 5.20it/s, ema_decay=0.989, loss=0.0199, lr=8.36e-5, step=418]\n", + "Epoch 418: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.989, loss=0.0205, lr=8.38e-5, step=419]\n", + "Epoch 419: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.989, loss=0.0108, lr=8.4e-5, step=420]\n", + "Epoch 420: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.989, loss=0.0345, lr=8.42e-5, step=421]\n", + "Epoch 421: 6%|โ–‹ | 1/16 [00:00<00:02, 5.37it/s, ema_decay=0.989, loss=0.0149, lr=8.44e-5, step=422]\n", + "Epoch 422: 6%|โ–‹ | 1/16 [00:00<00:03, 4.88it/s, ema_decay=0.989, loss=0.0115, lr=8.46e-5, step=423]\n", + "Epoch 423: 6%|โ–‹ | 1/16 [00:00<00:02, 5.10it/s, ema_decay=0.989, loss=0.0192, lr=8.48e-5, step=424]\n", + "Epoch 424: 6%|โ–‹ | 1/16 [00:00<00:03, 4.51it/s, ema_decay=0.989, loss=0.00819, lr=8.5e-5, step=425]\n", + "Epoch 425: 6%|โ–‹ | 1/16 [00:00<00:03, 4.36it/s, ema_decay=0.989, loss=0.0359, lr=8.52e-5, step=426]\n", + "Epoch 426: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.989, loss=0.0193, lr=8.54e-5, step=427]\n", + "Epoch 427: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.989, loss=0.0111, lr=8.56e-5, step=428]\n", + "Epoch 428: 6%|โ–‹ | 1/16 [00:00<00:03, 4.37it/s, ema_decay=0.989, loss=0.044, lr=8.58e-5, step=429]\n", + "Epoch 429: 6%|โ–‹ | 1/16 [00:00<00:02, 5.12it/s, ema_decay=0.989, loss=0.0133, lr=8.6e-5, step=430]\n", + "Epoch 430: 6%|โ–‹ | 1/16 [00:00<00:03, 4.91it/s, ema_decay=0.989, loss=0.0396, lr=8.62e-5, step=431]\n", + "Epoch 431: 6%|โ–‹ | 1/16 [00:00<00:03, 4.75it/s, ema_decay=0.989, loss=0.0204, lr=8.64e-5, step=432]\n", + "Epoch 432: 6%|โ–‹ | 1/16 [00:00<00:02, 5.32it/s, ema_decay=0.989, loss=0.0117, lr=8.66e-5, step=433]\n", + "Epoch 433: 6%|โ–‹ | 1/16 [00:00<00:02, 5.51it/s, ema_decay=0.989, loss=0.011, lr=8.68e-5, step=434]\n", + "Epoch 434: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.989, loss=0.0505, lr=8.7e-5, step=435]\n", + "Epoch 435: 6%|โ–‹ | 1/16 [00:00<00:02, 5.55it/s, ema_decay=0.99, loss=0.00882, lr=8.72e-5, step=436]\n", + "Epoch 436: 6%|โ–‹ | 1/16 [00:00<00:02, 5.43it/s, ema_decay=0.99, loss=0.0288, lr=8.74e-5, step=437]\n", + "Epoch 437: 6%|โ–‹ | 1/16 [00:00<00:03, 4.38it/s, ema_decay=0.99, loss=0.0226, lr=8.76e-5, step=438]\n", + "Epoch 438: 6%|โ–‹ | 1/16 [00:00<00:04, 3.63it/s, ema_decay=0.99, loss=0.032, lr=8.78e-5, step=439]\n", + "Epoch 439: 6%|โ–‹ | 1/16 [00:00<00:02, 5.08it/s, ema_decay=0.99, loss=0.0156, lr=8.8e-5, step=440]\n", + "Epoch 440: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.99, loss=0.0157, lr=8.82e-5, step=441]\n", + "Epoch 441: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.99, loss=0.0104, lr=8.84e-5, step=442]\n", + "Epoch 442: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.99, loss=0.00898, lr=8.86e-5, step=443]\n", + "Epoch 443: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.99, loss=0.0113, lr=8.88e-5, step=444]\n", + "Epoch 444: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.99, loss=0.0191, lr=8.9e-5, step=445]\n", + "Epoch 445: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.99, loss=0.0658, lr=8.92e-5, step=446]\n", + "Epoch 446: 6%|โ–‹ | 1/16 [00:00<00:02, 5.23it/s, ema_decay=0.99, loss=0.0207, lr=8.94e-5, step=447]\n", + "Epoch 447: 6%|โ–‹ | 1/16 [00:00<00:02, 5.33it/s, ema_decay=0.99, loss=0.0563, lr=8.96e-5, step=448]\n", + "Epoch 448: 6%|โ–‹ | 1/16 [00:00<00:05, 2.77it/s, ema_decay=0.99, loss=0.0152, lr=8.98e-5, step=449]\n", + "Epoch 449: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.99, loss=0.0134, lr=9e-5, step=450]\n", + "Epoch 450: 6%|โ–‹ | 1/16 [00:00<00:03, 4.50it/s, ema_decay=0.99, loss=0.0187, lr=9.02e-5, step=451]\n", + "Epoch 451: 6%|โ–‹ | 1/16 [00:00<00:02, 5.27it/s, ema_decay=0.99, loss=0.0132, lr=9.04e-5, step=452]\n", + "Epoch 452: 6%|โ–‹ | 1/16 [00:00<00:02, 5.11it/s, ema_decay=0.99, loss=0.0104, lr=9.06e-5, step=453]\n", + "Epoch 453: 6%|โ–‹ | 1/16 [00:00<00:02, 5.06it/s, ema_decay=0.99, loss=0.0289, lr=9.08e-5, step=454]\n", + "Epoch 454: 6%|โ–‹ | 1/16 [00:00<00:03, 4.69it/s, ema_decay=0.99, loss=0.0184, lr=9.1e-5, step=455]\n", + "Epoch 455: 6%|โ–‹ | 1/16 [00:00<00:03, 4.45it/s, ema_decay=0.99, loss=0.019, lr=9.12e-5, step=456]\n", + "Epoch 456: 6%|โ–‹ | 1/16 [00:00<00:03, 4.47it/s, ema_decay=0.99, loss=0.0482, lr=9.14e-5, step=457]\n", + "Epoch 457: 6%|โ–‹ | 1/16 [00:00<00:03, 4.39it/s, ema_decay=0.99, loss=0.0164, lr=9.16e-5, step=458]\n", + "Epoch 458: 6%|โ–‹ | 1/16 [00:00<00:03, 4.26it/s, ema_decay=0.99, loss=0.00871, lr=9.18e-5, step=459]\n", + "Epoch 459: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.99, loss=0.0247, lr=9.2e-5, step=460]\n", + "Epoch 460: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.99, loss=0.00849, lr=9.22e-5, step=461]\n", + "Epoch 461: 6%|โ–‹ | 1/16 [00:00<00:03, 4.48it/s, ema_decay=0.99, loss=0.0245, lr=9.24e-5, step=462]\n", + "Epoch 462: 6%|โ–‹ | 1/16 [00:00<00:03, 4.33it/s, ema_decay=0.99, loss=0.014, lr=9.26e-5, step=463]\n", + "Epoch 463: 6%|โ–‹ | 1/16 [00:00<00:03, 4.72it/s, ema_decay=0.99, loss=0.0233, lr=9.28e-5, step=464]\n", + "Epoch 464: 6%|โ–‹ | 1/16 [00:00<00:03, 4.86it/s, ema_decay=0.99, loss=0.0309, lr=9.3e-5, step=465]\n", + "Epoch 465: 6%|โ–‹ | 1/16 [00:00<00:02, 5.13it/s, ema_decay=0.99, loss=0.0153, lr=9.32e-5, step=466]\n", + "Epoch 466: 6%|โ–‹ | 1/16 [00:00<00:03, 4.64it/s, ema_decay=0.99, loss=0.061, lr=9.34e-5, step=467]\n", + "Epoch 467: 6%|โ–‹ | 1/16 [00:00<00:03, 4.58it/s, ema_decay=0.99, loss=0.0134, lr=9.36e-5, step=468]\n", + "Epoch 468: 6%|โ–‹ | 1/16 [00:00<00:02, 5.25it/s, ema_decay=0.99, loss=0.0156, lr=9.38e-5, step=469]\n", + "Epoch 469: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.99, loss=0.0126, lr=9.4e-5, step=470]\n", + "Epoch 470: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.99, loss=0.0913, lr=9.42e-5, step=471]\n", + "Epoch 471: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.99, loss=0.0156, lr=9.44e-5, step=472]\n", + "Epoch 472: 6%|โ–‹ | 1/16 [00:00<00:02, 5.22it/s, ema_decay=0.99, loss=0.0178, lr=9.46e-5, step=473]\n", + "Epoch 473: 6%|โ–‹ | 1/16 [00:00<00:02, 5.39it/s, ema_decay=0.99, loss=0.0114, lr=9.48e-5, step=474]\n", + "Epoch 474: 6%|โ–‹ | 1/16 [00:00<00:03, 4.51it/s, ema_decay=0.99, loss=0.00989, lr=9.5e-5, step=475]\n", + "Epoch 475: 6%|โ–‹ | 1/16 [00:00<00:03, 5.00it/s, ema_decay=0.99, loss=0.0096, lr=9.52e-5, step=476]\n", + "Epoch 476: 6%|โ–‹ | 1/16 [00:00<00:03, 4.48it/s, ema_decay=0.99, loss=0.00882, lr=9.54e-5, step=477]\n", + "Epoch 477: 6%|โ–‹ | 1/16 [00:00<00:03, 4.65it/s, ema_decay=0.99, loss=0.0103, lr=9.56e-5, step=478]\n", + "Epoch 478: 6%|โ–‹ | 1/16 [00:00<00:02, 5.18it/s, ema_decay=0.99, loss=0.0221, lr=9.58e-5, step=479]\n", + "Epoch 479: 6%|โ–‹ | 1/16 [00:00<00:02, 5.38it/s, ema_decay=0.99, loss=0.0151, lr=9.6e-5, step=480]\n", + "Epoch 480: 6%|โ–‹ | 1/16 [00:00<00:02, 5.35it/s, ema_decay=0.99, loss=0.0154, lr=9.62e-5, step=481]\n", + "Epoch 481: 6%|โ–‹ | 1/16 [00:00<00:02, 5.29it/s, ema_decay=0.99, loss=0.0162, lr=9.64e-5, step=482]\n", + "Epoch 482: 6%|โ–‹ | 1/16 [00:00<00:02, 5.46it/s, ema_decay=0.99, loss=0.019, lr=9.66e-5, step=483]\n", + "Epoch 483: 6%|โ–‹ | 1/16 [00:00<00:03, 4.70it/s, ema_decay=0.99, loss=0.0108, lr=9.68e-5, step=484]\n", + "Epoch 484: 6%|โ–‹ | 1/16 [00:00<00:02, 5.16it/s, ema_decay=0.99, loss=0.0206, lr=9.7e-5, step=485]\n", + "Epoch 485: 6%|โ–‹ | 1/16 [00:00<00:02, 5.30it/s, ema_decay=0.99, loss=0.0155, lr=9.72e-5, step=486]\n", + "Epoch 486: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.99, loss=0.0104, lr=9.74e-5, step=487]\n", + "Epoch 487: 6%|โ–‹ | 1/16 [00:00<00:02, 5.27it/s, ema_decay=0.99, loss=0.00854, lr=9.76e-5, step=488]\n", + "Epoch 488: 6%|โ–‹ | 1/16 [00:00<00:03, 4.60it/s, ema_decay=0.99, loss=0.0157, lr=9.78e-5, step=489]\n", + "Epoch 489: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.99, loss=0.00803, lr=9.8e-5, step=490]\n", + "Epoch 490: 6%|โ–‹ | 1/16 [00:00<00:02, 5.01it/s, ema_decay=0.99, loss=0.0204, lr=9.82e-5, step=491]\n", + "Epoch 491: 6%|โ–‹ | 1/16 [00:00<00:02, 5.19it/s, ema_decay=0.99, loss=0.0265, lr=9.84e-5, step=492]\n", + "Epoch 492: 6%|โ–‹ | 1/16 [00:00<00:02, 5.28it/s, ema_decay=0.99, loss=0.0207, lr=9.86e-5, step=493]\n", + "Epoch 493: 6%|โ–‹ | 1/16 [00:00<00:02, 5.34it/s, ema_decay=0.99, loss=0.00809, lr=9.88e-5, step=494]\n", + "Epoch 494: 6%|โ–‹ | 1/16 [00:00<00:02, 5.42it/s, ema_decay=0.99, loss=0.0113, lr=9.9e-5, step=495]\n", + "Epoch 495: 6%|โ–‹ | 1/16 [00:00<00:02, 5.45it/s, ema_decay=0.99, loss=0.034, lr=9.92e-5, step=496]\n", + "Epoch 496: 6%|โ–‹ | 1/16 [00:00<00:03, 4.62it/s, ema_decay=0.99, loss=0.00972, lr=9.94e-5, step=497]\n", + "Epoch 497: 6%|โ–‹ | 1/16 [00:00<00:03, 4.83it/s, ema_decay=0.99, loss=0.0225, lr=9.96e-5, step=498]\n", + "Epoch 498: 6%|โ–‹ | 1/16 [00:00<00:03, 4.87it/s, ema_decay=0.991, loss=0.0155, lr=9.98e-5, step=499]\n", + "Epoch 499: 6%|โ–‹ | 1/16 [00:00<00:03, 4.74it/s, ema_decay=0.991, loss=0.0567, lr=0.0001, step=500]\n" + ] + } + ], + "source": [ + "global_step = 0\n", + "for epoch in range(training_config.num_epochs):\n", + " model.train()\n", + " progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)\n", + " progress_bar.set_description(f\"Epoch {epoch}\")\n", + " batch = train_image.unsqueeze(0).repeat(\n", + " 16, 1, 1, 1\n", + " ).to(accelerator.device)\n", + " noise = torch.randn(batch.shape).to(accelerator.device)\n", + " bsz = batch.shape[0]\n", + " # Sample a random timestep for each image\n", + " timesteps = torch.randint(\n", + " 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=batch.device\n", + " ).long()\n", + " noisy_images = noise_scheduler.add_noise(batch, noise, timesteps)\n", + " with accelerator.accumulate(model):\n", + " # Predict the noise residual\n", + " noise_pred = model(noisy_images, timesteps).sample\n", + " loss = F.mse_loss(noise_pred, noise)\n", + " accelerator.backward(loss)\n", + "\n", + " if accelerator.sync_gradients:\n", + " accelerator.clip_grad_norm_(model.parameters(), 1.0)\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " if training_config.use_ema:\n", + " ema_model.step(model)\n", + " optimizer.zero_grad()\n", + "\n", + " # Checks if the accelerator has performed an optimization step behind the scenes\n", + " if accelerator.sync_gradients:\n", + " progress_bar.update(1)\n", + " global_step += 1\n", + "\n", + " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n", + " if training_config.use_ema:\n", + " logs[\"ema_decay\"] = ema_model.decay\n", + " progress_bar.set_postfix(**logs)\n", + " accelerator.log(logs, step=global_step)\n", + " progress_bar.close()\n", + "\n", + " accelerator.wait_for_everyone()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0be92e11858f43a3984716ac1e9de667", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(Image.fromarray(images_processed[0]))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "json.dump(model.config, open(\"teacher_config.json\", \"w\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), \"minnie-diffusion/diffusion_pytorch_model.bin\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.10 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/progressive_distillation/utils.py b/examples/progressive_distillation/utils.py new file mode 100644 index 000000000000..d8cc92356f25 --- /dev/null +++ b/examples/progressive_distillation/utils.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass + +from diffusers import UNet2DModel +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) + + +@dataclass +class DiffusionTrainingArgs: + resolution: int = 64 + mixed_precision: str = "fp16" + gradient_accumulation_steps: int = 1 + learning_rate: float = 1e-4 + lr_scheduler: str = "cosine" + lr_warmup_steps: int = 500 + adam_beta1: float = 0.95 + adam_beta2: float = 0.999 + adam_weight_decay: float = 1e-6 + adam_epsilon: float = 1e-08 + use_ema: bool = True + ema_inv_gamma: float = 1.0 + ema_power: float = 3 / 4 + ema_max_decay: float = 0.9999 + batch_size: int = 64 + num_epochs: int = 500 + + +def get_train_transforms(training_config): + # Get standard image transforms + return Compose( + [ + Resize(training_config.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(training_config.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + + +def get_unet(training_config): + # Initialize a generic UNet model to use in our example + return UNet2DModel( + sample_size=training_config.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) diff --git a/examples/v_prediction/train_butterflies.py b/examples/v_prediction/train_butterflies.py new file mode 100644 index 000000000000..5074ece86a98 --- /dev/null +++ b/examples/v_prediction/train_butterflies.py @@ -0,0 +1,227 @@ +import glob +import os +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_cosine_schedule_with_warmup +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm + + +@dataclass +class TrainingConfig: + image_size = 128 # the generated image resolution + train_batch_size = 16 + eval_batch_size = 16 # how many images to sample during evaluation + num_epochs = 50 + gradient_accumulation_steps = 1 + learning_rate = 5e-5 + lr_warmup_steps = 500 + save_image_epochs = 10 + save_model_epochs = 30 + mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision + output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub + + push_to_hub = False # whether to upload the saved model to the HF Hub + hub_private_repo = False + overwrite_output_dir = True # overwrite the old model when re-running the notebook + seed = 0 + + +config = TrainingConfig() + + +config.dataset_name = "huggan/smithsonian_butterflies_subset" +dataset = load_dataset(config.dataset_name, split="train") + + +preprocess = transforms.Compose( + [ + transforms.Resize((config.image_size, config.image_size)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] +) + + +def transform(examples): + images = [preprocess(image.convert("RGB")) for image in examples["image"]] + return {"images": images} + + +dataset.set_transform(transform) + + +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) + + +model = UNet2DModel( + sample_size=config.image_size, # the target image resolution + in_channels=3, # the number of input channels, 3 for RGB images + out_channels=3, # the number of output channels + layers_per_block=2, # how many ResNet layers to use per UNet block + block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block + down_block_types=( + "DownBlock2D", # a regular ResNet downsampling block + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", # a regular ResNet upsampling block + "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), +) + + +if config.output_dir.startswith("ddpm"): + noise_scheduler = DDPMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) +else: + noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + + +optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) + + +lr_scheduler = get_cosine_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=config.lr_warmup_steps, + num_training_steps=(len(train_dataloader) * config.num_epochs), +) + + +def make_grid(images, rows, cols): + w, h = images[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + for i, image in enumerate(images): + grid.paste(image, box=(i % cols * w, i // cols * h)) + return grid + + +def evaluate(config, epoch, pipeline): + # Sample some images from random noise (this is the backward diffusion process). + # The default pipeline output type is `List[PIL.Image]` + images = pipeline( + batch_size=config.eval_batch_size, + generator=torch.manual_seed(config.seed), + ).images + + # Make a grid out of the images + image_grid = make_grid(images, rows=4, cols=4) + + # Save the images + test_dir = os.path.join(config.output_dir, "samples") + os.makedirs(test_dir, exist_ok=True) + image_grid.save(f"{test_dir}/{epoch:04d}.png") + + +def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): + # Initialize accelerator and tensorboard logging + accelerator = Accelerator( + mixed_precision=config.mixed_precision, + gradient_accumulation_steps=config.gradient_accumulation_steps, + log_with="tensorboard", + logging_dir=os.path.join(config.output_dir, "logs"), + ) + if accelerator.is_main_process: + if config.push_to_hub: + repo = init_git_repo(config, at_init=True) + accelerator.init_trackers("train_example") + + # Prepare everything + # There is no specific order to remember, you just need to unpack the + # objects in the same order you gave them to the prepare method. + model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler + ) + + global_step = 0 + + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + evaluate(config, 0, pipeline) + + # Now you train the model + for epoch in range(config.num_epochs): + progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + # Sample noise to add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) + bs = clean_images.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long() + + with accelerator.accumulate(model): + # Predict the noise residual + alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device) + z_t = alpha_t * clean_images + sigma_t * noise + noise_pred = model(z_t, timesteps).sample + v = alpha_t * noise - sigma_t * clean_images + loss = F.mse_loss(noise_pred, v) + accelerator.backward(loss) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + global_step += 1 + + # After each epoch you optionally sample some demo images with evaluate() and save the model + if accelerator.is_main_process: + if config.output_dir.startswith("ddpm"): + pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + else: + pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) + + if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: + evaluate(config, epoch, pipeline) + + if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: + if config.push_to_hub: + push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True) + else: + pipeline.save_pretrained(config.output_dir) + + +args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) + +train_loop(*args) + +sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) +Image.open(sample_images[-1]) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 86eda7371fe9..4394b63ee813 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -33,6 +33,7 @@ DanceDiffusionPipeline, DDIMPipeline, DDPMPipeline, + DistillationPipeline, KarrasVePipeline, LDMPipeline, LDMSuperResolutionPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index ef4d23e5e6d0..b1b293eccf90 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -8,6 +8,7 @@ from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pndm import PNDMPipeline + from .progressive_distillation import DistillationPipeline from .repaint import RePaintPipeline from .score_sde_ve import ScoreSdeVePipeline from .stochastic_karras_ve import KarrasVePipeline diff --git a/src/diffusers/pipelines/progressive_distillation/__init__.py b/src/diffusers/pipelines/progressive_distillation/__init__.py new file mode 100644 index 000000000000..e7031b1583a7 --- /dev/null +++ b/src/diffusers/pipelines/progressive_distillation/__init__.py @@ -0,0 +1 @@ +from .pipeline_progressive_distillation import DistillationPipeline diff --git a/src/diffusers/pipelines/progressive_distillation/pipeline_progressive_distillation.py b/src/diffusers/pipelines/progressive_distillation/pipeline_progressive_distillation.py new file mode 100644 index 000000000000..24b1021d338d --- /dev/null +++ b/src/diffusers/pipelines/progressive_distillation/pipeline_progressive_distillation.py @@ -0,0 +1,276 @@ +import copy +from random import sample +import os +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from PIL import Image +from tqdm.auto import tqdm +from accelerate import Accelerator +from diffusers import DiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from diffusers.pipelines.ddpm import DDPMPipeline +from diffusers.pipelines.ddim import DDIMPipeline +from diffusers.training_utils import EMAModel + + +def logsnr_schedule(t, logsnr_min=-20, logsnr_max=20): + logsnr_min = torch.tensor(logsnr_min, dtype=torch.float32) + logsnr_max = torch.tensor(logsnr_max, dtype=torch.float32) + b = torch.arctan(torch.exp(-0.5 * logsnr_max)) + a = torch.arctan(torch.exp(-0.5 * logsnr_min)) - b + return -2.0 * torch.log(torch.tan(a * t + b)) + + +def continuous_to_discrete_time(u, num_timesteps): + return (u * (num_timesteps - 1)).float().round().long() + + +def predict_x_from_v(*, z, v, logsnr): + logsnr = utils.broadcast_from_left(logsnr, z.shape) + alpha_t = torch.sqrt(F.sigmoid(logsnr)) + sigma_t = torch.sqrt(F.sigmoid(-logsnr)) + return alpha_t * z - sigma_t * v + + +def alpha_sigma_from_logsnr(logsnr): + alpha_t = torch.sqrt(F.sigmoid(logsnr)) + sigma_t = torch.sqrt(F.sigmoid(-logsnr)) + return alpha_t, sigma_t + + +class DistillationPipeline(DiffusionPipeline): + def __init__(self): + pass + + def __call__( + self, + teacher, + n_teacher_trainsteps, + train_data, + epochs=100, + lr=3e-4, + batch_size=64, + gamma=0, + gradient_accumulation_steps=1, + mixed_precision="no", + adam_beta1=0.95, + adam_beta2=0.999, + adam_weight_decay=0.001, + adam_epsilon=1e-08, + ema_inv_gamma=0.9999, + ema_power=3 / 4, + ema_max_decay=0.9999, + use_ema=True, + permute_samples=(0, 1, 2, 3), + generator=None, + accelerator=None, + sample_every: int = None, + sample_path: str = "distillation_samples", + ): + # Initialize our accelerator for training + os.makedirs(os.path.join(sample_path, f"{n_teacher_trainsteps}"), exist_ok=True) + if accelerator is None: + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + + if accelerator.is_main_process: + run = "distill" + accelerator.init_trackers(run) + + # Setup a dataloader with the provided train data + train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True) + + # Setup the noise schedulers for the teacher and student + teacher_scheduler = DDIMScheduler( + num_train_timesteps=n_teacher_trainsteps, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + student_scheduler = DDIMScheduler( + num_train_timesteps=n_teacher_trainsteps // 2, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + + # Initialize the student model as a direct copy of the teacher + student = copy.deepcopy(teacher) + student.load_state_dict(teacher.state_dict()) + student = accelerator.prepare(student) + student.train() + teacher.eval() + + # Setup the optimizer for the student + optimizer = torch.optim.AdamW( + student.parameters(), + lr=lr, + # betas=(adam_beta1, adam_beta2), + # weight_decay=adam_weight_decay, + # eps=adam_epsilon, + ) + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=np.ceil((epochs * len(train_dataloader)) // gradient_accumulation_steps), + ) + + # Let accelerate handle moving the model to the correct device + ( + teacher, + student, + optimizer, + lr_scheduler, + train_data, + teacher_scheduler, + student_scheduler, + ) = accelerator.prepare( + teacher, student, optimizer, lr_scheduler, train_data, teacher_scheduler, student_scheduler + ) + if not generator: + generator = torch.Generator().manual_seed(0) + + # generator = accelerator.prepare(generator) + ema_model = EMAModel( + student, + inv_gamma=ema_inv_gamma, + power=ema_power, + max_value=ema_max_decay, + ) + global_step = 0 + + # run pipeline in inference (sample random noise and denoise) on our teacher model as a baseline + pipeline = DDIMPipeline( + unet=teacher, + scheduler=teacher_scheduler, + ) + + images = pipeline(batch_size=4, generator=torch.manual_seed(0)).images + + # denormalize the images and save to tensorboard + # images_processed = (images * 255).round().astype("uint8") + for sample_number, img in enumerate(images): + + img.save(os.path.join(sample_path, f"{n_teacher_trainsteps}", f"baseline_sample_{sample_number}.png")) + + # Train the student + for epoch in range(epochs): + progress_bar = tqdm(total=len(train_data) // batch_size, disable=not accelerator.is_local_main_process) + progress_bar.set_description(f"Epoch {epoch}") + for batch in train_dataloader: + with accelerator.accumulate(student): + if isinstance(batch, dict): + batch = batch["images"] + batch = batch.to(accelerator.device) + noise = torch.randn(batch.shape, generator=generator).to(accelerator.device) + bsz = batch.shape[0] + # Sample a random timestep for each image + + u = torch.rand(size=(bsz,), generator=generator).to(accelerator.device) + u_1 = u - (0.5 / (n_teacher_trainsteps // 2)) + u_2 = u - (1 / (n_teacher_trainsteps // 2)) + # logsnr = logsnr_schedule(u) + # alpha_t, sigma_t = alpha_sigma_from_logsnr(logsnr) + with torch.no_grad(): + # Add noise to the image based on noise scheduler a t=timesteps + timesteps = continuous_to_discrete_time(u, n_teacher_trainsteps) + alpha_t, sigma_t = teacher_scheduler.get_alpha_sigma(batch, timesteps, accelerator.device) + z_t = alpha_t * batch + sigma_t * noise + # z_t = batch * torch.sqrt(F.sigmoid(logsnr)) + noise * torch.sqrt(F.sigmoid(-logsnr)) + + # teach_out_start = teacher(z_t, continuous_to_discrete_time(u, n_teacher_trainsteps)) + # x_pred = predict_x_from_v(teach_out_start) + # Take the first diffusion step with the teacher + v_pred_t = teacher(z_t.permute(*permute_samples), timesteps).sample.permute(*permute_samples) + + # reconstruct the image at timesteps using v diffusion + x_teacher_z_t = alpha_t * z_t - sigma_t * v_pred_t + # eps = (z - alpha*x)/sigma. + eps_pred = (z_t - alpha_t * x_teacher_z_t) / sigma_t + + # Add noise to the image based on noise scheduler a t=timesteps-1, to prepare for the next diffusion step + timesteps = continuous_to_discrete_time(u_1, n_teacher_trainsteps) + alpha_t_prime, sigma_t_prime = teacher_scheduler.get_alpha_sigma( + batch, timesteps, accelerator.device + ) + z_mid = alpha_t_prime * x_teacher_z_t + sigma_t_prime * eps_pred + # Take the second diffusion step with the teacher + v_pred_mid = teacher(z_mid.permute(*permute_samples), timesteps).sample.permute( + *permute_samples + ) + x_pred_mid = alpha_t_prime * z_mid - sigma_t_prime * v_pred_mid + + eps_pred = (z_mid - alpha_t_prime * x_pred_mid) / sigma_t_prime + + timesteps = continuous_to_discrete_time(u_2, n_teacher_trainsteps) + alpha_t_prime2, sigma_t_prime2 = teacher_scheduler.get_alpha_sigma( + batch, timesteps, accelerator.device + ) + z_teacher = alpha_t_prime2 * x_pred_mid + sigma_t_prime2 * eps_pred + sigma_frac = sigma_t / sigma_t_prime2 + + x_target = (z_teacher - sigma_frac * z_t) / (alpha_t_prime2 - sigma_frac * alpha_t) + eps_target = (z_teacher - alpha_t_prime2 * x_target) / sigma_t_prime2 + v_target = alpha_t * eps_target - sigma_t * x_target + + timesteps = continuous_to_discrete_time(u_2, n_teacher_trainsteps // 2) + noise_pred = student(z_t.permute(*permute_samples), timesteps).sample.permute(*permute_samples) + w = torch.pow(1 + alpha_t_prime2 / sigma_t_prime2, gamma) + loss = F.mse_loss(noise_pred * w, v_target * w) + accelerator.backward(loss) + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(student.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + if use_ema: + ema_model.step(student) + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + progress_bar.close() + if sample_every is not None: + if (epoch + 1) % sample_every == 0: + new_scheduler = DDIMScheduler( + num_train_timesteps=n_teacher_trainsteps // 2, + beta_schedule="squaredcos_cap_v2", + variance_type="v_diffusion", + prediction_type="v", + ) + pipeline = DDIMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if use_ema else student), + scheduler=new_scheduler, + ) + + # run pipeline in inference (sample random noise and denoise) + images = pipeline( + batch_size=4, + generator=torch.manual_seed(0), + num_inference_steps=n_teacher_trainsteps // 2, + ).images + + # denormalize the images and save to tensorboard + for sample_number, img in enumerate(images): + img.save( + os.path.join( + sample_path, f"{n_teacher_trainsteps}", f"epoch_{epoch}_sample_{sample_number}.png" + ) + ) + accelerator.wait_for_everyone() + return student, ema_model, accelerator diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 75cef635d063..8fdbe44cdcf6 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -27,13 +27,23 @@ from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM class DDIMSchedulerOutput(BaseOutput): """ - Output class for the scheduler's step function output. - Args: + Output class for the scheduler's step function output. prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the denoising loop. @@ -48,18 +58,13 @@ class DDIMSchedulerOutput(BaseOutput): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. Contains a function alpha_bar that takes an argument t and transforms it to the + cumulative product of (1-beta) up to that part of the diffusion process. + num_diffusion_timesteps (`int`): the number of betas to produce. max_beta (`float`): the maximum beta to use; + use values lower than 1 to prevent singularities. - Returns: betas (`np.ndarray`): the betas used by the scheduler to step the model outputs """ @@ -75,23 +80,40 @@ def alpha_bar(time_step): return torch.tensor(betas) -class DDIMScheduler(SchedulerMixin, ConfigMixin): - """ - Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising - diffusion probabilistic models (DDPMs) with non-Markovian guidance. +def _logsnr_schedule_cosine(t, logsnr_min=-20, logsnr_max=20): + logsnr_min = torch.tensor(logsnr_min, dtype=torch.float32) + logsnr_max = torch.tensor(logsnr_max, dtype=torch.float32) + b = torch.arctan(torch.exp(-0.5 * logsnr_max)) + a = torch.arctan(torch.exp(-0.5 * logsnr_min)) - b + return -2.0 * torch.log(torch.tan(a * t + b)) - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. - For more details, see the original paper: https://arxiv.org/abs/2010.02502 +def t_to_alpha_sigma(num_diffusion_timesteps): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + out = torch.FloatTensor([_logsnr_schedule_cosine(t) for t in torch.linspace(0, 1, 1000)]) + alphas = torch.sqrt(torch.sigmoid(out)) + sigmas = torch.sqrt(torch.sigmoid(-out)) + # alphas = torch.cos( + # torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + # ) + # sigmas = torch.sin( + # torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)]) + # ) + return alphas, sigmas + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. [`~ConfigMixin`] takes care of storing all + config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can + be accessed via `scheduler.config.num_train_timesteps`. [`~ConfigMixin`] also provides general loading and saving + functionality via the [`~ConfigMixin.save_config`] and [`~ConfigMixin.from_config`] functions. For more details, + see the original paper: https://arxiv.org/abs/2010.02502 + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from `linear`, `scaled_linear`, or `squaredcos_cap_v2`. trained_betas (`np.ndarray`, optional): @@ -106,7 +128,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. - """ _compatible_classes = [ @@ -128,7 +149,10 @@ def __init__( trained_betas: Optional[np.ndarray] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, + variance_type: str = "fixed", steps_offset: int = 0, + prediction_type: str = "epsilon", + **kwargs, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -145,14 +169,18 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.variance_type = variance_type self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + if prediction_type == "v": + self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps) # At every step in ddim, we are looking into the previous alphas_cumprod # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -160,36 +188,49 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ + Args: Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`int`, optional): current timestep - + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep Returns: `torch.FloatTensor`: scaled input sample """ return sample - def _get_variance(self, timestep, prev_timestep): + def _get_variance(self, timestep, prev_timestep, eta=0): alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev - variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) - + if self.variance_type == "fixed": + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + elif self.variance_type == "v_diffusion": + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + # variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma + if eta: + numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt() + else: + numerator = 0 + denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt() + ddim_sigma = (numerator * denominator).clamp(min=1.0e-7) + variance = (sigma_prev**2 - ddim_sigma**2).sqrt() + if torch.isnan(variance): + variance = 0 return variance def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ - Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. - Args: + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. """ @@ -213,30 +254,30 @@ def step( return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: """ + Args: Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): + model_output (`torch.FloatTensor`): direct output from learned diffusion model. timestep (`int`): current + discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): current instance of sample being created by diffusion process. - eta (`float`): weight of noise for added noise in diffusion step. - use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + prediction_type (`str`): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample), or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + eta (`float`): weight of noise for added noise in diffusion step. use_clipped_model_output (`bool`): if + `True`, compute "corrected" `model_output` from the clipped predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would coincide with the one provided as input and `use_clipped_model_output` will have not effect. - generator: random number generator. - variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + generator: random number generator. variance_noise (`torch.FloatTensor`): instead of generating noise for + the variance using `generator`, we can directly provide the noise for the variance itself. This is useful for methods such as CycleDiffusion. (https://arxiv.org/abs/2210.05559) return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class - Returns: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - """ if self.num_inference_steps is None: raise ValueError( @@ -247,14 +288,14 @@ def step( # Ideally, read DDIM paper in-detail understanding # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - pred_noise_t -> e_theta(x_t, timestep) + # - pred_original_sample -> f_theta(x_t, timestep) or x_0 # - std_dev_t -> sigma_t # - eta -> ฮท # - pred_sample_direction -> "direction pointing to x_t" # - pred_prev_sample -> "x_t-1" - # 1. get previous step value (=t-1) + # 1. get previous step value (=timestep-1) prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps # 2. compute alphas, betas @@ -265,7 +306,21 @@ def step( # 3. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + eps = torch.tensor(1) + elif self.prediction_type == "sample": + pred_original_sample = model_output + eps = torch.tensor(1) + elif self.prediction_type == "v": + # v_t = alpha_t * epsilon - sigma_t * x + # need to merge the PRs for sigma to be available in DDPM + pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep] + eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep] + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) # 4. Clip "predicted x_0" if self.config.clip_sample: @@ -273,7 +328,7 @@ def step( # 5. compute variance: "sigma_t(ฮท)" -> see formula (16) # ฯƒ_t = sqrt((1 โˆ’ ฮฑ_tโˆ’1)/(1 โˆ’ ฮฑ_t)) * sqrt(1 โˆ’ ฮฑ_t/ฮฑ_tโˆ’1) - variance = self._get_variance(timestep, prev_timestep) + variance = self._get_variance(timestep, prev_timestep, eta) std_dev_t = eta * variance ** (0.5) if use_clipped_model_output: @@ -281,10 +336,14 @@ def step( model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + if self.prediction_type == "epsilon": + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output - # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction + else: + alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + prev_sample = pred_original_sample * alpha_prev + eps * variance if eta > 0: # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 @@ -307,7 +366,6 @@ def step( variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise prev_sample = prev_sample + variance - if not return_dict: return (prev_sample,) @@ -319,6 +377,10 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -338,3 +400,13 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device) + return alpha, sigma + + def get_alpha_sigma_from_logsnr(self, sample, logsnr, device): + alpha = expand_to_shape(self.alphas, logsnr, sample.shape, device) + sigma = expand_to_shape(self.sigmas, logsnr, sample.shape, device) + return alpha, sigma diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 08a73119e5b7..b5602eb69c56 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -26,6 +26,17 @@ from .scheduling_utils import SchedulerMixin +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out + + @dataclass class DDPMSchedulerOutput(BaseOutput): """ @@ -99,9 +110,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. clip_sample (`bool`, default `True`): option to clip predicted sample between -1 and 1 for numerical stability. - predict_epsilon (`bool`): - optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise. - + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ _compatible_classes = [ @@ -123,7 +135,7 @@ def __init__( trained_betas: Optional[np.ndarray] = None, variance_type: str = "fixed_small", clip_sample: bool = True, - predict_epsilon: bool = True, + prediction_type: str = "epsilon", ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -146,7 +158,8 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.one = torch.tensor(1.0) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod) # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -156,6 +169,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) self.variance_type = variance_type + self.prediction_type = prediction_type def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ @@ -186,14 +200,14 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic )[::-1].copy() self.timesteps = torch.from_numpy(timesteps).to(device) - def _get_variance(self, t, predicted_variance=None, variance_type=None): - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + def _get_variance(self, timestep, predicted_variance=None, variance_type=None): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) - # For t > 0, compute predicted variance ฮฒt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) + # For timestep > 0, compute predicted variance ฮฒt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) # and sample from it to get previous sample - # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample - variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] + # x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep] if variance_type is None: variance_type = self.config.variance_type @@ -205,17 +219,19 @@ def _get_variance(self, t, predicted_variance=None, variance_type=None): elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) elif variance_type == "fixed_large": - variance = self.betas[t] + variance = self.betas[timestep] elif variance_type == "fixed_large_log": # Glide max_log - variance = torch.log(self.betas[t]) + variance = torch.log(self.betas[timestep]) elif variance_type == "learned": return predicted_variance elif variance_type == "learned_range": min_log = variance - max_log = self.betas[t] + max_log = self.betas[timestep] frac = (predicted_variance + 1) / 2 variance = frac * max_log + (1 - frac) * min_log + elif variance_type == "v_diffusion": + variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) return variance @@ -246,6 +262,8 @@ def step( returning a tuple, the first element is the sample tensor. """ + if self.variance_type == "v_diffusion": + assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" message = ( "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" " DDPMScheduler.from_config(, predict_epsilon=True)`." @@ -256,25 +274,39 @@ def step( new_config["predict_epsilon"] = predict_epsilon self._internal_dict = FrozenDict(new_config) - t = timestep - if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas - alpha_prod_t = self.alphas_cumprod[t] - alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if self.config.predict_epsilon: + if self.prediction_type == "v": + # x_recon in p_mean_variance + pred_original_sample = ( + sample * self.sqrt_alphas_cumprod[timestep] + - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] + ) + elif self.prediction_type == "epsilon" or self.config.predict_epsilon: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - else: + + elif self.prediction_type == "sample": pred_original_sample = model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" + ) + + # pred_original_sample = ( + # sample * self.sqrt_alphas_cumprod[timestep] - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] + # ) + # eps = model_output * self.sqrt_alphas_cumprod[timestep] - sample * self.sqrt_one_minus_alphas_cumprod[timestep] # 3. Clip "predicted x_0" if self.config.clip_sample: @@ -282,8 +314,8 @@ def step( # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf - pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t - current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t + pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t + current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t # 5. Compute predicted previous sample ยต_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -291,11 +323,16 @@ def step( # 6. Add noise variance = 0 - if t > 0: + if timestep > 0: noise = torch.randn( model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator ).to(model_output.device) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(timestep, predicted_variance=predicted_variance) * noise + elif self.variance_type == "v_diffusion": + variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * noise + else: + variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise pred_prev_sample = pred_prev_sample + variance @@ -310,6 +347,11 @@ def add_noise( noise: torch.FloatTensor, timesteps: torch.IntTensor, ) -> torch.FloatTensor: + if self.variance_type == "v_diffusion": + alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device) + z_t = alpha * original_samples + sigma * noise + return z_t + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) timesteps = timesteps.to(original_samples.device) @@ -329,3 +371,8 @@ def add_noise( def __len__(self): return self.config.num_train_timesteps + + def get_alpha_sigma(self, sample, timesteps, device): + alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device) + sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device) + return alpha, sigma diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index af2e0c7c61d6..7790a707a609 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -197,6 +197,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DistillationPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class KarrasVePipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/progressive_distillation/test_progressive_distillation.py b/tests/pipelines/progressive_distillation/test_progressive_distillation.py new file mode 100644 index 000000000000..03547a9cd6db --- /dev/null +++ b/tests/pipelines/progressive_distillation/test_progressive_distillation.py @@ -0,0 +1,88 @@ +import gc +import unittest + +import numpy as np +import torch +from torch.utils.data import Dataset +from diffusers import DistillationPipeline, UNet2DModel, DDPMScheduler, DDPMPipeline +from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class SingleImageDataset(Dataset): + def __init__(self, image, batch_size): + self.image = image + self.batch_size = batch_size + + def __len__(self): + return self.batch_size + + def __getitem__(self, idx): + return self.image + + +class PipelineFastTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @property + def dummy_unet(self): + torch.manual_seed(0) + model = UNet2DModel( + sample_size=64, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + return model + + def test_progressive_distillation(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + pipe = DistillationPipeline() + pipe.set_progress_bar_config(disable=None) + generator = torch.Generator(device=device).manual_seed(0) + # create a dummy dataset with a random image + image = torch.rand(3, 64, 64, device=device, generator=generator) + dataset = SingleImageDataset(image, batch_size=2) + teacher, distilled_ema, distill_accelrator = pipe( + teacher=self.dummy_unet, train_data=dataset, n_teacher_trainsteps=100, epochs=1, generator=generator + ) + new_scheduler = DDPMScheduler(num_train_timesteps=50, beta_schedule="squaredcos_cap_v2") + pipeline = DDPMPipeline( + unet=distill_accelrator.unwrap_model(distilled_ema.averaged_model), + scheduler=new_scheduler, + ) + + # run pipeline in inference (sample random noise and denoise) + images = pipeline(generator=generator, batch_size=2, output_type="numpy").images + image_slice = images[0, -3:, -3:].flatten()[:10] + print(image_slice) + assert images.shape == (2, 64, 64, 3) + expected_slice = np.array( + [0.11791468, 0.04737437, 0.0, 0.74979293, 0.3200513, 0.43817604, 0.83634996, 0.10667279, 0.0, 0.29753304] + ) + assert np.abs(image_slice - expected_slice).max() < 1e-2