Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
875 changes: 471 additions & 404 deletions tutorials/generative/2d_ldm/2d_ldm_tutorial.ipynb

Large diffs are not rendered by default.

69 changes: 57 additions & 12 deletions tutorials/generative/2d_ldm/2d_ldm_tutorial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: ipynb,py
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.4
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# # 2D Latent Diffusion Model

# +
Expand Down Expand Up @@ -31,7 +47,7 @@
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import DiffusionInferer
from generative.inferers import LatentDiffusionInferer
from generative.losses.adversarial_loss import PatchAdversarialLoss
from generative.losses.perceptual import PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
Expand Down Expand Up @@ -125,8 +141,6 @@

scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", beta_start=0.0015, beta_end=0.0195)

inferer = DiffusionInferer(scheduler)

discriminator = PatchDiscriminator(
spatial_dims=2,
num_layers_d=3,
Expand Down Expand Up @@ -274,6 +288,26 @@

# ## Train Diffusion Model

# ### Scaling factor
#
# As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) can affect the results obtained with the LDM, if the standard deviation of the latent space distribution drifts too much from that of a Gaussian. For this reason, it is best practice to use a scaling factor to adapt this standard deviation.
#
# _Note: In case where the latent space is close to a Gaussian distribution, the scaling factor will be close to one, and the results will not differ from those obtained when it is not used._
#

# +
with torch.no_grad():
with autocast(enabled=True):
z = autoencoderkl.encode_stage_2_inputs(check_data["image"].to(device))

print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)
# -

# We define the inferer using the scale factor:

inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)

# It takes about ~80 min to train the model.

# +
Expand All @@ -295,14 +329,14 @@
for step, batch in progress_bar:
images = batch["image"].to(device)
optimizer.zero_grad(set_to_none=True)

with autocast(enabled=True):
z_mu, z_sigma = autoencoderkl.encode(images)
z = autoencoderkl.sampling(z_mu, z_sigma)

noise = torch.randn_like(z).to(device)
timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device).long()
noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps)
noise_pred = inferer(
inputs=images, diffusion_model=unet, noise=noise, timesteps=timesteps, autoencoder_model=autoencoderkl
)
loss = F.mse_loss(noise_pred.float(), noise.float())

scaler.scale(loss).backward()
Expand All @@ -329,7 +363,13 @@
timesteps = torch.randint(
0, inferer.scheduler.num_train_timesteps, (z.shape[0],), device=z.device
).long()
noise_pred = inferer(inputs=z, diffusion_model=unet, noise=noise, timesteps=timesteps)
noise_pred = inferer(
inputs=images,
diffusion_model=unet,
noise=noise,
timesteps=timesteps,
autoencoder_model=autoencoderkl,
)

loss = F.mse_loss(noise_pred.float(), noise.float())

Expand All @@ -343,8 +383,9 @@
z = z.to(device)
scheduler.set_timesteps(num_inference_steps=1000)
with autocast(enabled=True):
z = inferer.sample(input_noise=z, diffusion_model=unet, scheduler=scheduler)
decoded = autoencoderkl.decode(z)
decoded = inferer.sample(
input_noise=z, diffusion_model=unet, scheduler=scheduler, autoencoder_model=autoencoderkl
)

plt.figure(figsize=(2, 2))
plt.style.use("default")
Expand All @@ -371,7 +412,12 @@

noise = torch.randn_like(z).to(device)
image, intermediates = inferer.sample(
input_noise=z, diffusion_model=unet, scheduler=scheduler, save_intermediates=True, intermediate_steps=100
input_noise=z,
diffusion_model=unet,
scheduler=scheduler,
save_intermediates=True,
intermediate_steps=100,
autoencoder_model=autoencoderkl,
)


