diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.ipynb b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.ipynb new file mode 100644 index 00000000..dfc9d932 --- /dev/null +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.ipynb @@ -0,0 +1,1704 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bc11fdc9", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "95c08725", + "metadata": {}, + "source": [ + "# Super-resolution using Stable Diffusion v2 Upscalers\n", + "\n", + "Tutorial to illustrate the super-resolution task on medical images using Latent Diffusion Models (LDMs) [1]. For that, we will use an autoencoder to obtain a latent representation of the high-resolution images. Then, we train a diffusion model to infer this latent representation when conditioned on a low-resolution image.\n", + "\n", + "To improve the performance of our models, we will use a method called \"noise conditioning augmentation\" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples.\n", + "\n", + "\n", + "[1] - Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n", + "\n", + "[2] - Ho et al. \"Cascaded diffusion models for high fidelity image generation\" https://arxiv.org/abs/2106.15282\n", + "\n", + "[3] - Ho et al. \"High Definition Video Generation with Diffusion Models\" https://arxiv.org/abs/2210.02303" + ] + }, + { + "cell_type": "markdown", + "id": "b839bf2d", + "metadata": {}, + "source": [ + "## Set up environment using Colab\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "77f7e633", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"\n", + "!python -c \"import pytorch_lightning\" || pip install pytorch-lightning\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "id": "214066de", + "metadata": {}, + "source": [ + "## Set up imports" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "de71fe08", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:22:32,886 - WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:\n", + " PyTorch 1.13.1+cu117 with CUDA 1107 (you have 1.12.1)\n", + " Python 3.8.16 (you have 3.8.16)\n", + " Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)\n", + " Memory-efficient attention, SwiGLU, sparse and more won't be available.\n", + " Set XFORMERS_MORE_DETAILS=1 for more details\n", + "2023-05-12 17:22:35,069 - Created a temporary directory at /tmp/tmph4_v9gin\n", + "2023-05-12 17:22:35,071 - Writing /tmp/tmph4_v9gin/_remote_module_non_scriptable.py\n", + "MONAI version: 1.2.dev2304\n", + "Numpy version: 1.23.5\n", + "Pytorch version: 1.12.1\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0\n", + "MONAI __file__: /home/ol18/miniconda3/envs/monai_generative/lib/python3.8/site-packages/monai_weekly-1.2.dev2304-py3.8.egg/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.10\n", + "ITK version: 5.3.0\n", + "Nibabel version: 5.1.0\n", + "scikit-image version: 0.20.0\n", + "Pillow version: 9.4.0\n", + "Tensorboard version: 2.12.1\n", + "gdown version: 4.7.1\n", + "TorchVision version: 0.13.1\n", + "tqdm version: 4.65.0\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 2.0.0\n", + "einops version: 0.6.0\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.2.2\n", + "pynrrd version: 1.0.0\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "import os\n", + "import shutil\n", + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from monai import transforms\n", + "from monai.apps import MedNISTDataset\n", + "from monai.config import print_config\n", + "from monai.data import CacheDataset, ThreadDataLoader\n", + "from monai.utils import first, set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from torch import nn\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from generative.losses import PatchAdversarialLoss, PerceptualLoss\n", + "from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n", + "from generative.networks.schedulers import DDPMScheduler\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9f0a17bc", + "metadata": {}, + "outputs": [], + "source": [ + "# for reproducibility purposes set a seed\n", + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "c0dde922", + "metadata": {}, + "source": [ + "## Setup a data directory and download dataset\n", + "Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ded618a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmphf3tvpfi\n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "298d964a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "MedNIST.tar.gz: 59.0MB [00:01, 36.4MB/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:22:36,915 - INFO - Downloaded: /tmp/tmphf3tvpfi/MedNIST.tar.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:22:37,020 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2023-05-12 17:22:37,022 - INFO - Writing into directory: /tmp/tmphf3tvpfi.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 47164/47164 [00:35<00:00, 1323.43it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:23:21,457 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2023-05-12 17:23:21,459 - INFO - File exists: /tmp/tmphf3tvpfi/MedNIST.tar.gz, skipped downloading.\n", + "2023-05-12 17:23:21,460 - INFO - Non-empty folder exists in /tmp/tmphf3tvpfi/MedNIST, skipped extracting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 5895/5895 [00:04<00:00, 1319.58it/s]\n" + ] + } + ], + "source": [ + "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\",\n", + " download=True, seed=0)\n", + "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n", + "val_data = MedNISTDataset(root_dir=root_dir, section=\"validation\",\n", + " download=True, seed=0)\n", + "val_datalist = [{\"image\": item[\"image\"]} for item in val_data.data if item[\"class_name\"] == \"HeadCT\"]" + ] + }, + { + "cell_type": "markdown", + "id": "46bafb78", + "metadata": {}, + "source": [ + "## Setup utils functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4f8eff03", + "metadata": {}, + "outputs": [], + "source": [ + "def get_train_transforms():\n", + " image_size = 64\n", + " train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0,\n", + " b_max=1.0, clip=True),\n", + " transforms.RandAffined(\n", + " keys=[\"image\"],\n", + " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", + " translate_range=[(-1, 1), (-1, 1)],\n", + " scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n", + " spatial_size=[image_size, image_size],\n", + " padding_mode=\"zeros\",\n", + " prob=0.5,\n", + " ),\n", + " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", + " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", + " ]\n", + " )\n", + " return train_transforms\n", + "\n", + "def get_val_transforms():\n", + " val_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, \n", + " b_max=1.0, clip=True),\n", + " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", + " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", + " ]\n", + " )\n", + " return val_transforms\n", + "\n", + " \n", + "def get_datasets():\n", + " train_transforms = get_train_transforms()\n", + " val_transforms = get_val_transforms()\n", + " train_ds = CacheDataset(data=train_datalist[:320], transform=train_transforms)\n", + " val_ds = CacheDataset(data=val_datalist[:32], transform=val_transforms)\n", + " return train_ds, val_ds\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "d80e045b", + "metadata": {}, + "source": [ + "## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc)\n", + "The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d5d1caff", + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "class AutoEncoder(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.data_dir = root_dir\n", + " self.autoencoderkl = AutoencoderKL(spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " num_channels=(256, 512, 512),\n", + " latent_channels=3,\n", + " num_res_blocks=2,\n", + " norm_num_groups=32,\n", + " attention_levels=(False, False, True))\n", + " self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1,\n", + " num_layers_d=3, num_channels=64)\n", + " self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"alex\")\n", + " self.perceptual_weight = 0.002\n", + " self.autoencoder_warm_up_n_epochs = 10\n", + " self.automatic_optimization = False\n", + " self.adv_loss = PatchAdversarialLoss(criterion=\"least_squares\")\n", + " self.adv_weight = 0.005\n", + " self.kl_weight = 1e-6\n", + " \n", + " def forward(self, z):\n", + " return self.autoencoderkl(z)\n", + "\n", + " def prepare_data(self):\n", + " self.train_ds, self.val_ds = get_datasets()\n", + " \n", + " def train_dataloader(self):\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n", + " num_workers=4, persistent_workers=True)\n", + " \n", + " def val_dataloader(self):\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n", + " num_workers=4)\n", + " \n", + " def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):\n", + " recons_loss = F.l1_loss(reconstruction.float(), images.float())\n", + " p_loss = self.perceptual_loss(reconstruction.float(), images.float())\n", + " kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])\n", + " kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n", + " loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss)\n", + " return loss_g,recons_loss\n", + " \n", + " def _compute_loss_discriminator(self, images, reconstruction):\n", + " logits_fake = self.discriminator(reconstruction.contiguous().detach())[-1]\n", + " loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True)\n", + " logits_real = self.discriminator(images.contiguous().detach())[-1]\n", + " loss_d_real = self.adv_loss(logits_real, target_is_real=True, for_discriminator=True)\n", + " discriminator_loss = (loss_d_fake + loss_d_real) * 0.5\n", + " loss_d = self.adv_weight * discriminator_loss\n", + " return loss_d, discriminator_loss\n", + " \n", + " def training_step(self, batch, batch_idx):\n", + " optimizer_g, optimizer_d = self.optimizers()\n", + " images = batch[\"image\"]\n", + " reconstruction, z_mu, z_sigma = self.forward(images)\n", + " loss_g, recons_loss = self._compute_loss_generator(images, reconstruction, z_mu, z_sigma)\n", + " self.log(\"recons_loss\", recons_loss, batch_size=16, prog_bar=True)\n", + "\n", + " if self.current_epoch > self.autoencoder_warm_up_n_epochs:\n", + " logits_fake = self.discriminator(reconstruction.contiguous().float())[-1]\n", + " generator_loss = self.adv_loss(logits_fake, target_is_real=True, for_discriminator=False)\n", + " loss_g += self.adv_weight * generator_loss\n", + " self.log(\"gen_loss\", generator_loss, batch_size=16, prog_bar=True)\n", + " \n", + " \n", + "\n", + " self.log(\"loss_g\", loss_g, batch_size=16, prog_bar=True)\n", + " self.manual_backward(loss_g)\n", + " optimizer_g.step()\n", + " optimizer_g.zero_grad()\n", + " self.untoggle_optimizer(optimizer_g)\n", + "\n", + " if self.current_epoch > self.autoencoder_warm_up_n_epochs:\n", + " loss_d, discriminator_loss = self._compute_loss_discriminator(images, reconstruction)\n", + " self.log(\"disc_loss\", loss_d, batch_size=16, prog_bar=True)\n", + " self.log(\"train_loss_d\", loss_d, batch_size=16, prog_bar=True)\n", + " self.manual_backward(loss_d)\n", + " optimizer_d.step()\n", + " optimizer_d.zero_grad()\n", + " self.untoggle_optimizer(optimizer_d)\n", + "\n", + " \n", + " \n", + " def validation_step(self, batch, batch_idx):\n", + " images = batch[\"image\"]\n", + " reconstruction, z_mu, z_sigma = self.autoencoderkl(images)\n", + " recons_loss = F.l1_loss(images.float(), reconstruction.float())\n", + " self.log(\"val_loss_d\", recons_loss, batch_size=1, prog_bar=True)\n", + " self.images = images\n", + " self.reconstruction = reconstruction\n", + "\n", + "\n", + " def on_validation_epoch_end(self):\n", + " # ploting reconstruction\n", + " plt.figure(figsize=(2, 2))\n", + " plt.imshow(torch.cat([self.images[0, 0].cpu(), \n", + " self.reconstruction[0, 0].cpu()],\n", + " dim=1), vmin=0, vmax=1, cmap=\"gray\")\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + " \n", + "\n", + " def configure_optimizers(self):\n", + " optimizer_g = torch.optim.Adam(self.autoencoderkl.parameters(), lr=5e-5)\n", + " optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)\n", + " return [optimizer_g, optimizer_d], []\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "id": "c16de505", + "metadata": {}, + "source": [ + "## Train Autoencoder" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9d903aaa", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.\n", + "Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:23:29,608 - GPU available: True (cuda), used: True\n", + "2023-05-12 17:23:29,609 - TPU available: False, using: 0 TPU cores\n", + "2023-05-12 17:23:29,610 - IPU available: False, using: 0 IPUs\n", + "2023-05-12 17:23:29,611 - HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 320/320 [00:00<00:00, 858.75it/s]\n", + "Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 284.05it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:23:30,118 - Missing logger folder: /tmp/tmphf3tvpfi/lightning_logs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Checkpoint directory /tmp/tmphf3tvpfi exists and is not empty.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:23:31,637 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "2023-05-12 17:23:31,652 - \n", + " | Name | Type | Params\n", + "---------------------------------------------------------\n", + "0 | autoencoderkl | AutoencoderKL | 75.1 M\n", + "1 | discriminator | PatchDiscriminator | 2.8 M \n", + "2 | perceptual_loss | PerceptualLoss | 2.5 M \n", + "3 | adv_loss | PatchAdversarialLoss | 0 \n", + "---------------------------------------------------------\n", + "77.8 M Trainable params\n", + "2.5 M Non-trainable params\n", + "80.3 M Total params\n", + "321.225 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e345b86c597145adad868d37e0416041", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:41:32,347 - `Trainer.fit` stopped: `max_epochs=75` reached.\n" + ] + } + ], + "source": [ + "n_epochs = 75 \n", + "val_interval = 10\n", + "\n", + " \n", + "# initialise the LightningModule\n", + "ae_net = AutoEncoder()\n", + "\n", + "# set up checkpoints\n", + "\n", + "checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename=\"best_metric_model\")\n", + "\n", + " \n", + "# initialise Lightning's trainer.\n", + "trainer = pl.Trainer(devices=1,\n", + " max_epochs=n_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir)\n", + "\n", + "# train\n", + "trainer.fit(ae_net)" + ] + }, + { + "cell_type": "markdown", + "id": "c7108b87", + "metadata": {}, + "source": [ + "## Rescaling factor\n", + "\n", + "As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ccb6ba9f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaling factor set to 0.5302040576934814\n" + ] + } + ], + "source": [ + "def get_scale_factor():\n", + " ae_net.eval()\n", + " device = torch.device(\"cuda:0\")\n", + " ae_net.to(device)\n", + "\n", + " train_loader = ae_net.train_dataloader()\n", + " check_data = first(train_loader)\n", + " z = ae_net.autoencoderkl.encode_stage_2_inputs(check_data[\"image\"].to(ae_net.device))\n", + " print(f\"Scaling factor set to {1/torch.std(z)}\")\n", + " scale_factor = 1 / torch.std(z)\n", + " return scale_factor\n", + "\n", + "scale_factor = get_scale_factor()" + ] + }, + { + "cell_type": "markdown", + "id": "3baa2b0f", + "metadata": {}, + "source": [ + "## Define the LightningModule for DiffusionModelUnet (transforms, network, loaders, etc)\n", + "The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "731034ec", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "class DiffusionUNET(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.data_dir = root_dir\n", + " self.unet = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=4,\n", + " out_channels=3,\n", + " num_res_blocks=2,\n", + " num_channels=(256, 256, 512, 1024),\n", + " attention_levels=(False, False, True, True),\n", + " num_head_channels=(0, 0, 64, 64),\n", + " )\n", + " self.max_noise_level = 350\n", + " self.scheduler = DDPMScheduler(num_train_timesteps=1000, \n", + " beta_schedule=\"linear\",\n", + " beta_start=0.0015,\n", + " beta_end=0.0195)\n", + " self.z = ae_net.autoencoderkl.eval()\n", + "\n", + "\n", + " def forward(self, x, timesteps, low_res_timesteps):\n", + " return self.unet(x=x, \n", + " timesteps=timesteps,\n", + " class_labels=low_res_timesteps)\n", + " \n", + " \n", + " def prepare_data(self):\n", + " self.train_ds, self.val_ds = get_datasets()\n", + " \n", + " def train_dataloader(self):\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n", + " num_workers=4, persistent_workers=True)\n", + " \n", + " def val_dataloader(self):\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n", + " num_workers=4)\n", + " \n", + " def _calculate_loss(self, batch, batch_idx, plt_image=False):\n", + " images = batch[\"image\"]\n", + " low_res_image = batch[\"low_res_image\"] \n", + " with autocast(enabled=True):\n", + " with torch.no_grad():\n", + " latent = self.z.encode_stage_2_inputs(images) * scale_factor\n", + " \n", + " # Noise augmentation\n", + " noise = torch.randn_like(latent)\n", + " low_res_noise = torch.randn_like(low_res_image)\n", + " timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (latent.shape[0],),\n", + " device=latent.device).long()\n", + " low_res_timesteps = torch.randint(\n", + " 0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device\n", + " ).long()\n", + "\n", + " noisy_latent = self.scheduler.add_noise(original_samples=latent, \n", + " noise=noise, timesteps=timesteps)\n", + " noisy_low_res_image = self.scheduler.add_noise(\n", + " original_samples=low_res_image, noise=low_res_noise, \n", + " timesteps=low_res_timesteps\n", + " )\n", + "\n", + " latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n", + "\n", + " noise_pred = self.forward(latent_model_input, timesteps, low_res_timesteps)\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + " \n", + " if plt_image:\n", + " # Sampling image during training\n", + " sampling_image = low_res_image[0].unsqueeze(0)\n", + " latents = torch.randn((1, 3, 16, 16)).to(sampling_image.device)\n", + " low_res_noise = torch.randn((1, 1, 16, 16)).to(sampling_image.device)\n", + " noise_level = 20\n", + " noise_level = torch.Tensor((noise_level,)).long().to(sampling_image.device)\n", + " \n", + " noisy_low_res_image = self.scheduler.add_noise(\n", + " original_samples=sampling_image,\n", + " noise=low_res_noise,\n", + " timesteps=noise_level,\n", + " )\n", + " self.scheduler.set_timesteps(num_inference_steps=1000)\n", + " for t in tqdm(self.scheduler.timesteps, ncols=110):\n", + " with autocast(enabled=True):\n", + " with torch.no_grad():\n", + " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", + " noise_pred = self.forward(latent_model_input, \n", + " torch.Tensor((t,)).to(sampling_image.device)\n", + " , noise_level)\n", + " latents, _ = self.scheduler.step(noise_pred, t, latents)\n", + " with torch.no_grad():\n", + " decoded = self.z.decode_stage_2_outputs(latents / scale_factor)\n", + " low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + " # plot images\n", + " \n", + " self.images = images\n", + " self.low_res_bicubic = low_res_bicubic\n", + " self.decoded = decoded\n", + " \n", + " return loss\n", + " \n", + " def _plot_image(self, images, low_res_bicubic, decoded):\n", + " plt.figure(figsize=(2, 2))\n", + " plt.style.use(\"default\")\n", + " plt.imshow(\n", + " torch.cat([images[0, 0].cpu(), low_res_bicubic[0, 0].cpu(), decoded[0, 0].cpu()], dim=1),\n", + " vmin=0,\n", + " vmax=1,\n", + " cmap=\"gray\",\n", + " )\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss = self._calculate_loss(batch, batch_idx)\n", + " self.log(\"train_loss\", loss, batch_size=16, prog_bar=True)\n", + " return loss\n", + " \n", + " def validation_step(self, batch, batch_idx):\n", + " loss = self._calculate_loss(batch, batch_idx, plt_image=True)\n", + " self.log(\"val_loss\", loss, batch_size=16, prog_bar=True)\n", + " return loss\n", + " \n", + " def on_validation_epoch_end(self):\n", + " self._plot_image(self.images, self.low_res_bicubic, self.decoded)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.unet.parameters(), lr=5e-5)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "id": "b386a0c2", + "metadata": {}, + "source": [ + "## Train Diffusion Model\n", + "\n", + "In order to train the diffusion model to perform super-resolution, we will need to concatenate the latent representation of the high-resolution with the low-resolution image. For this, we create a Diffusion model with `in_channels=4`. Since only the outputted latent representation is interesting, we set `out_channels=3`.\n", + "\n", + "As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler low_res_scheduler to add this noise, with the t step defining the signal-to-noise ratio and use the t value to condition the diffusion model (inputted using class_labels argument)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "936bbb9c", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:41:38,503 - GPU available: True (cuda), used: True\n", + "2023-05-12 17:41:38,503 - TPU available: False, using: 0 TPU cores\n", + "2023-05-12 17:41:38,504 - IPU available: False, using: 0 IPUs\n", + "2023-05-12 17:41:38,504 - HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 320/320 [00:00<00:00, 773.44it/s]\n", + "Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 495.25it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 17:41:39,194 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "2023-05-12 17:41:39,226 - \n", + " | Name | Type | Params\n", + "-------------------------------------------------\n", + "0 | unet | DiffusionModelUNet | 266 M \n", + "1 | scheduler | DDPMScheduler | 0 \n", + "2 | z | AutoencoderKL | 75.1 M\n", + "-------------------------------------------------\n", + "342 M Trainable params\n", + "0 Non-trainable params\n", + "342 M Total params\n", + "1,368.189 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "71e4d5d2e391477aac58bfb05ecefb85", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "be43a3c9e09a4405a067d9fc887e151a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5aa3448de1094d0f82f2808b8131bb9e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9b9fad73a311409082600b576378c2b7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "323be58fabcb4ab4a9a128d9f6cc9b52", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4b79fa6a38cd49f7bb65fc6ad327af73", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "02c9a151f246464da9eebd015e6beb9d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cbaf0368cb304e6bbfa8ceac20f5a247", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2a95535f104048b2ac1756ff1db454f8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8959e5f59fb14792abea3737d90363ee", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "40b8269d8e05404da709e1f212edc47c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2023-05-12 18:22:06,851 - `Trainer.fit` stopped: `max_epochs=200` reached.\n" + ] + } + ], + "source": [ + "n_epochs = 200\n", + "val_interval = 20\n", + "\n", + " \n", + "# initialise the LightningModule\n", + "d_net = DiffusionUNET()\n", + "\n", + "# set up checkpoints\n", + "\n", + "checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename=\"best_metric_model_dunet\")\n", + "\n", + " \n", + "# initialise Lightning's trainer.\n", + "trainer = pl.Trainer(devices=1,\n", + " max_epochs=n_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir)\n", + "\n", + "# train\n", + "trainer.fit(d_net)" + ] + }, + { + "cell_type": "markdown", + "id": "30f24595", + "metadata": {}, + "source": [ + "### Plotting sampling example" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "155be091", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "631635b665454a2884dc88d7d466c50b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00 x_t-1\n", + " latents, _ = scheduler.step(noise_pred, t, latents)\n", + "\n", + " with torch.no_grad():\n", + " decoded = ae_net.autoencoderkl.decode_stage_2_outputs(latents / scale_factor)\n", + " return sampling_image, images, decoded\n", + "\n", + "sampling_image, images, decoded = get_images_to_plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "32e16e69", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + "fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8))\n", + "axs[0, 0].set_title(\"Original image\")\n", + "axs[0, 1].set_title(\"Low-resolution Image\")\n", + "axs[0, 2].set_title(\"Outputted image\")\n", + "for i in range(0, num_samples):\n", + " axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 0].axis(\"off\")\n", + " axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 1].axis(\"off\")\n", + " axs[i, 2].imshow(decoded[i, 0].cpu().detach().numpy(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 2].axis(\"off\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "7fa52acc", + "metadata": {}, + "source": [ + "### Clean-up data directory" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3a6f6d5a", + "metadata": {}, + "outputs": [], + "source": [ + "if directory is None:\n", + " shutil.rmtree(root_dir)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.py b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.py new file mode 100644 index 00000000..34c8c8f0 --- /dev/null +++ b/tutorials/generative/2d_super_resolution/2d_stable_diffusion_v2_super_resolution-lightning.py @@ -0,0 +1,540 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.5 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% [markdown] +# # Super-resolution using Stable Diffusion v2 Upscalers +# +# Tutorial to illustrate the super-resolution task on medical images using Latent Diffusion Models (LDMs) [1]. For that, we will use an autoencoder to obtain a latent representation of the high-resolution images. Then, we train a diffusion model to infer this latent representation when conditioned on a low-resolution image. +# +# To improve the performance of our models, we will use a method called "noise conditioning augmentation" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples. +# +# +# [1] - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 +# +# [2] - Ho et al. "Cascaded diffusion models for high fidelity image generation" https://arxiv.org/abs/2106.15282 +# +# [3] - Ho et al. "High Definition Video Generation with Diffusion Models" https://arxiv.org/abs/2210.02303 + +# %% [markdown] +# ## Set up environment using Colab +# + +# %% +# !python -c "import monai" || pip install -q "monai-weekly[tqdm]" +# !python -c "import pytorch_lightning" || pip install pytorch-lightning +# !python -c "import matplotlib" || pip install -q matplotlib +# %matplotlib inline + +# %% [markdown] +# ## Set up imports + +# %% +import os +import shutil +import tempfile +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from monai import transforms +from monai.apps import MedNISTDataset +from monai.config import print_config +from monai.data import CacheDataset, ThreadDataLoader +from monai.utils import first, set_determinism +from torch.cuda.amp import GradScaler, autocast +from torch import nn +from tqdm.notebook import tqdm + +from generative.losses import PatchAdversarialLoss, PerceptualLoss +from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator +from generative.networks.schedulers import DDPMScheduler + +import pytorch_lightning as pl +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + +print_config() + +# %% +# for reproducibility purposes set a seed +set_determinism(42) + +# %% [markdown] +# ## Setup a data directory and download dataset +# Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used. + +# %% +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory +print(root_dir) + +# %% +train_data = MedNISTDataset(root_dir=root_dir, section="training", download=True, seed=0) +train_datalist = [{"image": item["image"]} for item in train_data.data if item["class_name"] == "HeadCT"] +val_data = MedNISTDataset(root_dir=root_dir, section="validation", download=True, seed=0) +val_datalist = [{"image": item["image"]} for item in val_data.data if item["class_name"] == "HeadCT"] + + +# %% [markdown] +# ## Setup utils functions + + +# %% +def get_train_transforms(): + image_size = 64 + train_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + transforms.RandAffined( + keys=["image"], + rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)], + translate_range=[(-1, 1), (-1, 1)], + scale_range=[(-0.05, 0.05), (-0.05, 0.05)], + spatial_size=[image_size, image_size], + padding_mode="zeros", + prob=0.5, + ), + transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]), + transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)), + ] + ) + return train_transforms + + +def get_val_transforms(): + val_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"]), + transforms.EnsureChannelFirstd(keys=["image"]), + transforms.ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), + transforms.CopyItemsd(keys=["image"], times=1, names=["low_res_image"]), + transforms.Resized(keys=["low_res_image"], spatial_size=(16, 16)), + ] + ) + return val_transforms + + +def get_datasets(): + train_transforms = get_train_transforms() + val_transforms = get_val_transforms() + train_ds = CacheDataset(data=train_datalist[:320], transform=train_transforms) + val_ds = CacheDataset(data=val_datalist[:32], transform=val_transforms) + return train_ds, val_ds + + +# %% [markdown] +# ## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc) +# The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution. +# + + +# %% +class AutoEncoder(pl.LightningModule): + def __init__(self): + super().__init__() + self.data_dir = root_dir + self.autoencoderkl = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(256, 512, 512), + latent_channels=3, + num_res_blocks=2, + norm_num_groups=32, + attention_levels=(False, False, True), + ) + self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, num_channels=64) + self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type="alex") + self.perceptual_weight = 0.002 + self.autoencoder_warm_up_n_epochs = 10 + self.automatic_optimization = False + self.adv_loss = PatchAdversarialLoss(criterion="least_squares") + self.adv_weight = 0.005 + self.kl_weight = 1e-6 + + def forward(self, z): + return self.autoencoderkl(z) + + def prepare_data(self): + self.train_ds, self.val_ds = get_datasets() + + def train_dataloader(self): + return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True) + + def val_dataloader(self): + return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4) + + def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma): + recons_loss = F.l1_loss(reconstruction.float(), images.float()) + p_loss = self.perceptual_loss(reconstruction.float(), images.float()) + kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3]) + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss) + return loss_g, recons_loss + + def _compute_loss_discriminator(self, images, reconstruction): + logits_fake = self.discriminator(reconstruction.contiguous().detach())[-1] + loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True) + logits_real = self.discriminator(images.contiguous().detach())[-1] + loss_d_real = self.adv_loss(logits_real, target_is_real=True, for_discriminator=True) + discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 + loss_d = self.adv_weight * discriminator_loss + return loss_d, discriminator_loss + + def training_step(self, batch, batch_idx): + optimizer_g, optimizer_d = self.optimizers() + images = batch["image"] + reconstruction, z_mu, z_sigma = self.forward(images) + loss_g, recons_loss = self._compute_loss_generator(images, reconstruction, z_mu, z_sigma) + self.log("recons_loss", recons_loss, batch_size=16, prog_bar=True) + + if self.current_epoch > self.autoencoder_warm_up_n_epochs: + logits_fake = self.discriminator(reconstruction.contiguous().float())[-1] + generator_loss = self.adv_loss(logits_fake, target_is_real=True, for_discriminator=False) + loss_g += self.adv_weight * generator_loss + self.log("gen_loss", generator_loss, batch_size=16, prog_bar=True) + + self.log("loss_g", loss_g, batch_size=16, prog_bar=True) + self.manual_backward(loss_g) + optimizer_g.step() + optimizer_g.zero_grad() + self.untoggle_optimizer(optimizer_g) + + if self.current_epoch > self.autoencoder_warm_up_n_epochs: + loss_d, discriminator_loss = self._compute_loss_discriminator(images, reconstruction) + self.log("disc_loss", loss_d, batch_size=16, prog_bar=True) + self.log("train_loss_d", loss_d, batch_size=16, prog_bar=True) + self.manual_backward(loss_d) + optimizer_d.step() + optimizer_d.zero_grad() + self.untoggle_optimizer(optimizer_d) + + def validation_step(self, batch, batch_idx): + images = batch["image"] + reconstruction, z_mu, z_sigma = self.autoencoderkl(images) + recons_loss = F.l1_loss(images.float(), reconstruction.float()) + self.log("val_loss_d", recons_loss, batch_size=1, prog_bar=True) + self.images = images + self.reconstruction = reconstruction + + def on_validation_epoch_end(self): + # ploting reconstruction + plt.figure(figsize=(2, 2)) + plt.imshow( + torch.cat([self.images[0, 0].cpu(), self.reconstruction[0, 0].cpu()], dim=1), vmin=0, vmax=1, cmap="gray" + ) + plt.tight_layout() + plt.axis("off") + plt.show() + + def configure_optimizers(self): + optimizer_g = torch.optim.Adam(self.autoencoderkl.parameters(), lr=5e-5) + optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4) + return [optimizer_g, optimizer_d], [] + + +# %% [markdown] +# ## Train Autoencoder + +# %% +n_epochs = 75 +val_interval = 10 + + +# initialise the LightningModule +ae_net = AutoEncoder() + +# set up checkpoints + +checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename="best_metric_model") + + +# initialise Lightning's trainer. +trainer = pl.Trainer( + devices=1, + max_epochs=n_epochs, + check_val_every_n_epoch=val_interval, + num_sanity_val_steps=0, + callbacks=checkpoint_callback, + default_root_dir=root_dir, +) + +# train +trainer.fit(ae_net) + + +# %% [markdown] +# ## Rescaling factor +# +# As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor. + + +# %% +def get_scale_factor(): + ae_net.eval() + device = torch.device("cuda:0") + ae_net.to(device) + + train_loader = ae_net.train_dataloader() + check_data = first(train_loader) + z = ae_net.autoencoderkl.encode_stage_2_inputs(check_data["image"].to(ae_net.device)) + print(f"Scaling factor set to {1/torch.std(z)}") + scale_factor = 1 / torch.std(z) + return scale_factor + + +scale_factor = get_scale_factor() + + +# %% [markdown] +# ## Define the LightningModule for DiffusionModelUnet (transforms, network, loaders, etc) +# The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution. + + +# %% +class DiffusionUNET(pl.LightningModule): + def __init__(self): + super().__init__() + self.data_dir = root_dir + self.unet = DiffusionModelUNet( + spatial_dims=2, + in_channels=4, + out_channels=3, + num_res_blocks=2, + num_channels=(256, 256, 512, 1024), + attention_levels=(False, False, True, True), + num_head_channels=(0, 0, 64, 64), + ) + self.max_noise_level = 350 + self.scheduler = DDPMScheduler( + num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195 + ) + self.z = ae_net.autoencoderkl.eval() + + def forward(self, x, timesteps, low_res_timesteps): + return self.unet(x=x, timesteps=timesteps, class_labels=low_res_timesteps) + + def prepare_data(self): + self.train_ds, self.val_ds = get_datasets() + + def train_dataloader(self): + return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True) + + def val_dataloader(self): + return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4) + + def _calculate_loss(self, batch, batch_idx, plt_image=False): + images = batch["image"] + low_res_image = batch["low_res_image"] + with autocast(enabled=True): + with torch.no_grad(): + latent = self.z.encode_stage_2_inputs(images) * scale_factor + + # Noise augmentation + noise = torch.randn_like(latent) + low_res_noise = torch.randn_like(low_res_image) + timesteps = torch.randint( + 0, self.scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device + ).long() + low_res_timesteps = torch.randint( + 0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device + ).long() + + noisy_latent = self.scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps) + noisy_low_res_image = self.scheduler.add_noise( + original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps + ) + + latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1) + + noise_pred = self.forward(latent_model_input, timesteps, low_res_timesteps) + loss = F.mse_loss(noise_pred.float(), noise.float()) + + if plt_image: + # Sampling image during training + sampling_image = low_res_image[0].unsqueeze(0) + latents = torch.randn((1, 3, 16, 16)).to(sampling_image.device) + low_res_noise = torch.randn((1, 1, 16, 16)).to(sampling_image.device) + noise_level = 20 + noise_level = torch.Tensor((noise_level,)).long().to(sampling_image.device) + + noisy_low_res_image = self.scheduler.add_noise( + original_samples=sampling_image, noise=low_res_noise, timesteps=noise_level + ) + self.scheduler.set_timesteps(num_inference_steps=1000) + for t in tqdm(self.scheduler.timesteps, ncols=110): + with autocast(enabled=True): + with torch.no_grad(): + latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1) + noise_pred = self.forward( + latent_model_input, torch.Tensor((t,)).to(sampling_image.device), noise_level + ) + latents, _ = self.scheduler.step(noise_pred, t, latents) + with torch.no_grad(): + decoded = self.z.decode_stage_2_outputs(latents / scale_factor) + low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode="bicubic") + # plot images + + self.images = images + self.low_res_bicubic = low_res_bicubic + self.decoded = decoded + + return loss + + def _plot_image(self, images, low_res_bicubic, decoded): + plt.figure(figsize=(2, 2)) + plt.style.use("default") + plt.imshow( + torch.cat([images[0, 0].cpu(), low_res_bicubic[0, 0].cpu(), decoded[0, 0].cpu()], dim=1), + vmin=0, + vmax=1, + cmap="gray", + ) + plt.tight_layout() + plt.axis("off") + plt.show() + + def training_step(self, batch, batch_idx): + loss = self._calculate_loss(batch, batch_idx) + self.log("train_loss", loss, batch_size=16, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self._calculate_loss(batch, batch_idx, plt_image=True) + self.log("val_loss", loss, batch_size=16, prog_bar=True) + return loss + + def on_validation_epoch_end(self): + self._plot_image(self.images, self.low_res_bicubic, self.decoded) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.unet.parameters(), lr=5e-5) + return optimizer + + +# %% [markdown] +# ## Train Diffusion Model +# +# In order to train the diffusion model to perform super-resolution, we will need to concatenate the latent representation of the high-resolution with the low-resolution image. For this, we create a Diffusion model with `in_channels=4`. Since only the outputted latent representation is interesting, we set `out_channels=3`. +# +# As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler low_res_scheduler to add this noise, with the t step defining the signal-to-noise ratio and use the t value to condition the diffusion model (inputted using class_labels argument). + +# %% +n_epochs = 200 +val_interval = 20 + + +# initialise the LightningModule +d_net = DiffusionUNET() + +# set up checkpoints + +checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename="best_metric_model_dunet") + + +# initialise Lightning's trainer. +trainer = pl.Trainer( + devices=1, + max_epochs=n_epochs, + check_val_every_n_epoch=val_interval, + num_sanity_val_steps=0, + callbacks=checkpoint_callback, + default_root_dir=root_dir, +) + +# train +trainer.fit(d_net) + +# %% [markdown] +# ### Plotting sampling example + +# %% +num_samples = 3 + + +def get_images_to_plot(): + d_net.eval() + device = torch.device("cuda:0") + d_net.to(device) + + val_loader = d_net.val_dataloader() + check_data = first(val_loader) + images = check_data["image"].to(d_net.device) + + sampling_image = check_data["low_res_image"][:num_samples].to(d_net.device) + latents = torch.randn((num_samples, 3, 16, 16)).to(d_net.device) + low_res_noise = torch.randn((num_samples, 1, 16, 16)).to(d_net.device) + noise_level = 10 + noise_level = torch.Tensor((noise_level,)).long().to(d_net.device) + scheduler = d_net.scheduler + noisy_low_res_image = scheduler.add_noise( + original_samples=sampling_image, noise=low_res_noise, timesteps=torch.Tensor((noise_level,)).long() + ) + + scheduler.set_timesteps(num_inference_steps=1000) + for t in tqdm(scheduler.timesteps, ncols=110): + with autocast(enabled=True): + with torch.no_grad(): + latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1) + noise_pred = d_net.forward( + x=latent_model_input, timesteps=torch.Tensor((t,)).to(d_net.device), low_res_timesteps=noise_level + ) + # 2. compute previous image: x_t -> x_t-1 + latents, _ = scheduler.step(noise_pred, t, latents) + + with torch.no_grad(): + decoded = ae_net.autoencoderkl.decode_stage_2_outputs(latents / scale_factor) + return sampling_image, images, decoded + + +sampling_image, images, decoded = get_images_to_plot() + +# %% +low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode="bicubic") +fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8)) +axs[0, 0].set_title("Original image") +axs[0, 1].set_title("Low-resolution Image") +axs[0, 2].set_title("Outputted image") +for i in range(0, num_samples): + axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap="gray") + axs[i, 0].axis("off") + axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap="gray") + axs[i, 1].axis("off") + axs[i, 2].imshow(decoded[i, 0].cpu().detach().numpy(), vmin=0, vmax=1, cmap="gray") + axs[i, 2].axis("off") +plt.tight_layout() + +# %% [markdown] +# ### Clean-up data directory + +# %% +if directory is None: + shutil.rmtree(root_dir)