diff --git a/generative/networks/nets/diffusion_model_unet.py b/generative/networks/nets/diffusion_model_unet.py index fe8bf95c..aafde6b4 100644 --- a/generative/networks/nets/diffusion_model_unet.py +++ b/generative/networks/nets/diffusion_model_unet.py @@ -1890,3 +1890,176 @@ def forward( h = self.out(h) return h + + +class DiffusionModelEncoder(nn.Module): + """ + Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on + Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306). + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes. + upcast_attention: if True, upcast attention operations to full precision. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim." + ) + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups") + if len(num_channels) != len(attention_levels): + raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) # - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + + self.down_blocks.append(down_block) + + self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + for downsample_block in self.down_blocks: + h, _ = downsample_block(hidden_states=h, temb=emb, context=context) + + h = h.reshape(h.shape[0], -1) + output = self.out(h) + + return output diff --git a/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.ipynb b/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.ipynb new file mode 100644 index 00000000..71e58a54 --- /dev/null +++ b/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.ipynb @@ -0,0 +1,1458 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2470cf02", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) MONAI Consortium\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "63d95da6", + "metadata": {}, + "source": [ + "# Diffusion Models for Medical Anomaly Detection with Classifier Guidance\n", + "\n", + "This tutorial illustrates how to use MONAI for training a 2D gradient-guided anomaly detection using DDIMs [1].\n", + "\n", + "We train a diffusion model on 2D slices of brain MR images. A classification model is trained to predict whether the given slice shows a tumor or not.\\\n", + "We then translate an input slice to its healthy reconstruction using DDIMs.\\\n", + "Anomaly detection is performed by taking the difference between input and output, as proposed in [1].\n", + "\n", + "[1] - Wolleb et al. \"Diffusion Models for Medical Anomaly Detection\" https://arxiv.org/abs/2203.04306\n", + "\n", + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "75f2d5f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "running install\n", + "/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.\n", + " warnings.warn(\n", + "/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.\n", + " warnings.warn(\n", + "running bdist_egg\n", + "running egg_info\n", + "writing generative.egg-info/PKG-INFO\n", + "writing dependency_links to generative.egg-info/dependency_links.txt\n", + "writing requirements to generative.egg-info/requires.txt\n", + "writing top-level names to generative.egg-info/top_level.txt\n", + "reading manifest file 'generative.egg-info/SOURCES.txt'\n", + "writing manifest file 'generative.egg-info/SOURCES.txt'\n", + "installing library code to build/bdist.linux-x86_64/egg\n", + "running install_lib\n", + "warning: install_lib: 'build/lib' does not exist -- no Python modules to install\n", + "\n", + "creating build/bdist.linux-x86_64/egg\n", + "creating build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying generative.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying generative.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying generative.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying generative.egg-info/requires.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "copying generative.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO\n", + "zip_safe flag not set; analyzing archive contents...\n", + "creating 'dist/generative-0.1.0-py3.10.egg' and adding 'build/bdist.linux-x86_64/egg' to it\n", + "removing 'build/bdist.linux-x86_64/egg' (and everything under it)\n", + "Processing generative-0.1.0-py3.10.egg\n", + "Removing /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/generative-0.1.0-py3.10.egg\n", + "Copying generative-0.1.0-py3.10.egg to /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages\n", + "generative 0.1.0 is already the active version in easy-install.pth\n", + "\n", + "Installed /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/generative-0.1.0-py3.10.egg\n", + "Processing dependencies for generative==0.1.0\n", + "Searching for monai-weekly==1.2.dev2304\n", + "Best match: monai-weekly 1.2.dev2304\n", + "Adding monai-weekly 1.2.dev2304 to easy-install.pth file\n", + "\n", + "Using /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages\n", + "Searching for numpy==1.23.2\n", + "Best match: numpy 1.23.2\n", + "Adding numpy 1.23.2 to easy-install.pth file\n", + "Installing f2py script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "Installing f2py3 script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "Installing f2py3.10 script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "\n", + "Using /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages\n", + "Searching for torch==1.12.1\n", + "Best match: torch 1.12.1\n", + "Adding torch 1.12.1 to easy-install.pth file\n", + "Installing convert-caffe2-to-onnx script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "Installing convert-onnx-to-caffe2 script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "Installing torchrun script to /home/juliawolleb/anaconda3/envs/experiment/bin\n", + "\n", + "Using /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages\n", + "Searching for typing-extensions==4.3.0\n", + "Best match: typing-extensions 4.3.0\n", + "Adding typing-extensions 4.3.0 to easy-install.pth file\n", + "\n", + "Using /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages\n", + "Finished processing dependencies for generative==0.1.0\n" + ] + } + ], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[pillow, tqdm]\"\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "!python -c \"import seaborn\" || pip install -q seaborn" + ] + }, + { + "cell_type": "markdown", + "id": "6b766027", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "id": "972ed3f3", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "path ['/home/juliawolleb/PycharmProjects/MONAI/GenerativeModels/tutorials/generative/anomaly_detection/classifier_guidance_anomalydetection', '/home/juliawolleb/anaconda3/envs/experiment/lib/python310.zip', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/lib-dynload', '', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/PyYAML-6.0-py3.10-linux-x86_64.egg', '/home/juliawolleb/PycharmProjects/Python_Tutorials/Calgary_Infants/calgary/HD-BET', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/lpips-0.1.4-py3.10.egg', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/tqdm-4.64.1-py3.10.egg', '/home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/generative-0.1.0-py3.10.egg', '/home/juliawolleb/PycharmProjects/MONAI/GenerativeModels/', '/home/juliawolleb/PycharmProjects/MONAI/GenerativeModels/', '/home/juliawolleb/PycharmProjects/MONAI/GenerativeModels/']\n", + "MONAI version: 1.2.dev2304\n", + "Numpy version: 1.23.2\n", + "Pytorch version: 1.12.1\n", + "MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n", + "MONAI rev id: 9a57be5aab9f2c2a134768c0c146399150e247a0\n", + "MONAI __file__: /home/juliawolleb/anaconda3/envs/experiment/lib/python3.10/site-packages/monai/__init__.py\n", + "\n", + "Optional dependencies:\n", + "Pytorch Ignite version: 0.4.10\n", + "ITK version: 5.3.0\n", + "Nibabel version: 4.0.1\n", + "scikit-image version: 0.19.3\n", + "Pillow version: 9.2.0\n", + "Tensorboard version: 2.12.0\n", + "gdown version: 4.6.4\n", + "TorchVision version: 0.13.1\n", + "tqdm version: 4.64.1\n", + "lmdb version: 1.4.0\n", + "psutil version: 5.9.4\n", + "pandas version: 1.5.3\n", + "einops version: 0.6.0\n", + "transformers version: 4.21.3\n", + "mlflow version: 2.1.1\n", + "pynrrd version: 1.0.0\n", + "\n", + "For details about installing the optional dependencies, please visit:\n", + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", + "\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import tempfile\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from monai import transforms\n", + "from monai.apps import DecathlonDataset\n", + "from monai.config import print_config\n", + "from monai.data import DataLoader\n", + "from monai.utils import set_determinism\n", + "from torch.cuda.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "\n", + "from generative.inferers import DiffusionInferer\n", + "from generative.networks.nets.diffusion_model_unet import DiffusionModelEncoder, DiffusionModelUNet\n", + "from generative.networks.schedulers.ddim import DDIMScheduler\n", + "\n", + "torch.multiprocessing.set_sharing_strategy(\"file_system\")\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "7d4ff515", + "metadata": {}, + "source": [ + "## Setup data directory" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "8b4323e7", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory" + ] + }, + { + "cell_type": "markdown", + "id": "99175d50", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "34ea510f", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [], + "source": [ + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "c3f70dd1-236a-47ff-a244-575729ad92ba", + "metadata": { + "tags": [] + }, + "source": [ + "## Preprocessing of the BRATS Dataset in 2D slices for training\n", + "We download the BRATS training dataset from the Decathlon dataset. \\\n", + "We slice the volumes in axial 2D slices, and assign slice-wise labels (0 for healthy, 1 for diseased) to all slices.\n", + "Here we use transforms to augment the training dataset:\n", + "\n", + "1. `LoadImaged` loads the brain MR images from files.\n", + "1. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", + "1. `ScaleIntensityRangePercentilesd` takes the lower and upper intensity percentiles and scales them to [0, 1].\n" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "c68d2d91-9a0b-4ac1-ae49-f4a64edbd82a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ": Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.\n" + ] + } + ], + "source": [ + "channel = 0 # 0 = Flair\n", + "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", + "\n", + "train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", + " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n", + " transforms.AddChanneld(keys=[\"image\"]),\n", + " transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n", + " transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", + " transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n", + " transforms.CenterSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 44)),\n", + " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n", + " transforms.RandSpatialCropd(keys=[\"image\", \"label\"], roi_size=(64, 64, 1), random_size=False),\n", + " transforms.Lambdad(keys=[\"image\", \"label\"], func=lambda x: x.squeeze(-1)),\n", + " transforms.CopyItemsd(keys=[\"label\"], times=1, names=[\"slice_label\"]),\n", + " transforms.Lambdad(keys=[\"slice_label\"], func=lambda x: 0.0 if x.sum() > 0 else 1.0),\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 107, + "id": "da1927b0", + "metadata": { + "jupyter": { + "outputs_hidden": false + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|████████████████████████| 388/388 [03:02<00:00, 2.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of training data: 388\n", + "Train image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "batch_size = 64\n", + "\n", + "train_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"training\", # validation\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "\n", + "print(f\"Length of training data: {len(train_ds)}\") # this gives the number of patients in the training set\n", + "print(f'Train image shape {train_ds[0][\"image\"].shape}')\n", + "\n", + "train_loader = DataLoader(\n", + " train_ds, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fac55e9d", + "metadata": { + "tags": [] + }, + "source": [ + "## Preprocessing of the BRATS Dataset in 2D slices for validation\n", + "We download the BRATS validation dataset from the Decathlon dataset, and define the dataloader to load 2D slices for validation.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "73d72110-a8b3-4e03-91cc-1dab4d5a7b87", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|██████████████████████████| 96/96 [00:48<00:00, 2.00it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Length of training data: 96\n", + "Validation Image shape torch.Size([1, 64, 64])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "val_ds = DecathlonDataset(\n", + " root_dir=root_dir,\n", + " task=\"Task01_BrainTumour\",\n", + " section=\"validation\",\n", + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", + " num_workers=4,\n", + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", + " seed=0,\n", + " transform=train_transforms,\n", + ")\n", + "print(f\"Length of training data: {len(val_ds)}\")\n", + "print(f'Validation Image shape {val_ds[0][\"image\"].shape}')\n", + "\n", + "val_loader = DataLoader(\n", + " val_ds, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True, persistent_workers=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "08428bc6", + "metadata": {}, + "source": [ + "## Define network, scheduler, optimizer, and inferer\n", + "At this step, we instantiate the MONAI components to create a DDIM, the UNET, the noise scheduler, and the inferer used for training and sampling. We are using\n", + "the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms\n", + "in the 3rd level (`num_head_channels=64`).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "id": "bee5913e", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\")\n", + "\n", + "model = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " num_channels=(64, 64, 64),\n", + " attention_levels=(False, False, True),\n", + " num_res_blocks=1,\n", + " num_head_channels=64,\n", + " with_conditioning=False,\n", + ")\n", + "model.to(device)\n", + "\n", + "scheduler = DDIMScheduler(num_train_timesteps=1000)\n", + "\n", + "optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)\n", + "\n", + "inferer = DiffusionInferer(scheduler)" + ] + }, + { + "cell_type": "markdown", + "id": "2a4d3ab2", + "metadata": { + "tags": [] + }, + "source": [ + "## Model training of the diffusion model\n", + "We train our diffusion model for 2000 epochs." + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "id": "6c0ed909", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 Validation loss 0.9828271865844727\n", + "Epoch 20 Validation loss 0.45277565717697144\n", + "Epoch 40 Validation loss 0.16044068336486816\n", + "Epoch 60 Validation loss 0.06908729672431946\n", + "Epoch 80 Validation loss 0.037922561168670654\n", + "Epoch 100 Validation loss 0.024700244888663292\n", + "Epoch 120 Validation loss 0.02825773134827614\n", + "Epoch 140 Validation loss 0.01575350947678089\n", + "Epoch 160 Validation loss 0.02807718887925148\n", + "Epoch 180 Validation loss 0.03635002672672272\n", + "Epoch 200 Validation loss 0.018522320315241814\n", + "Epoch 220 Validation loss 0.020984284579753876\n", + "Epoch 240 Validation loss 0.02985953912138939\n", + "Epoch 260 Validation loss 0.018604595214128494\n", + "Epoch 280 Validation loss 0.02505004033446312\n", + "Epoch 300 Validation loss 0.018166495487093925\n", + "Epoch 320 Validation loss 0.012706207111477852\n", + "Epoch 340 Validation loss 0.03222103416919708\n", + "Epoch 360 Validation loss 0.010545151308178902\n", + "Epoch 380 Validation loss 0.017768580466508865\n", + "Epoch 400 Validation loss 0.023036960512399673\n", + "Epoch 420 Validation loss 0.023991823196411133\n", + "Epoch 440 Validation loss 0.014143284410238266\n", + "Epoch 460 Validation loss 0.010133783333003521\n", + "Epoch 480 Validation loss 0.019768211990594864\n", + "Epoch 500 Validation loss 0.016018100082874298\n", + "Epoch 520 Validation loss 0.016411196440458298\n", + "Epoch 540 Validation loss 0.012067019008100033\n", + "Epoch 560 Validation loss 0.017793692648410797\n", + "Epoch 580 Validation loss 0.015390219166874886\n", + "Epoch 600 Validation loss 0.015438873320817947\n", + "Epoch 620 Validation loss 0.019228052347898483\n", + "Epoch 640 Validation loss 0.022589124739170074\n", + "Epoch 660 Validation loss 0.022526469081640244\n", + "Epoch 680 Validation loss 0.0310574471950531\n", + "Epoch 700 Validation loss 0.016018839552998543\n", + "Epoch 720 Validation loss 0.018153013661503792\n", + "Epoch 740 Validation loss 0.01506253331899643\n", + "Epoch 760 Validation loss 0.00914084818214178\n", + "Epoch 780 Validation loss 0.017407484352588654\n", + "Epoch 800 Validation loss 0.013946758583188057\n", + "Epoch 820 Validation loss 0.013289306312799454\n", + "Epoch 840 Validation loss 0.007855996489524841\n", + "Epoch 860 Validation loss 0.01187637448310852\n", + "Epoch 880 Validation loss 0.018494905903935432\n", + "Epoch 900 Validation loss 0.009516816586256027\n", + "Epoch 920 Validation loss 0.030950400978326797\n", + "Epoch 940 Validation loss 0.017931077629327774\n", + "Epoch 960 Validation loss 0.017525378614664078\n", + "Epoch 980 Validation loss 0.016576599329710007\n", + "Epoch 1000 Validation loss 0.007525463588535786\n", + "Epoch 1020 Validation loss 0.008745957165956497\n", + "Epoch 1040 Validation loss 0.023068588227033615\n", + "Epoch 1060 Validation loss 0.023049402981996536\n", + "Epoch 1080 Validation loss 0.020367465913295746\n", + "Epoch 1100 Validation loss 0.026941468939185143\n", + "Epoch 1120 Validation loss 0.019598377868533134\n", + "Epoch 1140 Validation loss 0.023052945733070374\n", + "Epoch 1160 Validation loss 0.020239276811480522\n", + "Epoch 1180 Validation loss 0.009076420217752457\n", + "Epoch 1200 Validation loss 0.011559909209609032\n", + "Epoch 1220 Validation loss 0.023455770686268806\n", + "Epoch 1240 Validation loss 0.015224231407046318\n", + "Epoch 1260 Validation loss 0.020417172461748123\n", + "Epoch 1280 Validation loss 0.025817634537816048\n", + "Epoch 1300 Validation loss 0.012675277888774872\n", + "Epoch 1320 Validation loss 0.014165625907480717\n", + "Epoch 1340 Validation loss 0.021743204444646835\n", + "Epoch 1360 Validation loss 0.00959782674908638\n", + "Epoch 1380 Validation loss 0.014942880719900131\n", + "Epoch 1400 Validation loss 0.033313099294900894\n", + "Epoch 1420 Validation loss 0.025836177170276642\n", + "Epoch 1440 Validation loss 0.015067282132804394\n", + "Epoch 1460 Validation loss 0.01235564611852169\n", + "Epoch 1480 Validation loss 0.012111244723200798\n", + "Epoch 1500 Validation loss 0.00833088904619217\n", + "Epoch 1520 Validation loss 0.01528056338429451\n", + "Epoch 1540 Validation loss 0.017444560304284096\n", + "Epoch 1560 Validation loss 0.014621825888752937\n", + "Epoch 1580 Validation loss 0.019431518390774727\n", + "Epoch 1600 Validation loss 0.016186822205781937\n", + "Epoch 1620 Validation loss 0.02027059532701969\n", + "Epoch 1640 Validation loss 0.01720491796731949\n", + "Epoch 1660 Validation loss 0.011756360530853271\n", + "Epoch 1680 Validation loss 0.02627478912472725\n", + "Epoch 1700 Validation loss 0.023451916873455048\n", + "Epoch 1720 Validation loss 0.011613328941166401\n", + "Epoch 1740 Validation loss 0.026256393641233444\n", + "Epoch 1760 Validation loss 0.008156227879226208\n", + "Epoch 1780 Validation loss 0.01597723178565502\n", + "Epoch 1800 Validation loss 0.013070507906377316\n", + "Epoch 1820 Validation loss 0.01726200059056282\n", + "Epoch 1840 Validation loss 0.009824991226196289\n", + "Epoch 1860 Validation loss 0.014878236688673496\n", + "Epoch 1880 Validation loss 0.017673484981060028\n", + "Epoch 1900 Validation loss 0.016455603763461113\n", + "Epoch 1920 Validation loss 0.02442217618227005\n", + "Epoch 1940 Validation loss 0.026278261095285416\n", + "Epoch 1960 Validation loss 0.02376818098127842\n", + "Epoch 1980 Validation loss 0.016214493662118912\n", + "train diffusion completed, total time: 6097.77689909935.\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "n_epochs = 2000\n", + "val_interval = 20\n", + "epoch_loss_list = []\n", + "val_epoch_loss_list = []\n", + "\n", + "scaler = GradScaler()\n", + "total_start = time.time()\n", + "\n", + "for epoch in range(n_epochs):\n", + " model.train()\n", + " epoch_loss = 0\n", + "\n", + " for step, data in enumerate(train_loader):\n", + " images = data[\"image\"].to(device)\n", + " classes = data[\"slice_label\"].to(device)\n", + " optimizer.zero_grad(set_to_none=True)\n", + " timesteps = torch.randint(0, 1000, (len(images),)).to(device) # pick a random time step t\n", + "\n", + " with autocast(enabled=True):\n", + " # Generate random noise\n", + " noise = torch.randn_like(images).to(device)\n", + "\n", + " # Get model prediction\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " scaler.scale(loss).backward()\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " epoch_loss += loss.item()\n", + " epoch_loss_list.append(epoch_loss / (step + 1))\n", + "\n", + " if (epoch) % val_interval == 0:\n", + " model.eval()\n", + " val_epoch_loss = 0\n", + "\n", + " for step, data in enumerate(val_loader):\n", + " images = data[\"image\"].to(device)\n", + " classes = data[\"slice_label\"].to(device)\n", + " timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n", + " with torch.no_grad():\n", + " with autocast(enabled=True):\n", + " noise = torch.randn_like(images).to(device)\n", + " noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps)\n", + " val_loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " val_epoch_loss += val_loss.item()\n", + " val_epoch_loss_list.append(val_epoch_loss / (step + 1))\n", + " print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (step + 1))\n", + "\n", + "total_time = time.time() - total_start\n", + "print(f\"train diffusion completed, total time: {total_time}.\")\n", + "\n", + "plt.style.use(\"seaborn-bright\")\n", + "plt.title(\"Learning Curves Diffusion Model\", fontsize=20)\n", + "plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", + "plt.plot(\n", + " np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),\n", + " val_epoch_loss_list,\n", + " color=\"C1\",\n", + " linewidth=2.0,\n", + " label=\"Validation\",\n", + ")\n", + "plt.yticks(fontsize=12)\n", + "plt.xticks(fontsize=12)\n", + "plt.xlabel(\"Epochs\", fontsize=16)\n", + "plt.ylabel(\"Loss\", fontsize=16)\n", + "plt.legend(prop={\"size\": 14})\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "326101ed-333b-44a9-933f-55760b5d93a4", + "metadata": {}, + "source": [ + "## Check the performance of the diffusion model\n", + "\n", + "We generate a random image from noise to check whether our diffusion model works properly for an image generation task.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 161, + "id": "8f7a9e99-a8a4-4c8f-a42f-17ef91b18585", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████| 1000/1000 [00:23<00:00, 42.86it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "noise = torch.randn((1, 1, 64, 64))\n", + "noise = noise.to(device)\n", + "scheduler.set_timesteps(num_inference_steps=1000)\n", + "with autocast(enabled=True):\n", + " image, intermediates = inferer.sample(\n", + " input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100\n", + " )\n", + "\n", + "chain = torch.cat(intermediates, dim=-1)\n", + "\n", + "plt.style.use(\"default\")\n", + "plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.tight_layout()\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "546f9983-c2e2-4c24-b03a-ebe34627638a", + "metadata": {}, + "source": [ + "## Define the classification model\n", + "First, we define the classification model. It follows the encoder architecture of the diffusion model, combined with linear layers for binary classification between healthy and diseased slices.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 174, + "id": "44cc6928-2525-4e61-8805-15b409097bbb", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DiffusionModelEncoder(\n", + " (conv_in): Convolution(\n", + " (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_embed): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " (down_blocks): ModuleList(\n", + " (0): DownBlock(\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)\n", + " (norm2): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Identity()\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (1): AttnDownBlock(\n", + " (attentions): ModuleList(\n", + " (0): AttentionBlock(\n", + " (norm): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (to_q): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_k): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_v): Linear(in_features=64, out_features=64, bias=True)\n", + " (proj_attn): Linear(in_features=64, out_features=64, bias=True)\n", + " )\n", + " )\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", + " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Convolution(\n", + " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (2): AttnDownBlock(\n", + " (attentions): ModuleList(\n", + " (0): AttentionBlock(\n", + " (norm): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (to_q): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_k): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_v): Linear(in_features=64, out_features=64, bias=True)\n", + " (proj_attn): Linear(in_features=64, out_features=64, bias=True)\n", + " )\n", + " )\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", + " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Identity()\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (out): Sequential(\n", + " (0): Linear(in_features=4096, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=2, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 174, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = torch.device(\"cuda\")\n", + "classifier = DiffusionModelEncoder(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=2,\n", + " num_channels=(32, 64, 64),\n", + " attention_levels=(False, True, True),\n", + " num_res_blocks=(1, 1, 1),\n", + " num_head_channels=64,\n", + " with_conditioning=False,\n", + ")\n", + "\n", + "classifier.to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "45fab83a-b4c8-42cb-96c9-4e9f1e191111", + "metadata": {}, + "source": [ + "## Model training of the classification model\n", + "We train our classification model for 1000 epochs.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "de18d5cb-68e7-407c-afe9-8efd7a5a904a", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 9 Validation loss 0.2536351333061854\n", + "Epoch 19 Validation loss 0.3019549027085304\n", + "Epoch 29 Validation loss 0.34552596261103946\n", + "Epoch 39 Validation loss 0.2783070926864942\n", + "Epoch 49 Validation loss 0.28460513055324554\n", + "Epoch 59 Validation loss 0.25296298414468765\n", + "Epoch 69 Validation loss 0.3343521902958552\n", + "Epoch 79 Validation loss 0.2634535978237788\n", + "Epoch 89 Validation loss 0.2862999041875203\n", + "Epoch 99 Validation loss 0.22700381030639014\n", + "Epoch 109 Validation loss 0.27035540093978244\n", + "Epoch 119 Validation loss 0.2451721504330635\n", + "Epoch 129 Validation loss 0.2890484283367793\n", + "Epoch 139 Validation loss 0.27566688507795334\n", + "Epoch 149 Validation loss 0.28788923223813373\n", + "Epoch 159 Validation loss 0.2524748469392459\n", + "Epoch 169 Validation loss 0.3107323000828425\n", + "Epoch 179 Validation loss 0.21660694728295007\n", + "Epoch 189 Validation loss 0.2702282816171646\n", + "Epoch 199 Validation loss 0.2677164326111476\n", + "Epoch 209 Validation loss 0.33349836121002835\n", + "Epoch 219 Validation loss 0.2969249188899994\n", + "Epoch 229 Validation loss 0.268981905033191\n", + "Epoch 239 Validation loss 0.29199230174223584\n", + "Epoch 249 Validation loss 0.2806356226404508\n", + "Epoch 259 Validation loss 0.301661084095637\n", + "Epoch 269 Validation loss 0.25811708470185596\n", + "Epoch 279 Validation loss 0.2599738910794258\n", + "Epoch 289 Validation loss 0.23392533014218012\n", + "Epoch 299 Validation loss 0.2580989971756935\n", + "Epoch 309 Validation loss 0.22807281464338303\n", + "Epoch 319 Validation loss 0.2510971352458\n", + "Epoch 329 Validation loss 0.25221700221300125\n", + "Epoch 339 Validation loss 0.25722870975732803\n", + "Epoch 349 Validation loss 0.2516109471519788\n", + "Epoch 359 Validation loss 0.22627043972412744\n", + "Epoch 369 Validation loss 0.28725822021563846\n", + "Epoch 379 Validation loss 0.2712069054444631\n", + "Epoch 389 Validation loss 0.29460274676481885\n", + "Epoch 399 Validation loss 0.2599460730950038\n", + "Epoch 409 Validation loss 0.22882529348134995\n", + "Epoch 419 Validation loss 0.24265126883983612\n", + "Epoch 429 Validation loss 0.23436561226844788\n", + "Epoch 439 Validation loss 0.25520699471235275\n", + "Epoch 449 Validation loss 0.22466829667488733\n", + "Epoch 459 Validation loss 0.26379595696926117\n", + "Epoch 469 Validation loss 0.23318989326556525\n", + "Epoch 479 Validation loss 0.264743114511172\n", + "Epoch 489 Validation loss 0.25179669509331387\n", + "Epoch 499 Validation loss 0.20064709583918253\n", + "Epoch 509 Validation loss 0.2527008851369222\n", + "Epoch 519 Validation loss 0.24675505111614862\n", + "Epoch 529 Validation loss 0.2267578070362409\n", + "Epoch 539 Validation loss 0.2342942381898562\n", + "Epoch 549 Validation loss 0.2587633654475212\n", + "Epoch 559 Validation loss 0.21963710337877274\n", + "Epoch 569 Validation loss 0.2676527574658394\n", + "Epoch 579 Validation loss 0.25124627848466236\n", + "Epoch 589 Validation loss 0.22307553887367249\n", + "Epoch 599 Validation loss 0.28288981815179187\n", + "Epoch 609 Validation loss 0.2745586136976878\n", + "Epoch 619 Validation loss 0.2356488679846128\n", + "Epoch 629 Validation loss 0.191768117249012\n", + "Epoch 639 Validation loss 0.23102722316980362\n", + "Epoch 649 Validation loss 0.2544248104095459\n", + "Epoch 659 Validation loss 0.23119398951530457\n", + "Epoch 669 Validation loss 0.20733060439427695\n", + "Epoch 679 Validation loss 0.22538802524407706\n", + "Epoch 689 Validation loss 0.216872605184714\n", + "Epoch 699 Validation loss 0.22977381944656372\n", + "Epoch 709 Validation loss 0.21891566862662634\n", + "Epoch 719 Validation loss 0.223398727675279\n", + "Epoch 729 Validation loss 0.24623310069243112\n", + "Epoch 739 Validation loss 0.23960118989149728\n", + "Epoch 749 Validation loss 0.21641289939483008\n", + "Epoch 759 Validation loss 0.21971949686606726\n", + "Epoch 769 Validation loss 0.22835112363100052\n", + "Epoch 779 Validation loss 0.2273434673746427\n", + "Epoch 789 Validation loss 0.18299358462293944\n", + "Epoch 799 Validation loss 0.1827801006535689\n", + "Epoch 809 Validation loss 0.21519174302617708\n", + "Epoch 819 Validation loss 0.1936649220685164\n", + "Epoch 829 Validation loss 0.23625890165567398\n", + "Epoch 839 Validation loss 0.2425163264075915\n", + "Epoch 849 Validation loss 0.16746311262249947\n", + "Epoch 859 Validation loss 0.20408761004606882\n", + "Epoch 869 Validation loss 0.2144848903020223\n", + "Epoch 879 Validation loss 0.23374033719301224\n", + "Epoch 889 Validation loss 0.23659739891688028\n", + "Epoch 899 Validation loss 0.24609535684188208\n", + "Epoch 909 Validation loss 0.2324757898847262\n", + "Epoch 919 Validation loss 0.24446949362754822\n", + "Epoch 929 Validation loss 0.19177630295356116\n", + "Epoch 939 Validation loss 0.2438896174232165\n", + "Epoch 949 Validation loss 0.2519366617004077\n", + "Epoch 959 Validation loss 0.20046784232060114\n", + "Epoch 969 Validation loss 0.21268909921248755\n", + "Epoch 979 Validation loss 0.2184151684244474\n", + "Epoch 989 Validation loss 0.21281357357899347\n", + "Epoch 999 Validation loss 0.21612912913163504\n", + "train completed, total time: 1351.5848128795624.\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "n_epochs = 1000\n", + "val_interval = 10\n", + "epoch_loss_list = []\n", + "val_epoch_loss_list = []\n", + "optimizer_cls = torch.optim.Adam(params=classifier.parameters(), lr=2.5e-5)\n", + "\n", + "\n", + "scaler = GradScaler()\n", + "total_start = time.time()\n", + "for epoch in range(n_epochs):\n", + " classifier.train()\n", + " epoch_loss = 0\n", + "\n", + " for step, data in enumerate(train_loader):\n", + " images = data[\"image\"].to(device)\n", + " classes = data[\"slice_label\"].to(device)\n", + " # classes[classes==2]=0\n", + "\n", + " optimizer_cls.zero_grad(set_to_none=True)\n", + " timesteps = torch.randint(0, 1000, (len(images),)).to(device)\n", + "\n", + " with autocast(enabled=False):\n", + " # Generate random noise\n", + " noise = torch.randn_like(images).to(device)\n", + "\n", + " # Get model prediction\n", + " noisy_img = scheduler.add_noise(images, noise, timesteps) # add t steps of noise to the input image\n", + " pred = classifier(noisy_img, timesteps)\n", + "\n", + " loss = F.cross_entropy(pred, classes.long())\n", + "\n", + " loss.backward()\n", + " optimizer_cls.step()\n", + "\n", + " epoch_loss += loss.item()\n", + " epoch_loss_list.append(epoch_loss / (step + 1))\n", + "\n", + " if (epoch + 1) % val_interval == 0:\n", + " classifier.eval()\n", + " val_epoch_loss = 0\n", + "\n", + " for step, data_val in enumerate(val_loader):\n", + " images = data_val[\"image\"].to(device)\n", + " classes = data_val[\"slice_label\"].to(device)\n", + " timesteps = torch.randint(0, 1, (len(images),)).to(\n", + " device\n", + " ) # check validation accuracy on the original images, i.e., do not add noise\n", + "\n", + " with torch.no_grad():\n", + " with autocast(enabled=False):\n", + " noise = torch.randn_like(images).to(device)\n", + " pred = classifier(images, timesteps)\n", + " val_loss = F.cross_entropy(pred, classes.long(), reduction=\"mean\")\n", + "\n", + " val_epoch_loss += val_loss.item()\n", + " _, predicted = torch.max(pred, 1)\n", + " val_epoch_loss_list.append(val_epoch_loss / (step + 1))\n", + " print(\"Epoch\", epoch, \"Validation loss\", val_epoch_loss / (step + 1))\n", + "\n", + "total_time = time.time() - total_start\n", + "print(f\"train completed, total time: {total_time}.\")\n", + "\n", + "## Learning curves for the Classifier\n", + "\n", + "plt.style.use(\"seaborn-bright\")\n", + "plt.title(\"Learning Curves\", fontsize=20)\n", + "plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color=\"C0\", linewidth=2.0, label=\"Train\")\n", + "plt.plot(\n", + " np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),\n", + " val_epoch_loss_list,\n", + " color=\"C1\",\n", + " linewidth=2.0,\n", + " label=\"Validation\",\n", + ")\n", + "plt.yticks(fontsize=12)\n", + "plt.xticks(fontsize=12)\n", + "plt.xlabel(\"Epochs\", fontsize=16)\n", + "plt.ylabel(\"Loss\", fontsize=16)\n", + "plt.legend(prop={\"size\": 14})\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a676b3fe", + "metadata": {}, + "source": [ + "# Image-to-image translation to a healthy subject\n", + "We pick a diseased subject of the validation set as input image. We want to translate it to its healthy reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": 162, + "id": "fe0d9eac-1477-4d6d-a885-d3c4acb4a781", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "minmax tensor(0.) tensor(1.3396)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAdUAAAHWCAYAAAAhLRNZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVUklEQVR4nO3dW2wW9P3H8QIFCthyEsTDRJgMhOF504ERFzzhPEzN5jHRaXSL0SzTuMULs8VdLNHtSp2LGi+2Oc1OThZP0WhkxgSmqBBEBZUtFFCrLadSDqX/2yX///cre/5fKoXX6/ZN26dPH/rxSfj5G9TX19fXBAD8vw3+oh8AAOwvjCoAFDGqAFDEqAJAEaMKAEWMKgAUMaoAUMSoAkARowoARZr39A8OGjRobz4OANin7cn/gNA7VQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAo0vxFPwDYH/X19YVt2bJlYVuxYkXY3n///bCdc845YRszZkzYvvKVr4QN+O95pwoARYwqABQxqgBQxKgCQBGjCgBFjCoAFBnUl/3b///8g4MG7e3Hwn4mOwKybt26sH3wwQdhGzduXNg6OzvD9vvf/z5sEydODNt5550XtsMPPzxsgwfH/726Zs2asL388sthu/XWW8PW09PTUGttbQ3b7NmzwwYHoj2ZS+9UAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAijtTwue6+++6wzZ07N2xdXV1h2759e9iyIyBDhgwJ2wsvvBC2BQsWhG3Xrl0Nfb0RI0aELfv+fvvb34btkksuCVt2o0z2WEaNGhW2f/zjH2HbsWNH2MaPHx+21157LWy/+MUvwgb7OkdqAKAfGVUAKGJUAaCIUQWAIkYVAIoYVQAo4kjNAWLTpk1pb2tra+jzPvnkk2H797//HbZnn302bDfddFPYtmzZErZf//rXYbviiivCduihh4YtO47S3d0dti996Uthu+WWW8J23333hW3btm1ha25uDtvWrVvDNnr06IY+Z/a8ZEdxsiNK2Wv0pJNOChv0F0dqAKAfGVUAKGJUAaCIUQWAIkYVAIoYVQAo4kjNfiQ74rF27dr0Y9evXx+266+/Pmw//OEPwzZy5Mj0a0buuOOOsJ166qlhy16j2ff/k5/8JGzZ0ZHs6/X29obtxRdfDNt3vvOdsDV6a0x2W1Cjsl8bu3fvDlt2fGnz5s1hy15LLS0tYZs6dWrY4L/lSA0A9COjCgBFjCoAFDGqAFDEqAJAEaMKAEUcqdmP3HnnnWG78cYb04995ZVXwtbR0RG2559/PmwLFiwI2xFHHBG27DjKxo0bw7Zz586wjRkzJmzZ7S/ZkZpZs2aFbeXKlWFrb28P24wZM8KW/VW9//77w5YdK8l+RtlRquz3wV/+8pewHXTQQWHLfg7Zz6+1tTVsCxcuDNv3vve9sMH/xZEaAOhHRhUAihhVAChiVAGgiFEFgCJGFQCKOFIzwDz88MNhy45HrFq1Kv28y5cvD1t2JOOjjz4K25AhQ8KWHalZt25d2AYPjv87sK2tLWzZDT6ZI488MmyTJk0KW3YrTnZrzPbt28OW/YyyY09PPvlk2Hp6esL2zDPPhO3iiy8O2/Dhw8OW/Ryy5yW7pSb73ZS9JqZNmxa25557LmwcuBypAYB+ZFQBoIhRBYAiRhUAihhVAChiVAGgiCM1+6CnnnoqbMOGDQtbdnxgxYoV6dfMjodkRx2y18WuXbvClh23+etf/xq27373u2E7+OCDw5bdYNPc3By27Baeiy66KGzZbTpDhw4NW/Z8Zsdtsr/Gs2fPDtuIESPCtmbNmrBlLrzwwrC1tLSELfv+5s+fH7bFixeHbcuWLWHLbsx59913w8aBy5EaAOhHRhUAihhVAChiVAGgiFEFgCJGFQCKxGcJ+MJ861vfCtsLL7wQtk8//TRs2dGQpqampkWLFoUtO3YxZsyYsN18880NfdzWrVvDlv2T9h07doQtu90mO3Zx1VVXha29vT1svb29YRs1alTYsmMl2c8h+95ff/31sGVHXLLPmR1HyY64ZD+/n/70p2G75557wpY91w888EDYbrvttrAtXbo0bPPmzQsbeKcKAEWMKgAUMaoAUMSoAkARowoARYwqABRxS80Ak93g8vjjj4fthhtuSD/vhg0bwtbZ2Rm27GacTEdHR9hmzpwZtp6enrBlR1Uyra2tYXv66afDdsYZZ4St0eMv2fGQ7Gaf3bt3N9Sy23SGDx8etuxmn+wWns2bN4ctu2Wou7s7bNmvsOyxZK+l7Ial7OjPm2++GTYGPrfUAEA/MqoAUMSoAkARowoARYwqABQxqgBQxJGaASa7+eXss88O27Zt29LPm92Okt0ok33cpEmTwtbV1RW27DhDduRk9OjRYcte5tnxl+wIyB7+1flfRo4cGbbsuc6+948//jhs2fM5bdq0sGWvmU2bNoUtO6qS/R7JjvBkstdgdjtRdhys0RuPsp/R6aefHjYGBkdqAKAfGVUAKGJUAaCIUQWAIkYVAIoYVQAoEl81wT5pwoQJYcuOamTHKpqampqWLFkStm9/+9thW7duXdheeumlsM2aNStsjR45yW5cyT5ndrwnk92AkrXsca5duzZs2T/nP+WUU8L2/vvvhy07OpIdccmOxmSPM/u47Jah7KakQw45JGzZzyE7hpQdJ8puNWr0tcT+wztVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIIzUDzFe/+tWwZTeHfN4NIHPnzg1bdmxm2LBhYZsxY0bYWlpawpYd/1m0aFHYpk+fHrbsaEV2q0p2hCe75SS7rWT16tVhy25cWbZsWdimTJkStux73717d9juu+++sF166aVhy56z7Ov94Q9/CNvll18etsmTJ4ft4osvDtsDDzwQtoMPPjhsZ555Zti2bt0aNg4M3qkCQBGjCgBFjCoAFDGqAFDEqAJAEaMKAEUG9WVXSvznH0xul2DfcM0114QtO8rweT07PpEdPciO4mQflx0N+vGPfxy2p59+OmzZUYcf/ehHYctuJDnttNPC9tRTTzX0Of/1r3+F7aCDDgpbdtvM7bffHra//e1vYctucTnxxBPDdvzxx4dt6dKlYVu+fHnYsmNWY8eODVt2g032cdmRqOy53rBhQ9gef/zxsDEw7MlceqcKAEWMKgAUMaoAUMSoAkARowoARYwqABRxpGY/ctxxx4Vt9OjR6cdmxwSOOeaYsB122GFhe+aZZ8KWHZHIbr7JXq7f//73wzZq1KiwZbfGnHfeeWHLbmPJjvdkN6dkN+a88847YfvlL38ZtrPOOits2dGY7DjKo48+Gra///3vYfvjH/8YtnHjxoXt6quvDlt2C88ll1wStsGD4/cUHR0dYcv+rrS3t4dt1qxZYWtqyo8+sW9wpAYA+pFRBYAiRhUAihhVAChiVAGgiFEFgCLNX/QDoM5dd90Vtp07d6Yf293dHbbm5vhlkt02k92qcvLJJ4ctu1FmxIgRYdu+fXtDj2XixIlhy56X7FhQduvPqlWrwpb9k/2urq6wzZ8/P2yTJk0K29y5c8OWHcNav3592B577LGwbdy4MWwnnHBC2P785z+HLbtNZ/bs2WFbs2ZN2LKf+6233hq27HWWvV7Yf3inCgBFjCoAFDGqAFDEqAJAEaMKAEWMKgAUcUvNPuif//xn2D744IOwjR07NmzZcZOmpqambdu2hS27NSY75pHd4tLS0hK24cOHh623tzds2TGW7IhE9rxlfz2yG2Wyoz+dnZ0NfdzSpUvD9o1vfCNsQ4YMCdtpp50Wtka9+uqrYctuxcleE43KjvC89957YbvuuuvClj2fZ555Zth+9atfhY2BwS01ANCPjCoAFDGqAFDEqAJAEaMKAEWMKgAUcUvNPii7waW9vT1s2bGn7FhMU1N+E03WBg+O/7vso48+Cttxxx0Xtux4T3aMZceOHWHLblxZt25d2I466qiwZc93ditQdnQk+yf7s2bNClt2E8306dPDtjfMmTOnX79eJjtKlv2Msuds5MiRYXNsBu9UAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAibqnZBz322GNhy25U2bRpU9iGDh2afs3seMGSJUvCNnfu3LBlt9SsXbs2bIccckjYsu8jeynv2rUrbNlre8OGDWHLbpTJjhpljyX7nOeff37YDmSLFi0KW/ZzePDBB8N2ww03hO3tt98O2w9+8IOwMfC5pQYA+pFRBYAiRhUAihhVAChiVAGgiFEFgCJuqdkHXXHFFWG7++67w/bGG2+E7ZZbbkm/5ubNm8M2bdq0sC1btixs2a0qhx12WNi6u7vDNnz48LB1dnaGLbtZZOPGjWFra2sL24QJE8KW3dCTHZv5vKNP+7P169eHrbW1NWyvvvpq2LKfbXbkq6urK2yOzZDxThUAihhVAChiVAGgiFEFgCJGFQCKGFUAKOJIzQBz+umnh23y5Mlhu+uuu9LPe+ihh4ZtypQpYTvjjDPC1t7eHrbs+Et2C8iMGTPC1twcv5y3b98etq1bt4YtuxXoZz/7WdimTp0atnnz5oUtO1Lz3HPPhS272WfLli1hy46OZK+J7FjQu+++G7bx48eH7cQTTwzb6tWrw5bdRLNixYqw3XzzzWHLjktBxjtVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIIzUDTHa8Zdy4cWHLjrc0NTU1Pfvss2H72te+Fra1a9eGLTvm8eGHH4Ytu5Fk5cqVYdu5c2fYTjnllIa+XnZ0JDuOctZZZ4Xtk08+CdvEiRPDlh2NefTRRxt6LK+//nrYent7w3bZZZeFra+vL2yLFy8OW3a0KXvOVq1aFbbs2MySJUvCduedd4YNMt6pAkARowoARYwqABQxqgBQxKgCQBGjCgBFHKkZYLJjKsuXLw/bp59+2vDnfeSRR8J29NFHh23Dhg1hu/7668PW1tYWtpkzZ4btd7/7XdimT58etuwoUkdHR9iyYzrZjTlHHnlk2N54442wvfbaa2HLjo4MGTIkbCeccELYdu/eHbZjjz02bNktNdmtMU888UTYenp6wjZnzpywnXPOOWGDvcE7VQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiCM1+5HZs2eH7dprr00/dvLkyWHLbkfJbvq47bbb0q8Z6e7uDtsrr7wStnnz5jX09TZt2hS27DjKW2+9Fbazzz47bNnxpuy4zfjx48OWHbeZMWNG2LJbeL7+9a+HLTu+9ac//Sls2evsqKOOClt2y9KwYcPCBv3NO1UAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoMigvr6+vj36g4MG7e3Hwhdo4cKFYbvwwgvDdu+994bty1/+ctjuueeesF100UVhe/jhh8N24403hq2lpSVsra2tYdu1a1fYJkyYELalS5eGbdq0aQ19vewIT3Z7T3YcpbOzM2zZEa0333wzbL29vWHLbgT67LPPwnbBBReEbcqUKWGbOnVq2OC/tSdz6Z0qABQxqgBQxKgCQBGjCgBFjCoAFDGqAFDEkRr2moceeihsJ510Utiy4ygPPvhg2LLjKN/85jfDdvTRR4dt6NChYctueHnppZfClh0Z2rx5c9h+85vfhG3UqFFhy47NnHrqqWHLvr/sVpwrr7wybGPHjg3bggULwpbdRON3E/3FkRoA6EdGFQCKGFUAKGJUAaCIUQWAIkYVAIo4UsM+J3tJvvfee2FbvHhx2FavXh22c889N2yDB8f/3bly5cqGPu7nP/952O64446wdXR0hK2rqytsPT09YTv88MPDlt0oc+mll4atra0tbJkXX3wxbPPnz2/oc0IlR2oAoB8ZVQAoYlQBoIhRBYAiRhUAihhVACjiSA0HhOxlvnDhwrB98sknYZs5c2bY5syZE7Zrr702bNnfs6uvvjpsRxxxRNimT58etkZlz6ffFeyvHKkBgH5kVAGgiFEFgCJGFQCKGFUAKGJUAaCIIzUAsAccqQGAfmRUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAoYlQBoIhRBYAiRhUAihhVAChiVAGgiFEFgCLNe/oH+/r69ubjAIABzztVAChiVAGgiFEFgCJGFQCKGFUAKGJUAaCIUQWAIkYVAIoYVQAo8j/bDQpRm1Wv1wAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "DiffusionModelEncoder(\n", + " (conv_in): Convolution(\n", + " (conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_embed): Sequential(\n", + " (0): Linear(in_features=32, out_features=128, bias=True)\n", + " (1): SiLU()\n", + " (2): Linear(in_features=128, out_features=128, bias=True)\n", + " )\n", + " (down_blocks): ModuleList(\n", + " (0): DownBlock(\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)\n", + " (norm2): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Identity()\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (1): AttnDownBlock(\n", + " (attentions): ModuleList(\n", + " (0): AttentionBlock(\n", + " (norm): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (to_q): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_k): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_v): Linear(in_features=64, out_features=64, bias=True)\n", + " (proj_attn): Linear(in_features=64, out_features=64, bias=True)\n", + " )\n", + " )\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 32, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", + " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Convolution(\n", + " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " (2): AttnDownBlock(\n", + " (attentions): ModuleList(\n", + " (0): AttentionBlock(\n", + " (norm): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (to_q): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_k): Linear(in_features=64, out_features=64, bias=True)\n", + " (to_v): Linear(in_features=64, out_features=64, bias=True)\n", + " (proj_attn): Linear(in_features=64, out_features=64, bias=True)\n", + " )\n", + " )\n", + " (resnets): ModuleList(\n", + " (0): ResnetBlock(\n", + " (norm1): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (nonlinearity): SiLU()\n", + " (conv1): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (time_emb_proj): Linear(in_features=128, out_features=64, bias=True)\n", + " (norm2): GroupNorm(32, 64, eps=1e-06, affine=True)\n", + " (conv2): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", + " )\n", + " (skip_connection): Identity()\n", + " )\n", + " )\n", + " (downsampler): Downsample(\n", + " (op): Convolution(\n", + " (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (out): Sequential(\n", + " (0): Linear(in_features=4096, out_features=512, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=2, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 162, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idx_unhealthy = np.argwhere(data_val[\"slice_label\"].numpy() == 0).squeeze()\n", + "idx = idx_unhealthy[4] # Pick a random slice of the validation set to be transformed\n", + "inputimg = data_val[\"image\"][idx] # Pick an input slice of the validation set to be transformed\n", + "inputlabel = data_val[\"slice_label\"][idx] # Check whether it is healthy or diseased\n", + "print(\"minmax\", inputimg.min(), inputimg.max())\n", + "\n", + "plt.figure(\"input\" + str(inputlabel))\n", + "plt.imshow(inputimg[0, ...], vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.axis(\"off\")\n", + "plt.tight_layout()\n", + "plt.show()\n", + "\n", + "model.eval()\n", + "classifier.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "0cd48c2d", + "metadata": {}, + "source": [ + "### Encoding the input image in noise with the reversed DDIM sampling scheme\n", + "In order to sample using gradient guidance, we first need to encode the input image in noise by using the reversed DDIM sampling scheme.\\\n", + "We define the number of steps in the noising and denoising process by L.\\\n", + "The encoding process is presented in Equation 6 of the paper \"Diffusion Models for Medical Anomaly Detection\" (https://arxiv.org/pdf/2203.04306.pdf).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "id": "f71e4924", + "metadata": { + "jupyter": { + "outputs_hidden": false + }, + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████| 200/200 [00:05<00:00, 33.36it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "L = 200\n", + "current_img = inputimg[None, ...].to(device)\n", + "scheduler.set_timesteps(num_inference_steps=1000)\n", + "\n", + "progress_bar = tqdm(range(L)) # go back and forth L timesteps\n", + "for t in progress_bar: # go through the noising process\n", + " with autocast(enabled=False):\n", + " with torch.no_grad():\n", + " model_output = model(current_img, timesteps=torch.Tensor((t,)).to(current_img.device))\n", + " current_img, _ = scheduler.reversed_step(model_output, t, current_img)\n", + "\n", + "plt.style.use(\"default\")\n", + "plt.imshow(current_img[0, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.tight_layout()\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "a7c8346a-6296-4800-b978-c10fcdf09779", + "metadata": {}, + "source": [ + "### Denoising process using gradient guidance\n", + "From the noisy image, we apply DDIM sampling scheme for denoising for L steps.\\\n", + "Additionally, we apply gradient guidance using the classifier network towards the desired class label y=0 (healthy). This is presented in Algorithm 2 of https://arxiv.org/pdf/2105.05233.pdf, and in Algorithm 1 of https://arxiv.org/pdf/2203.04306.pdf. \\\n", + "The scale s is used to amplify the gradient." + ] + }, + { + "cell_type": "code", + "execution_count": 173, + "id": "7ab274bd-ea60-4674-b59b-d41de98fee5b", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████| 200/200 [00:15<00:00, 12.79it/s]\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "y = torch.tensor(0) # define the desired class label\n", + "scale = 6 # define the desired gradient scale s\n", + "progress_bar = tqdm(range(L)) # go back and forth L timesteps\n", + "\n", + "for i in progress_bar: # go through the denoising process\n", + " t = L - i\n", + " with autocast(enabled=True):\n", + " with torch.no_grad():\n", + " model_output = model(\n", + " current_img, timesteps=torch.Tensor((t,)).to(current_img.device)\n", + " ).detach() # this is supposed to be epsilon\n", + "\n", + " with torch.enable_grad():\n", + " x_in = current_img.detach().requires_grad_(True)\n", + " logits = classifier(x_in, timesteps=torch.Tensor((t,)).to(current_img.device))\n", + " log_probs = F.log_softmax(logits, dim=-1)\n", + " selected = log_probs[range(len(logits)), y.view(-1)]\n", + " a = torch.autograd.grad(selected.sum(), x_in)[0]\n", + " alpha_prod_t = scheduler.alphas_cumprod[t]\n", + " updated_noise = (\n", + " model_output - (1 - alpha_prod_t).sqrt() * scale * a\n", + " ) # update the predicted noise epsilon with the gradient of the classifier\n", + "\n", + " current_img, _ = scheduler.step(updated_noise, t, current_img)\n", + " torch.cuda.empty_cache()\n", + "\n", + "plt.style.use(\"default\")\n", + "plt.imshow(current_img[0, 0].cpu().detach().numpy(), vmin=0, vmax=1, cmap=\"gray\")\n", + "plt.tight_layout()\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d2e343f8-c6f3-4071-a5e6-771e2343c3bc", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "# Anomaly Detection\n", + "To get the anomaly map, we compute the difference between the input image the output of our image-to-image translation model towards the healthy reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": 175, + "id": "ecffaaf3-a7df-453e-81a9-757113d85084", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbsAAAG7CAYAAABaaTseAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASLElEQVR4nO3dQaht11kH8HUlD9pAK7SSDJLSN4iYQNNCJ9WUYilRFBQcFAXn4qRKB8UMOujrIFAHltJJESdOdNBBKbSgg4hPpNB0UGodtJI3eAEzyIMWNBCFBI5DQc/3f7nrnXvvy//+fsO93tpnnb3PfX82fOvbZ4fD4bAAoNgvXPUCAOCiCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6j3yTv/h2dmLYfStEywFACY3xpHD4Yv3ne3JDoB6wg6AesIOgHrCDoB6wg6AesIOgHrveOuB7QUAXJ0HyyBPdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1HvkqhcAXKQbYeytS1sFXDVPdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANSz9YB3gcsqn0+fk/5U0rzJzrp3rsN/bXwO9PFkB0A9YQdAPWEHQD1hB0A9YQdAPdWY18pO1eCO927M2a0anL5TWsP0s/9AmPPzMPb2cHz3O02VldPnrLXW+zbmpD9/VZx08WQHQD1hB0A9YQdAPWEHQD1hB0A9YQdAPVsPrpVT3u60jWGnbD01Rk7bCE75Welcvx3GvnfOz7mfaRtBkrZGTNL63j8c/8+Nz4Gr58kOgHrCDoB6wg6AesIOgHrCDoB6qjEvxW4D5qlabqqUWytX8u1U7D2xMeeNMDZVPKZ1p2bG0/nST3tq+Jzm/DiMTVIVaWo6fW84vlM9meak6zpVXe7+lqdru1NNm9awWwFLO092ANQTdgDUE3YA1BN2ANQTdgDUE3YA1LP14FKkcuhURj2N7W4veHxjzlQankrGd8vTz7uGtdZ6amPOVHKf1p3K9Kfrt7NtY635O6X1vTIcT7+VnRL+dF13xybT1o10L+A4T3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUs/XgyqXO+Dud59MbEXZMZd47byJYa687fzK9IeCxMGfagvFamJO2Ebw+HN99k8NO2f/NMDZJv5VpDbtl/9NveTq+1rylw5sNOD9PdgDUE3YA1BN2ANQTdgDUE3YA1FONeSlS9WS6BTuVi6kCcKrm2zlf+pxkp7H0TvVduubTZ+1cu7X2Gj6n831gY84kVfvuNChPFZypofiOne8Lx3myA6CesAOgnrADoJ6wA6CesAOgnrADoJ6tByc1lWvvlHivNZdy75b9T581lbqvNZd/p3Wn9U3zUon8ZTWW/ngYu3NJa1hrrR8Ox9M1mppbp3WnP/+PDsfTvXgzjO1sOTnl73Wt3HSadp7sAKgn7ACoJ+wAqCfsAKgn7ACoJ+wAqGfrwbmlkvup1DyVSqdbMM2byszXyqXXbwzHU9f+afvDznaAteby77Rd4dEwNpW7p/uU1jfZKXdPWw923qLwWpgz3aenwpyfhLFXwtiO6beXTH8br4c5aXsG15knOwDqCTsA6gk7AOoJOwDqCTsA6qnGPKmdSrCpim6tuQIwVfKlsel8qcpv+k7pp5O+007V4GNhbKryS+ubqlnTtUuVlTv3PVWETud7MsyZ1peaHz8TxqYG0ulepPv+9MacNDZ5O4xN1zXdd1p4sgOgnrADoJ6wA6CesAOgnrADoJ6wA6CerQfnlkrQp8uZSqhTw+I0Nkll8FPj3zRnKstOzaN3zpfmpHLyqYF0KidPjYQn6XzT2lPZ/05D8XS+U37OWvNWgSQ1y57uYbrv03aUnYbYa+Xvu2O6tqf+HE7Bkx0A9YQdAPWEHQD1hB0A9YQdAPVUY57UTkPZqZpwrbmqK1UTpua+U3VbqtibfiKpQjJdh5snXENaRzrfs8Pxfwpz0n3aaVj8sFfsTQ2207p3mlvvNNFOc1JF6NQAfKpSvt/5pmu0WwHLRfJkB0A9YQdAPWEHQD1hB0A9YQdAPWEHQD1bD85tp6w4lUpP5ctpLDVh3imD32ncPJVxpznJTqPlNC/9tF/aWMObYWz6TaRS/NTUebqHabvHVCK/cy/Wmn/Lu2X1O//VTGtP1zV93zsnXMNae9tyuCqe7ACoJ+wAqCfsAKgn7ACoJ+wAqCfsAKhn68G57XQtT6XIO7cgzdkpNX9043xpi0Na31S6nt4qsPOd0pshdsvxJ9O1eC3MSVsPpq0l9zbPd0rp2u28IWDnrQfpu6atEZO0hp23caRr5K0HV8WTHQD1hB0A9YQdAPWEHQD1hB0A9VRjXop0mVMF21TVlaryUlXjVFn272HOVKmWKuJSo95UJTlJTZgnqSJuGkuVd+l8U9XlTiXfWvn+Tqb17TYl3vmv4W4Ym37nqRLy1eH4Y2HOTrXjqZtl8zDyZAdAPWEHQD1hB0A9YQdAPWEHQD1hB0A9Ww8uxU7Z+lqnb5q8W4Z+TGoEnUrup3LtH4c5j4ex6fumLQ7T+tKfQ7p2U1n91Pz4fuebvlO6DtPad+/TtPadrTJrzb/znd9/Wncam67FTjPqtfa2LEzrs43honmyA6CesAOgnrADoJ6wA6CesAOgnrADoJ6tB5dit/v9Tif7nVua3mAwrSGVa6dy8qnEOl2jdL5JKtOfvm8qJb8Zxqb17ZaTT9c2XfOprD5tFfiDMPaN4fju1pYnNs432b2u028svaUj/W1M0n3afcMCD8qTHQD1hB0A9YQdAPWEHQD1hB0A9VRjntRUhbXTGDnNSxVdqUn0dLtTNdpU5ffrYc5LYWyqktz9TlPD4vTTTt93ks43jaV1pyrJTwzHXw5zJs+EsW+Fsamy8tPzlM9P615rfe3FYSBdh1M3t57Ot1Nxeb/POu8cjaAvmic7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6tl6cFJTWfFjYU5qcjzdnlSmP5XirzWvL5VyTw18fxDm7DTWTWXc6TtN1zZd12mLyPRd18rX6OZwPN2nJ8PYVI5/L8x5/ujRL62PjTO+vP4inO+p4fjtecrvhK0HT37x+PEv/FVYw3QPd0r+15q3U6Sy//Rbnu7vznYiLponOwDqCTsA6gk7AOoJOwDqCTsA6qnGPKmp0jA1u52q3tZa66fD8amacK254myttV4fjqcKsaki7ukwZ1p3kqre0vW7Mxz/ZJhzezierkO65pN0b385jH3z6NGPHD41znhu/ePR47fX341zbvzsuXHsVz74b0eP3/mP3xjnfOUX/3gc+/xf/+W0inHOfP3SvUhVuDvVk6lJ9E5TZw2fr4onOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOqdHQ6Hwzv6h2e3Lngp7xY7jVzfF8ZSg+GpxDptL0i7SaZ5vxXmvDQcT02OnwljU2n4bkn2dI1S8+hp28TOvVhrrQ8Ox1MD8Pl39LHDm0eP/+iFXxvnPPfn/3D0+JfXl8Y5X19/Oo69sL5y9PinXvveOGd9+z3z2Oe+Ogz8/jxnLPv/+zAn/Y6m3+zuVh7bCB4Wh8Ot+/4bT3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDU89aDS5HKl9NWhulNADfDnJfvu5r/74dhbNpGcDfM+UkYe3w4nrrLp7cHTG89uBnmTNsf0taDdL7pbRLpOz0/jvzLh45vZTj7vbBL6E+OH/7Nz35mnHLjI/P6vvtLnz0+8IWzeQ1356F5+00q+//X4Xh6C0a65lxnnuwAqCfsAKgn7ACoJ+wAqCfsAKinEfS5perJncawU3XiWnPj2tRY+smNNaSi3KlSM1XETdWOa6316eH47TAnVexN1+jZMGeq4JwrJHPj6+kaTdW0a+XvNM1LDbunKsn0552qT6eqxqlCcq18jabPSvdpajqdrt3NMHZ3OJ6ahicaQT8sNIIGgCXsALgGhB0A9YQdAPWEHQD1hB0A9TSCPrdUbjxtS0iXOY1Nn5VKpX8WxqaS9qmRcZqTthekkvtXhuMf3zzf1Pj6p2HO5G4YS+ebtoKk893cGJu2TKw1X6O0TSV5dDiett6ksWnrwUthzrTF4IkwJ22NSFsWaOfJDoB6wg6AesIOgHrCDoB6wg6AesIOgHq2HpzUThf01Cn+7Y3zpfLqqTQ8vcFg+k67bz2YxtJ1+GgYm+btvCnh1TBnRyrFn94qsFYux59Mf8rhPv3qH81j3//6MLB7n6Z1pP+Cpu0KaRtIOt+0dm8vuA482QFQT9gBUE/YAVBP2AFQT9gBUE815pWbKs52pYq4qZFwqrCbqgbTnGSqUEzNnm9vfM7O+tKcVN053cOdyti11vrwcDxVuU7n++Q85fu3wvmmZstp3amh+PRfTbrv94bjqXpSZSXHebIDoJ6wA6CesAOgnrADoJ6wA6CesAOgnq0HJzWVZe+WQ++c72/C2NSMN21XuDscT2X6j4exae27WxkmqUR+krYK7MzbaUa91lZT57Fp+A/CnJ31TVsS0hrWmr/TG2HOJG1XSH8bp/6N8W7iyQ6AesIOgHrCDoB6wg6AesIOgHqqMU/q1E1oT32+qfLt5TBnqthLlYE7DYvTTzFVd06NqlNl4CRd71S5OK09VYSmysDpPqU50/2Yrs9ae9WYUzPxNGetvz386OjxPzz7TDjfVPmZGk7v3HeuA092ANQTdgDUE3YA1BN2ANQTdgDUE3YA1LP14FqZSutPvcUhNep9fjj+nTBnZ0tAKqufxlJJezrfVO6eth7sNJ1Oc6a176x7rXX3z44f/1z4Tt/96jg0bzFIW1juDcfT9oJT/5Zp4ckOgHrCDoB6wg6AesIOgHrCDoB6wg6AerYecB9TJ/vUgT9tPUhvWJiEcvf3vHD8+H+/uPE5ad0/D2PPDsfT1oO0zWG6tukNBtNnzffpO4d/Hsd+92wa+XBYQzKt49UwxzYCTseTHQD1hB0A9YQdAPWEHQD1hB0A9c4Oh8PhHf3Ds1sXvBQu3lSxt1P1lioN0/mmisdUabgjVVZOn5W+UzIVNZ+6YXH6TpPUCDpVmD49HL8T5qTvO1FxyYM7HG7d9994sgOgnrADoJ6wA6CesAOgnrADoJ6wA6CeRtDXylTmvbONIJWMp/NNDYF3tzJMTr2VIa1h57ruSM23T23aYnCZazjlVhmuO092ANQTdgDUE3YA1BN2ANQTdgDUE3YA1LP1gPXwl3KfelvCZZzrQc53WSX3O9spTu3U2zPgOE92ANQTdgDUE3YA1BN2ANQTdgDUU43JBXjYqzsfdg9zhempPezro4UnOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqCTsA6gk7AOoJOwDqPXLVCwCA/3VjOP7WA53Vkx0A9YQdAPWEHQD1hB0A9YQdAPXOUY353o3Tv70xBwD+r/c/0GxPdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQ7OxwOh6teBABcJE92ANQTdgDUE3YA1BN2ANQTdgDUE3YA1BN2ANQTdgDUE3YA1PsfHnc8lSExf2kAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "diff = abs(inputimg.cpu() - current_img[0, 0].cpu()).detach().numpy()\n", + "plt.style.use(\"default\")\n", + "plt.imshow(diff[0, ...], cmap=\"jet\")\n", + "plt.tight_layout()\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c459ab23-459d-4063-824e-39dac93abb43", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "jupytext": { + "formats": "py:percent,ipynb" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.py b/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.py new file mode 100644 index 00000000..54fdb6f6 --- /dev/null +++ b/tutorials/generative/anomaly_detection/anomalydetection_tutorial_classifier_guidance.py @@ -0,0 +1,505 @@ +# --- +# jupyter: +# jupytext: +# formats: py:percent,ipynb +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.14.4 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% [markdown] +# # Diffusion Models for Medical Anomaly Detection with Classifier Guidance +# +# This tutorial illustrates how to use MONAI for training a 2D gradient-guided anomaly detection using DDIMs [1]. +# +# We train a diffusion model on 2D slices of brain MR images. A classification model is trained to predict whether the given slice shows a tumor or not.\ +# We then translate an input slice to its healthy reconstruction using DDIMs.\ +# Anomaly detection is performed by taking the difference between input and output, as proposed in [1]. +# +# [1] - Wolleb et al. "Diffusion Models for Medical Anomaly Detection" https://arxiv.org/abs/2203.04306 +# +# ## Setup environment + +# %% +# !python -c "import monai" || pip install -q "monai-weekly[pillow, tqdm]" +# !python -c "import matplotlib" || pip install -q matplotlib +# !python -c "import seaborn" || pip install -q seaborn + +# %% [markdown] +# ## Setup imports + +# %% jupyter={"outputs_hidden": false} +import os +import time +import tempfile +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +from monai import transforms +from monai.apps import DecathlonDataset +from monai.config import print_config +from monai.data import DataLoader +from monai.utils import set_determinism +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from generative.inferers import DiffusionInferer +from generative.networks.nets.diffusion_model_unet import DiffusionModelEncoder, DiffusionModelUNet +from generative.networks.schedulers.ddim import DDIMScheduler + +torch.multiprocessing.set_sharing_strategy("file_system") + +print_config() + + +# %% [markdown] +# ## Setup data directory + +# %% jupyter={"outputs_hidden": false} +directory = os.environ.get("MONAI_DATA_DIRECTORY") +root_dir = tempfile.mkdtemp() if directory is None else directory + +# %% [markdown] +# ## Set deterministic training for reproducibility + +# %% jupyter={"outputs_hidden": false} +set_determinism(42) + +# %% [markdown] tags=[] +# ## Preprocessing of the BRATS Dataset in 2D slices for training +# We download the BRATS training dataset from the Decathlon dataset. \ +# We slice the volumes in axial 2D slices, and assign slice-wise labels (0 for healthy, 1 for diseased) to all slices. +# Here we use transforms to augment the training dataset: +# +# 1. `LoadImaged` loads the brain MR images from files. +# 1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape. +# 1. `ScaleIntensityRangePercentilesd` takes the lower and upper intensity percentiles and scales them to [0, 1]. +# + +# %% +channel = 0 # 0 = Flair +assert channel in [0, 1, 2, 3], "Choose a valid channel" + +train_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.EnsureChannelFirstd(keys=["image", "label"]), + transforms.Lambdad(keys=["image"], func=lambda x: x[channel, :, :, :]), + transforms.AddChanneld(keys=["image"]), + transforms.EnsureTyped(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(3.0, 3.0, 2.0), mode=("bilinear", "nearest")), + transforms.CenterSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 44)), + transforms.ScaleIntensityRangePercentilesd(keys="image", lower=0, upper=99.5, b_min=0, b_max=1), + transforms.RandSpatialCropd(keys=["image", "label"], roi_size=(64, 64, 1), random_size=False), + transforms.Lambdad(keys=["image", "label"], func=lambda x: x.squeeze(-1)), + transforms.CopyItemsd(keys=["label"], times=1, names=["slice_label"]), + transforms.Lambdad(keys=["slice_label"], func=lambda x: 0.0 if x.sum() > 0 else 1.0), + ] +) + +# %% jupyter={"outputs_hidden": false} +batch_size = 64 + +train_ds = DecathlonDataset( + root_dir=root_dir, + task="Task01_BrainTumour", + section="training", # validation + cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise + num_workers=4, + download=False, # Set download to True if the dataset hasnt been downloaded yet + seed=0, + transform=train_transforms, +) + +print(f"Length of training data: {len(train_ds)}") # this gives the number of patients in the training set +print(f'Train image shape {train_ds[0]["image"].shape}') + +train_loader = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True +) + +# %% [markdown] tags=[] +# ## Preprocessing of the BRATS Dataset in 2D slices for validation +# We download the BRATS validation dataset from the Decathlon dataset, and define the dataloader to load 2D slices for validation. +# +# + +# %% +val_ds = DecathlonDataset( + root_dir=root_dir, + task="Task01_BrainTumour", + section="validation", + cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise + num_workers=4, + download=False, # Set download to True if the dataset hasnt been downloaded yet + seed=0, + transform=train_transforms, +) +print(f"Length of training data: {len(val_ds)}") +print(f'Validation Image shape {val_ds[0]["image"].shape}') + +val_loader = DataLoader( + val_ds, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=True, persistent_workers=True +) + + +# %% [markdown] +# ## Define network, scheduler, optimizer, and inferer +# At this step, we instantiate the MONAI components to create a DDIM, the UNET, the noise scheduler, and the inferer used for training and sampling. We are using +# the deterministic DDIM scheduler containing 1000 timesteps, and a 2D UNET with attention mechanisms +# in the 3rd level (`num_head_channels=64`). +# + +# %% jupyter={"outputs_hidden": false} +device = torch.device("cuda") + +model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + num_channels=(64, 64, 64), + attention_levels=(False, False, True), + num_res_blocks=1, + num_head_channels=64, + with_conditioning=False, +) +model.to(device) + +scheduler = DDIMScheduler(num_train_timesteps=1000) + +optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5) + +inferer = DiffusionInferer(scheduler) + + +# %% [markdown] tags=[] +# ## Model training of the diffusion model +# We train our diffusion model for 2000 epochs. + +# %% jupyter={"outputs_hidden": false} +n_epochs = 2000 +val_interval = 20 +epoch_loss_list = [] +val_epoch_loss_list = [] + +scaler = GradScaler() +total_start = time.time() + +for epoch in range(n_epochs): + model.train() + epoch_loss = 0 + + for step, data in enumerate(train_loader): + images = data["image"].to(device) + classes = data["slice_label"].to(device) + optimizer.zero_grad(set_to_none=True) + timesteps = torch.randint(0, 1000, (len(images),)).to(device) # pick a random time step t + + with autocast(enabled=True): + # Generate random noise + noise = torch.randn_like(images).to(device) + + # Get model prediction + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) + loss = F.mse_loss(noise_pred.float(), noise.float()) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + epoch_loss += loss.item() + epoch_loss_list.append(epoch_loss / (step + 1)) + + if (epoch) % val_interval == 0: + model.eval() + val_epoch_loss = 0 + + for step, data in enumerate(val_loader): + images = data["image"].to(device) + classes = data["slice_label"].to(device) + timesteps = torch.randint(0, 1000, (len(images),)).to(device) + with torch.no_grad(): + with autocast(enabled=True): + noise = torch.randn_like(images).to(device) + noise_pred = inferer(inputs=images, diffusion_model=model, noise=noise, timesteps=timesteps) + val_loss = F.mse_loss(noise_pred.float(), noise.float()) + + val_epoch_loss += val_loss.item() + val_epoch_loss_list.append(val_epoch_loss / (step + 1)) + print("Epoch", epoch, "Validation loss", val_epoch_loss / (step + 1)) + +total_time = time.time() - total_start +print(f"train diffusion completed, total time: {total_time}.") + +plt.style.use("seaborn-bright") +plt.title("Learning Curves Diffusion Model", fontsize=20) +plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") +plt.plot( + np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)), + val_epoch_loss_list, + color="C1", + linewidth=2.0, + label="Validation", +) +plt.yticks(fontsize=12) +plt.xticks(fontsize=12) +plt.xlabel("Epochs", fontsize=16) +plt.ylabel("Loss", fontsize=16) +plt.legend(prop={"size": 14}) +plt.show() + + +# %% [markdown] +# ## Check the performance of the diffusion model +# +# We generate a random image from noise to check whether our diffusion model works properly for an image generation task. +# +# + +# %% +model.eval() +noise = torch.randn((1, 1, 64, 64)) +noise = noise.to(device) +scheduler.set_timesteps(num_inference_steps=1000) +with autocast(enabled=True): + image, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100 + ) + +chain = torch.cat(intermediates, dim=-1) + +plt.style.use("default") +plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") +plt.tight_layout() +plt.axis("off") +plt.show() + + +# %% [markdown] +# ## Define the classification model +# First, we define the classification model. It follows the encoder architecture of the diffusion model, combined with linear layers for binary classification between healthy and diseased slices. +# + +# %% +device = torch.device("cuda") +classifier = DiffusionModelEncoder( + spatial_dims=2, + in_channels=1, + out_channels=2, + num_channels=(32, 64, 64), + attention_levels=(False, True, True), + num_res_blocks=(1, 1, 1), + num_head_channels=64, + with_conditioning=False, +) + +classifier.to(device) + + +# %% [markdown] +# ## Model training of the classification model +# We train our classification model for 1000 epochs. +# + +# %% + +n_epochs = 1000 +val_interval = 10 +epoch_loss_list = [] +val_epoch_loss_list = [] +optimizer_cls = torch.optim.Adam(params=classifier.parameters(), lr=2.5e-5) + + +scaler = GradScaler() +total_start = time.time() +for epoch in range(n_epochs): + classifier.train() + epoch_loss = 0 + + for step, data in enumerate(train_loader): + images = data["image"].to(device) + classes = data["slice_label"].to(device) + # classes[classes==2]=0 + + optimizer_cls.zero_grad(set_to_none=True) + timesteps = torch.randint(0, 1000, (len(images),)).to(device) + + with autocast(enabled=False): + # Generate random noise + noise = torch.randn_like(images).to(device) + + # Get model prediction + noisy_img = scheduler.add_noise(images, noise, timesteps) # add t steps of noise to the input image + pred = classifier(noisy_img, timesteps) + + loss = F.cross_entropy(pred, classes.long()) + + loss.backward() + optimizer_cls.step() + + epoch_loss += loss.item() + epoch_loss_list.append(epoch_loss / (step + 1)) + + if (epoch + 1) % val_interval == 0: + classifier.eval() + val_epoch_loss = 0 + + for step, data_val in enumerate(val_loader): + images = data_val["image"].to(device) + classes = data_val["slice_label"].to(device) + timesteps = torch.randint(0, 1, (len(images),)).to( + device + ) # check validation accuracy on the original images, i.e., do not add noise + + with torch.no_grad(): + with autocast(enabled=False): + noise = torch.randn_like(images).to(device) + pred = classifier(images, timesteps) + val_loss = F.cross_entropy(pred, classes.long(), reduction="mean") + + val_epoch_loss += val_loss.item() + _, predicted = torch.max(pred, 1) + val_epoch_loss_list.append(val_epoch_loss / (step + 1)) + print("Epoch", epoch, "Validation loss", val_epoch_loss / (step + 1)) + +total_time = time.time() - total_start +print(f"train completed, total time: {total_time}.") + +## Learning curves for the Classifier + +plt.style.use("seaborn-bright") +plt.title("Learning Curves", fontsize=20) +plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train") +plt.plot( + np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)), + val_epoch_loss_list, + color="C1", + linewidth=2.0, + label="Validation", +) +plt.yticks(fontsize=12) +plt.xticks(fontsize=12) +plt.xlabel("Epochs", fontsize=16) +plt.ylabel("Loss", fontsize=16) +plt.legend(prop={"size": 14}) +plt.show() +# %% [markdown] +# # Image-to-image translation to a healthy subject +# We pick a diseased subject of the validation set as input image. We want to translate it to its healthy reconstruction. + +# %% +idx_unhealthy = np.argwhere(data_val["slice_label"].numpy() == 0).squeeze() +idx = idx_unhealthy[4] # Pick a random slice of the validation set to be transformed +inputimg = data_val["image"][idx] # Pick an input slice of the validation set to be transformed +inputlabel = data_val["slice_label"][idx] # Check whether it is healthy or diseased +print("minmax", inputimg.min(), inputimg.max()) + +plt.figure("input" + str(inputlabel)) +plt.imshow(inputimg[0, ...], vmin=0, vmax=1, cmap="gray") +plt.axis("off") +plt.tight_layout() +plt.show() + +model.eval() +classifier.eval() + +# %% [markdown] +# ### Encoding the input image in noise with the reversed DDIM sampling scheme +# In order to sample using gradient guidance, we first need to encode the input image in noise by using the reversed DDIM sampling scheme.\ +# We define the number of steps in the noising and denoising process by L.\ +# The encoding process is presented in Equation 6 of the paper "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/pdf/2203.04306.pdf). +# + +# %% jupyter={"outputs_hidden": false} +L = 200 +current_img = inputimg[None, ...].to(device) +scheduler.set_timesteps(num_inference_steps=1000) + +progress_bar = tqdm(range(L)) # go back and forth L timesteps +for t in progress_bar: # go through the noising process + with autocast(enabled=False): + with torch.no_grad(): + model_output = model(current_img, timesteps=torch.Tensor((t,)).to(current_img.device)) + current_img, _ = scheduler.reversed_step(model_output, t, current_img) + +plt.style.use("default") +plt.imshow(current_img[0, 0].cpu(), vmin=0, vmax=1, cmap="gray") +plt.tight_layout() +plt.axis("off") +plt.show() + + +# %% [markdown] +# ### Denoising process using gradient guidance +# From the noisy image, we apply DDIM sampling scheme for denoising for L steps.\ +# Additionally, we apply gradient guidance using the classifier network towards the desired class label y=0 (healthy). This is presented in Algorithm 2 of https://arxiv.org/pdf/2105.05233.pdf, and in Algorithm 1 of https://arxiv.org/pdf/2203.04306.pdf. \ +# The scale s is used to amplify the gradient. + +# %% +y = torch.tensor(0) # define the desired class label +scale = 6 # define the desired gradient scale s +progress_bar = tqdm(range(L)) # go back and forth L timesteps + +for i in progress_bar: # go through the denoising process + t = L - i + with autocast(enabled=True): + with torch.no_grad(): + model_output = model( + current_img, timesteps=torch.Tensor((t,)).to(current_img.device) + ).detach() # this is supposed to be epsilon + + with torch.enable_grad(): + x_in = current_img.detach().requires_grad_(True) + logits = classifier(x_in, timesteps=torch.Tensor((t,)).to(current_img.device)) + log_probs = F.log_softmax(logits, dim=-1) + selected = log_probs[range(len(logits)), y.view(-1)] + a = torch.autograd.grad(selected.sum(), x_in)[0] + alpha_prod_t = scheduler.alphas_cumprod[t] + updated_noise = ( + model_output - (1 - alpha_prod_t).sqrt() * scale * a + ) # update the predicted noise epsilon with the gradient of the classifier + + current_img, _ = scheduler.step(updated_noise, t, current_img) + torch.cuda.empty_cache() + +plt.style.use("default") +plt.imshow(current_img[0, 0].cpu().detach().numpy(), vmin=0, vmax=1, cmap="gray") +plt.tight_layout() +plt.axis("off") +plt.show() + + +# %% [markdown] +# # Anomaly Detection +# To get the anomaly map, we compute the difference between the input image the output of our image-to-image translation model towards the healthy reconstruction. + + +# %% + +diff = abs(inputimg.cpu() - current_img[0, 0].cpu()).detach().numpy() +plt.style.use("default") +plt.imshow(diff[0, ...], cmap="jet") +plt.tight_layout() +plt.axis("off") +plt.show() + +# %%