Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 68 additions & 36 deletions tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
"id": "e0a3f076",
"metadata": {},
"source": [
"# 3D Latent Diffusion Model"
"# 3D Latent Diffusion Model\n",
"In this tutorial, we will walk through the process of using the MONAI Generative Models package to generate synthetic data using Latent Diffusion Models (LDM) [1, 2]. Specifically, we will focus on training an LDM to create synthetic brain images from the Brats dataset.\n",
"\n",
"[1] - Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n",
"\n",
"[2] - Pinaya et al. \"Brain imaging generation with latent diffusion models\" https://arxiv.org/abs/2209.07162"
]
},
{
"cell_type": "markdown",
"id": "da9e6b23",
"metadata": {},
"source": [
"## Set up imports"
"### Set up imports"
]
},
{
Expand Down Expand Up @@ -106,7 +111,7 @@
"id": "2b02aa6c",
"metadata": {},
"source": [
"## Setup a data directory and download dataset\n",
"### Setup a data directory and download dataset\n",
"Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used."
]
},
Expand Down Expand Up @@ -135,7 +140,8 @@
"id": "74302407",
"metadata": {},
"source": [
"## Download the training set"
"### Prepare data loader for the training set\n",
"Here we will download the Brats dataset using MONAI's `DecathlonDataset` class, and we prepare the data loader for the training set."
]
},
{
Expand Down Expand Up @@ -219,7 +225,7 @@
"id": "1d36e0c4",
"metadata": {},
"source": [
"## Visualise examples from the training set"
"### Visualise examples from the training set"
]
},
{
Expand Down Expand Up @@ -272,7 +278,11 @@
"id": "513d7eee",
"metadata": {},
"source": [
"## Define Networks"
"## Autoencoder KL\n",
"\n",
"### Define Autoencoder KL network\n",
"\n",
"In this section, we will define an autoencoder with KL-regularization for the LDM. The autoencoder's primary purpose is to transform input images into a latent representation that the diffusion model will subsequently learn. By doing so, we can decrease the computational resources required to train the diffusion component, making this approach suitable for learning high-resolution medical images.\n"
]
},
{
Expand Down Expand Up @@ -305,9 +315,8 @@
" spatial_dims=3,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" num_channels=32,\n",
" num_channels=(32, 64, 64),\n",
" latent_channels=3,\n",
" ch_mult=(1, 2, 2),\n",
" num_res_blocks=1,\n",
" norm_num_groups=16,\n",
" attention_levels=(False, False, True),\n",
Expand All @@ -321,11 +330,6 @@
" num_channels=32,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" kernel_size=4,\n",
" activation=\"LEAKYRELU\",\n",
" norm=\"BATCH\",\n",
" bias=False,\n",
" padding=1,\n",
")\n",
"discriminator.to(device)"
]
Expand All @@ -335,7 +339,9 @@
"id": "67f94d1b",
"metadata": {},
"source": [
"## Define Losses"
"### Defining Losses\n",
"\n",
"We will also specify the perceptual and adversarial losses, including the involved networks, and the optimizers to use during the training process."
]
},
{
Expand Down Expand Up @@ -386,7 +392,7 @@
"id": "be4fe2d4",
"metadata": {},
"source": [
"## Train AutoEncoder"
"### Train model"
]
},
{
Expand Down Expand Up @@ -569,7 +575,11 @@
" )\n",
" epoch_recon_loss_list.append(epoch_loss / (step + 1))\n",
" epoch_gen_loss_list.append(gen_epoch_loss / (step + 1))\n",
" epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))"
" epoch_disc_loss_list.append(disc_epoch_loss / (step + 1))\n",
"\n",
"del discriminator\n",
"del loss_perceptual\n",
"torch.cuda.empty_cache()"
]
},
{
Expand Down Expand Up @@ -692,7 +702,33 @@
"id": "fe436141",
"metadata": {},
"source": [
"## Train Diffusion Model"
"## Diffusion Model\n",
"\n",
"### Define diffusion model and scheduler\n",
"\n",
"In this section, we will define the diffusion model that will learn data distribution of the latent representation of the autoencoder. Together with the diffusion model, we define a beta scheduler responsible for defining the amount of noise tahat is added across the diffusion's model Markov chain."
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "88cbe609",
"metadata": {},
"outputs": [],
"source": [
"unet = DiffusionModelUNet(\n",
" spatial_dims=3,\n",
" in_channels=3,\n",
" out_channels=3,\n",
" num_res_blocks=1,\n",
" num_channels=[32, 64, 64],\n",
" attention_levels=(False, True, True),\n",
" num_head_channels=1,\n",
")\n",
"unet.to(device)\n",
"\n",
"\n",
"scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)"
]
},
{
Expand Down Expand Up @@ -740,26 +776,12 @@
},
{
"cell_type": "code",
"execution_count": 42,
"id": "88cbe609",
"execution_count": null,
"id": "7de37f3a",
"metadata": {},
"outputs": [],
"source": [
"unet = DiffusionModelUNet(\n",
" spatial_dims=3,\n",
" in_channels=3,\n",
" out_channels=3,\n",
" num_res_blocks=1,\n",
" num_channels=[32, 64, 64],\n",
" attention_levels=(False, True, True),\n",
" num_head_channels=1,\n",
")\n",
"unet.to(device)\n",
"\n",
"\n",
"scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule=\"scaled_linear\", beta_start=0.0015, beta_end=0.0195)\n",
"\n",
"inferer = LatentDiffusionInferer(scheduler)"
"inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)"
]
},
{
Expand All @@ -772,6 +794,14 @@
"optimizer_diff = torch.optim.Adam(params=unet.parameters(), lr=1e-4)"
]
},
{
"cell_type": "markdown",
"id": "4705c795",
"metadata": {},
"source": [
"### Train diffusion model"
]
},
{
"cell_type": "code",
"execution_count": 44,
Expand Down Expand Up @@ -1020,7 +1050,9 @@
"id": "c9de4288",
"metadata": {},
"source": [
"## Image generation"
"### Plotting sampling example\n",
"\n",
"Finally, we generate an image with our LDM. For that, we will initialize a latent representation with just noise. Then, we will use the `unet` to perform 1000 denoising steps. In the last step, we decode the latent representation and plot the sampled image."
]
},
{
Expand Down Expand Up @@ -1054,7 +1086,7 @@
"id": "fed68b96",
"metadata": {},
"source": [
"### Visualise Synthetic"
"### Visualise synthetic data"
]
},
{
Expand Down
Loading