From ac5982beb0e36831bd33c2632b3e1820bc4a1edd Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Mon, 20 Mar 2023 23:02:59 +0000 Subject: [PATCH] Add more content to tutorial Signed-off-by: Walter Hugo Lopez Pinaya --- .../generative/3d_ldm/3d_ldm_tutorial.ipynb | 104 ++++++++++++------ .../generative/3d_ldm/3d_ldm_tutorial.py | 87 +++++++++------ 2 files changed, 121 insertions(+), 70 deletions(-) diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb index 047fb431..c050db40 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb @@ -5,7 +5,12 @@ "id": "e0a3f076", "metadata": {}, "source": [ - "# 3D Latent Diffusion Model" + "# 3D Latent Diffusion Model\n", + "In this tutorial, we will walk through the process of using the MONAI Generative Models package to generate synthetic data using Latent Diffusion Models (LDM) [1, 2]. Specifically, we will focus on training an LDM to create synthetic brain images from the Brats dataset.\n", + "\n", + "[1] - Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n", + "\n", + "[2] - Pinaya et al. \"Brain imaging generation with latent diffusion models\" https://arxiv.org/abs/2209.07162" ] }, { @@ -13,7 +18,7 @@ "id": "da9e6b23", "metadata": {}, "source": [ - "## Set up imports" + "### Set up imports" ] }, { @@ -106,7 +111,7 @@ "id": "2b02aa6c", "metadata": {}, "source": [ - "## Setup a data directory and download dataset\n", + "### Setup a data directory and download dataset\n", "Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used." ] }, @@ -135,7 +140,8 @@ "id": "74302407", "metadata": {}, "source": [ - "## Download the training set" + "### Prepare data loader for the training set\n", + "Here we will download the Brats dataset using MONAI's `DecathlonDataset` class, and we prepare the data loader for the training set." ] }, { @@ -219,7 +225,7 @@ "id": "1d36e0c4", "metadata": {}, "source": [ - "## Visualise examples from the training set" + "### Visualise examples from the training set" ] }, { @@ -272,7 +278,11 @@ "id": "513d7eee", "metadata": {}, "source": [ - "## Define Networks" + "## Autoencoder KL\n", + "\n", + "### Define Autoencoder KL network\n", + "\n", + "In this section, we will define an autoencoder with KL-regularization for the LDM. The autoencoder's primary purpose is to transform input images into a latent representation that the diffusion model will subsequently learn. By doing so, we can decrease the computational resources required to train the diffusion component, making this approach suitable for learning high-resolution medical images.\n" ] }, { @@ -305,9 +315,8 @@ " spatial_dims=3,\n", " in_channels=1,\n", " out_channels=1,\n", - " num_channels=32,\n", + " num_channels=(32, 64, 64),\n", " latent_channels=3,\n", - " ch_mult=(1, 2, 2),\n", " num_res_blocks=1,\n", " norm_num_groups=16,\n", " attention_levels=(False, False, True),\n", @@ -321,11 +330,6 @@ " num_channels=32,\n", " in_channels=1,\n", " out_channels=1,\n", - " kernel_size=4,\n", - " activation=\"LEAKYRELU\",\n", - " norm=\"BATCH\",\n", - " bias=False,\n", - " padding=1,\n", ")\n", "discriminator.to(device)" ] @@ -335,7 +339,9 @@ "id": "67f94d1b", "metadata": {}, "source": [ - "## Define Losses" + "### Defining Losses\n", + "\n", + "We will also specify the perceptual and adversarial losses, including the involved networks, and the optimizers to use during the training process." ] }, { @@ -386,7 +392,7 @@ "id": "be4fe2d4", "metadata": {}, "source": [ - "## Train AutoEncoder" + "### Train model" ] }, { @@ -569,7 +575,11 @@ " )\n", " epoch_recon_loss_list.append(epoch_loss / (step + 1))\n", " epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))\n", - " epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))" + " epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))\n", + "\n", + "del discriminator\n", + "del loss_perceptual\n", + "torch.cuda.empty_cache()" ] }, { @@ -692,7 +702,33 @@ "id": "fe436141", "metadata": {}, "source": [ - "## Train Diffusion Model" + "## Diffusion Model\n", + "\n", + "### Define diffusion model and scheduler\n", + "\n", + "In this section, we will define the diffusion model that will learn data distribution of the latent representation of the autoencoder. Together with the diffusion model, we define a beta scheduler responsible for defining the amount of noise tahat is added across the diffusion's model Markov chain." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "88cbe609", + "metadata": {}, + "outputs": [], + "source": [ + "unet = DiffusionModelUNet(\n", + " spatial_dims=3,\n", + " in_channels=3,\n", + " out_channels=3,\n", + " num_res_blocks=1,\n", + " num_channels=[32, 64, 64],\n", + " attention_levels=(False, True, True),\n", + " num_head_channels=1,\n", + ")\n", + "unet.to(device)\n", + "\n", + "\n", + "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)" ] }, { @@ -740,26 +776,12 @@ }, { "cell_type": "code", - "execution_count": 42, - "id": "88cbe609", + "execution_count": null, + "id": "7de37f3a", "metadata": {}, "outputs": [], "source": [ - "unet = DiffusionModelUNet(\n", - " spatial_dims=3,\n", - " in_channels=3,\n", - " out_channels=3,\n", - " num_res_blocks=1,\n", - " num_channels=[32, 64, 64],\n", - " attention_levels=(False, True, True),\n", - " num_head_channels=1,\n", - ")\n", - "unet.to(device)\n", - "\n", - "\n", - "scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)\n", - "\n", - "inferer = LatentDiffusionInferer(scheduler)" + "inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)" ] }, { @@ -772,6 +794,14 @@ "optimizer_diff = torch.optim.Adam(params=unet.parameters(), lr=1e-4)" ] }, + { + "cell_type": "markdown", + "id": "4705c795", + "metadata": {}, + "source": [ + "### Train diffusion model" + ] + }, { "cell_type": "code", "execution_count": 44, @@ -1020,7 +1050,9 @@ "id": "c9de4288", "metadata": {}, "source": [ - "## Image generation" + "### Plotting sampling example\n", + "\n", + "Finally, we generate an image with our LDM. For that, we will initialize a latent representation with just noise. Then, we will use the `unet` to perform 1000 denoising steps. In the last step, we decode the latent representation and plot the sampled image." ] }, { @@ -1054,7 +1086,7 @@ "id": "fed68b96", "metadata": {}, "source": [ - "### Visualise Synthetic" + "### Visualise synthetic data" ] }, { diff --git a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py index 7acab22d..6806072e 100644 --- a/tutorials/generative/3d_ldm/3d_ldm_tutorial.py +++ b/tutorials/generative/3d_ldm/3d_ldm_tutorial.py @@ -15,8 +15,13 @@ # --- # # 3D Latent Diffusion Model +# In this tutorial, we will walk through the process of using the MONAI Generative Models package to generate synthetic data using Latent Diffusion Models (LDM) [1, 2]. Specifically, we will focus on training an LDM to create synthetic brain images from the Brats dataset. +# +# [1] - Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 +# +# [2] - Pinaya et al. "Brain imaging generation with latent diffusion models" https://arxiv.org/abs/2209.07162 -# ## Set up imports +# ### Set up imports # + import os @@ -46,14 +51,15 @@ # for reproducibility purposes set a seed set_determinism(42) -# ## Setup a data directory and download dataset +# ### Setup a data directory and download dataset # Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used. directory = os.environ.get("MONAI_DATA_DIRECTORY") root_dir = tempfile.mkdtemp() if directory is None else directory print(root_dir) -# ## Download the training set +# ### Prepare data loader for the training set +# Here we will download the Brats dataset using MONAI's `DecathlonDataset` class, and we prepare the data loader for the training set. # + batch_size = 2 @@ -87,7 +93,7 @@ print(f'Image shape {train_ds[0]["image"].shape}') # - -# ## Visualise examples from the training set +# ### Visualise examples from the training set # + # Plot axial, coronal and sagittal slices of a training sample @@ -107,7 +113,12 @@ # plt.savefig("training_examples.png") # - -# ## Define Networks +# ## Autoencoder KL +# +# ### Define Autoencoder KL network +# +# In this section, we will define an autoencoder with KL-regularization for the LDM. The autoencoder's primary purpose is to transform input images into a latent representation that the diffusion model will subsequently learn. By doing so, we can decrease the computational resources required to train the diffusion component, making this approach suitable for learning high-resolution medical images. +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using {device}") @@ -117,9 +128,8 @@ spatial_dims=3, in_channels=1, out_channels=1, - num_channels=32, + num_channels=(32, 64, 64), latent_channels=3, - ch_mult=(1, 2, 2), num_res_blocks=1, norm_num_groups=16, attention_levels=(False, False, True), @@ -133,16 +143,13 @@ num_channels=32, in_channels=1, out_channels=1, - kernel_size=4, - activation="LEAKYRELU", - norm="BATCH", - bias=False, - padding=1, ) discriminator.to(device) # - -# ## Define Losses +# ### Defining Losses +# +# We will also specify the perceptual and adversarial losses, including the involved networks, and the optimizers to use during the training process. # + l1_loss = L1Loss() @@ -164,7 +171,7 @@ def KL_loss(z_mu, z_sigma): optimizer_g = torch.optim.Adam(params=autoencoder.parameters(), lr=1e-4) optimizer_d = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4) -# ## Train AutoEncoder +# ### Train model # + n_epochs = 100 @@ -234,6 +241,10 @@ def KL_loss(z_mu, z_sigma): epoch_recon_loss_list.append(epoch_loss / (step + 1)) epoch_gen_loss_list.append(gen_epoch_loss / (step + 1)) epoch_disc_loss_list.append(disc_epoch_loss / (step + 1)) + +del discriminator +del loss_perceptual +torch.cuda.empty_cache() # - plt.style.use("ggplot") @@ -271,25 +282,11 @@ def KL_loss(z_mu, z_sigma): ax = axs[2] ax.imshow(img[img.shape[0] // 2, ...], cmap="gray") -# ## Train Diffusion Model - -# ### Scaling factor -# -# As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) can affect the results obtained with the LDM, if the standard deviation of the latent space distribution drifts too much from that of a Gaussian. For this reason, it is best practice to use a scaling factor to adapt this standard deviation. +# ## Diffusion Model # -# _Note: In case where the latent space is close to a Gaussian distribution, the scaling factor will be close to one, and the results will not differ from those obtained when it is not used._ +# ### Define diffusion model and scheduler # - -# + -with torch.no_grad(): - with autocast(enabled=True): - z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device)) - -print(f"Scaling factor set to {1/torch.std(z)}") -scale_factor = 1 / torch.std(z) -# - - -# We define the inferer using the scale factor: +# In this section, we will define the diffusion model that will learn data distribution of the latent representation of the autoencoder. Together with the diffusion model, we define a beta scheduler responsible for defining the amount of noise tahat is added across the diffusion's model Markov chain. # + unet = DiffusionModelUNet( @@ -305,12 +302,32 @@ def KL_loss(z_mu, z_sigma): scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="scaled_linear", beta_start=0.0015, beta_end=0.0195) +# - + +# ### Scaling factor +# +# As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) can affect the results obtained with the LDM, if the standard deviation of the latent space distribution drifts too much from that of a Gaussian. For this reason, it is best practice to use a scaling factor to adapt this standard deviation. +# +# _Note: In case where the latent space is close to a Gaussian distribution, the scaling factor will be close to one, and the results will not differ from those obtained when it is not used._ +# + +# + +with torch.no_grad(): + with autocast(enabled=True): + z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device)) -inferer = LatentDiffusionInferer(scheduler) +print(f"Scaling factor set to {1/torch.std(z)}") +scale_factor = 1 / torch.std(z) # - +# We define the inferer using the scale factor: + +inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor) + optimizer_diff = torch.optim.Adam(params=unet.parameters(), lr=1e-4) +# ### Train diffusion model + # + n_epochs = 150 epoch_loss_list = [] @@ -365,7 +382,9 @@ def KL_loss(z_mu, z_sigma): plt.legend(prop={"size": 14}) plt.show() -# ## Image generation +# ### Plotting sampling example +# +# Finally, we generate an image with our LDM. For that, we will initialize a latent representation with just noise. Then, we will use the `unet` to perform 1000 denoising steps. In the last step, we decode the latent representation and plot the sampled image. # + autoencoder.eval() @@ -379,7 +398,7 @@ def KL_loss(z_mu, z_sigma): ) # - -# ### Visualise Synthetic +# ### Visualise synthetic data idx = 0 img = synthetic_images[idx, channel].detach().cpu().numpy() # images