Expand All @@ -381,8 +427,7 @@
decoded_images = []
for image in intermediates:
with torch.no_grad():
decoded = autoencoderkl.decode(image)
decoded_images.append(decoded)
decoded_images.append(image)
plt.figure(figsize=(10, 12))
chain = torch.cat(decoded_images, dim=-1)
plt.style.use("default")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@
" num_levels=2,\n",
" downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1)),\n",
" upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),\n",
" num_channels=(256,256),\n",
" num_res_channels=(256,256),\n",
" num_channels=(256, 256),\n",
" num_res_channels=(256, 256),\n",
" num_embeddings=256,\n",
" embedding_dim=32,\n",
")\n",
Expand Down Expand Up @@ -603,13 +603,7 @@
"Epoch 69: 100%|██████████████| 125/125 [00:31<00:00, 3.92it/s, recons_loss=0.0162, quantization_loss=2.33e-5]\n",
"Epoch 70: 100%|███████████████| 125/125 [00:31<00:00, 3.91it/s, recons_loss=0.0162, quantization_loss=2.5e-5]\n",
"Epoch 71: 100%|██████████████| 125/125 [00:31<00:00, 3.91it/s, recons_loss=0.0168, quantization_loss=2.34e-5]\n",
"Epoch 72: 100%|██████████████| 125/125 [00:31<00:00, 3.92it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 72: 100%|██████████████| 125/125 [00:31<00:00, 3.92it/s, recons_loss=0.0171, quantization_loss=2.01e-5]\n",
"Epoch 73: 100%|██████████████| 125/125 [00:31<00:00, 3.92it/s, recons_loss=0.0166, quantization_loss=2.05e-5]\n",
"Epoch 74: 100%|██████████████| 125/125 [00:31<00:00, 3.92it/s, recons_loss=0.0165, quantization_loss=2.36e-5]\n",
"Epoch 75: 100%|██████████████| 125/125 [00:31<00:00, 3.91it/s, recons_loss=0.0161, quantization_loss=1.96e-5]\n",
Expand Down Expand Up @@ -679,10 +673,7 @@
" epoch_loss += recons_loss.item()\n",
"\n",
" progress_bar.set_postfix(\n",
" {\n",
" \"recons_loss\": epoch_loss / (step + 1),\n",
" \"quantization_loss\": quantization_loss.item() / (step + 1),\n",
" }\n",
" {\"recons_loss\": epoch_loss / (step + 1), \"quantization_loss\": quantization_loss.item() / (step + 1)}\n",
" )\n",
" epoch_recon_loss_list.append(epoch_loss / (step + 1))\n",
" epoch_quant_loss_list.append(quantization_loss.item() / (step + 1))\n",
Expand Down Expand Up @@ -902,11 +893,9 @@
"# Get spatial dimensions of data\n",
"# We divide the spatial shape by 4 as the vqvae downsamples the image by a factor of 4 along each dimension\n",
"spatial_shape = next(iter(train_loader))[\"image\"].shape[2:]\n",
"spatial_shape = (int(spatial_shape[0]/4),int(spatial_shape[1]/4))\n",
"spatial_shape = (int(spatial_shape[0] / 4), int(spatial_shape[1] / 4))\n",
"\n",
"ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value,\n",
" spatial_dims=2,\n",
" dimensions=(1,) + spatial_shape)\n",
"ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(1,) + spatial_shape)\n",
"\n",
"sequence_ordering = ordering.get_sequence_ordering()\n",
"revert_sequence_ordering = ordering.get_revert_sequence_ordering()"
Expand Down Expand Up @@ -1367,11 +1356,11 @@
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"transformer_model = DecoderOnlyTransformer(\n",
" num_tokens= 256, # must be equal to num_embeddings input of VQVAE\n",
" max_seq_len=spatial_shape[0]*spatial_shape[1],\n",
" attn_layers_dim=64,\n",
" attn_layers_depth=12,\n",
" attn_layers_heads=8,\n",
" num_tokens=256, # must be equal to num_embeddings input of VQVAE\n",
" max_seq_len=spatial_shape[0] * spatial_shape[1],\n",
" attn_layers_dim=64,\n",
" attn_layers_depth=12,\n",
" attn_layers_heads=8,\n",
")\n",
"transformer_model.to(device)"
]
Expand Down Expand Up @@ -1403,14 +1392,8 @@
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def generate(\n",
" net,\n",
" vqvae_model,\n",
" starting_tokens,\n",
" seq_len,\n",
" **kwargs\n",
"):\n",
" \n",
"def generate(net, vqvae_model, starting_tokens, seq_len, **kwargs):\n",
"\n",
" progress_bar = iter(range(seq_len))\n",
"\n",
" latent_seq = starting_tokens.long()\n",
Expand All @@ -1427,18 +1410,17 @@
" logits = logits[:, -1, :]\n",
" # optionally crop the logits to only the top k options\n",
"\n",
" \n",
" # apply softmax to convert logits to (normalized) probabilities\n",
" probs = F.softmax(logits, dim=-1)\n",
" # remove the chance to be sampled the BOS token\n",
" probs[:, vqvae_model.num_embeddings-1] = 0\n",
" probs[:, vqvae_model.num_embeddings - 1] = 0\n",
"\n",
" # sample from the distribution\n",
" idx_next = torch.multinomial(probs, num_samples=1)\n",
" latent_seq = torch.cat((latent_seq, idx_next), dim=1)\n",
"\n",
" latent_seq = latent_seq[:, 1:]\n",
" \n",
"\n",
" return latent_seq"
]
},
Expand Down Expand Up @@ -1533,13 +1515,7 @@
"Epoch 69: 100%|███████████████████████████████████████████████| 999/999 [00:56<00:00, 17.61it/s, ce_loss=2.19]\n",
"Epoch 70: 100%|███████████████████████████████████████████████| 999/999 [00:55<00:00, 17.86it/s, ce_loss=2.19]\n",
"Epoch 71: 100%|███████████████████████████████████████████████| 999/999 [00:57<00:00, 17.37it/s, ce_loss=2.19]\n",
"Epoch 72: 100%|███████████████████████████████████████████████| 999/999 [00:57<00:00, 17.50it/s, ce_loss=2.18]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Epoch 72: 100%|███████████████████████████████████████████████| 999/999 [00:57<00:00, 17.50it/s, ce_loss=2.18]\n",
"Epoch 73: 100%|███████████████████████████████████████████████| 999/999 [00:56<00:00, 17.56it/s, ce_loss=2.19]\n",
"Epoch 74: 100%|███████████████████████████████████████████████| 999/999 [00:57<00:00, 17.48it/s, ce_loss=2.18]\n",
"Epoch 75: 100%|███████████████████████████████████████████████| 999/999 [00:56<00:00, 17.79it/s, ce_loss=2.18]\n",
Expand Down Expand Up @@ -1618,14 +1594,9 @@
"\n",
" epoch_loss += loss.item()\n",
"\n",
" progress_bar.set_postfix(\n",
" {\n",
" \"ce_loss\": epoch_loss / (step + 1),\n",
" }\n",
" )\n",
" progress_bar.set_postfix({\"ce_loss\": epoch_loss / (step + 1)})\n",
" epoch_ce_loss_list.append(epoch_loss / (step + 1))\n",
"\n",
"\n",
" if (epoch + 1) % val_interval == 0:\n",
" transformer_model.eval()\n",
" val_loss = 0\n",
Expand Down Expand Up @@ -1653,10 +1624,12 @@
" # Generate a random sample to visualise progress\n",
" if val_step == 1:\n",
" starting_token = 255 * torch.ones((1, 1), device=device)\n",
" generated_latent = generate(transformer_model, vqvae_model, starting_token, spatial_shape[0]*spatial_shape[1])\n",
" generated_latent = generate(\n",
" transformer_model, vqvae_model, starting_token, spatial_shape[0] * spatial_shape[1]\n",
" )\n",
" generated_latent = generated_latent[0]\n",
" vqvae_latent = generated_latent[revert_sequence_ordering]\n",
" vqvae_latent = vqvae_latent.reshape((1,)+spatial_shape)\n",
" vqvae_latent = vqvae_latent.reshape((1,) + spatial_shape)\n",
" decoded = vqvae_model.decode_samples(vqvae_latent)\n",
" intermediary_images.append(decoded[:, 0])\n",
"\n",
Expand Down Expand Up @@ -1779,10 +1752,10 @@
"samples = []\n",
"for i in range(5):\n",
" starting_token = 255 * torch.ones((1, 1), device=device)\n",
" generated_latent = generate(transformer_model, vqvae_model, starting_token, spatial_shape[0]*spatial_shape[1])\n",
" generated_latent = generate(transformer_model, vqvae_model, starting_token, spatial_shape[0] * spatial_shape[1])\n",
" generated_latent = generated_latent[0]\n",
" vqvae_latent = generated_latent[revert_sequence_ordering]\n",
" vqvae_latent = vqvae_latent.reshape((1,)+spatial_shape)\n",
" vqvae_latent = vqvae_latent.reshape((1,) + spatial_shape)\n",
" decoded = vqvae_model.decode_samples(vqvae_latent)\n",
" samples.append(decoded[:, 0])"
]
Expand Down
Loading