Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
288 changes: 187 additions & 101 deletions tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.ipynb

Large diffs are not rendered by default.

257 changes: 155 additions & 102 deletions tutorials/generative/2d_ddpm/2d_ddpm_compare_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# format_version: '1.3'
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -204,124 +204,177 @@
# %% [markdown]
# ### Model training
# Here, we are training our model for 100 epochs (training time: ~40 minutes). It is necessary to train for a bit longer than other tutorials because the DDIM and PNDM schedules seem to require a model trained longer before they start producing good samples, when compared to DDPM.
#
# If you would like to skip the training and use a pre-trained model instead, set `use_pretrained=True`. This model was trained using the code in `tutorials/generative/distributed_training/ddpm_training_ddp.py`

# %%
n_epochs = 100
val_interval = 10
epoch_loss_list = []
val_epoch_loss_list = []
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in progress_bar:
images = batch["image"].to(device)
optimizer.zero_grad(set_to_none=True)

# Randomly select the timesteps to be used for the minibacth
timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long()

# Add noise to the minibatch images with intensity defined by the scheduler and timesteps
noise = torch.randn_like(images).to(device)
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)

# In this example, we are parametrising our DDPM to learn the added noise (epsilon).
# For this reason, we are using our network to predict the added noise and then using L1 loss to predict
# its performance.
noise_pred = model(x=noisy_image, timesteps=timesteps)
loss = F.l1_loss(noise_pred.float(), noise.float())

loss.backward()
optimizer.step()
epoch_loss += loss.item()

progress_bar.set_postfix(
{
"loss": epoch_loss / (step + 1),
}
)
epoch_loss_list.append(epoch_loss / (step + 1))

if (epoch + 1) % val_interval == 0:
model.eval()
val_epoch_loss = 0
progress_bar = tqdm(enumerate(val_loader), total=len(train_loader))
progress_bar.set_description(f"Epoch {epoch} - Validation set")
use_pretrained = False

if use_pretrained:
model = torch.hub.load("marksgraham/pretrained_generative_models", model="ddpm_2d", verbose=True).to(device)
else:
n_epochs = 100
val_interval = 10
epoch_loss_list = []
val_epoch_loss_list = []
for epoch in range(n_epochs):
model.train()
epoch_loss = 0
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in progress_bar:
images = batch["image"].to(device)
optimizer.zero_grad(set_to_none=True)

# Randomly select the timesteps to be used for the minibacth
timesteps = torch.randint(0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device).long()

# Add noise to the minibatch images with intensity defined by the scheduler and timesteps
noise = torch.randn_like(images).to(device)
with torch.no_grad():
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
noise_pred = model(x=noisy_image, timesteps=timesteps)
val_loss = F.l1_loss(noise_pred.float(), noise.float())
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)

# In this example, we are parametrising our DDPM to learn the added noise (epsilon).
# For this reason, we are using our network to predict the added noise and then using L1 loss to predict
# its performance.
noise_pred = model(x=noisy_image, timesteps=timesteps)
loss = F.l1_loss(noise_pred.float(), noise.float())

loss.backward()
optimizer.step()
epoch_loss += loss.item()

val_epoch_loss += val_loss.item()
progress_bar.set_postfix(
{
"val_loss": val_epoch_loss / (step + 1),
"loss": epoch_loss / (step + 1),
}
)
val_epoch_loss_list.append(val_epoch_loss / (step + 1))

# Sampling image during training
noise = torch.randn((1, 1, 64, 64))
noise = noise.to(device)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
plt.figure(figsize=(8, 4))
plt.subplot(3, len(sampling_steps), 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
plt.ylabel("DDPM")
plt.title("1000 steps")
# DDIM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
ddim_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
epoch_loss_list.append(epoch_loss / (step + 1))

if (epoch + 1) % val_interval == 0:
model.eval()
val_epoch_loss = 0
progress_bar = tqdm(enumerate(val_loader), total=len(train_loader))
progress_bar.set_description(f"Epoch {epoch} - Validation set")
for step, batch in progress_bar:
images = batch["image"].to(device)
timesteps = torch.randint(
0, ddpm_scheduler.num_train_timesteps, (images.shape[0],), device=device
).long()
noise = torch.randn_like(images).to(device)
with torch.no_grad():
noisy_image = ddpm_scheduler.add_noise(original_samples=images, noise=noise, timesteps=timesteps)
noise_pred = model(x=noisy_image, timesteps=timesteps)
val_loss = F.l1_loss(noise_pred.float(), noise.float())

val_epoch_loss += val_loss.item()
progress_bar.set_postfix(
{
"val_loss": val_epoch_loss / (step + 1),
}
)
val_epoch_loss_list.append(val_epoch_loss / (step + 1))

# Sampling image during training
noise = torch.randn((1, 1, 64, 64))
noise = noise.to(device)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
plt.figure(figsize=(8, 4))
plt.subplot(3, len(sampling_steps), 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("DDIM")
if idx == 0:
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
# PNDM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
pndm_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("PNDM")
if idx == 0:
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
plt.suptitle(f"Epoch {epoch+1}")
plt.show()
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
plt.ylabel("DDPM")
plt.title("1000 steps")
# DDIM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
ddim_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("DDIM")
if idx == 0:
plt.tick_params(
top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False
)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
# PNDM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
pndm_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("PNDM")
if idx == 0:
plt.tick_params(
top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False
)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
plt.suptitle(f"Epoch {epoch+1}")
plt.show()
# %% [markdown]
# ### Learning curves

# %%
plt.style.use("seaborn")
plt.title("Learning Curves", fontsize=20)
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train")
plt.plot(
np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),
val_epoch_loss_list,
color="C1",
linewidth=2.0,
label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()
if not use_pretrained:
plt.style.use("seaborn")
plt.title("Learning Curves", fontsize=20)
plt.plot(np.linspace(1, n_epochs, n_epochs), epoch_loss_list, color="C0", linewidth=2.0, label="Train")
plt.plot(
np.linspace(val_interval, n_epochs, int(n_epochs / val_interval)),
val_epoch_loss_list,
color="C1",
linewidth=2.0,
label="Validation",
)
plt.yticks(fontsize=12)
plt.xticks(fontsize=12)
plt.xlabel("Epochs", fontsize=16)
plt.ylabel("Loss", fontsize=16)
plt.legend(prop={"size": 14})
plt.show()


# %% [markdown]
# ### Compare samples from trained model

# %%
noise = torch.randn((1, 1, 64, 64))
noise = noise.to(device)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddpm_scheduler)
plt.figure(figsize=(8, 4))
plt.subplot(3, len(sampling_steps), 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
plt.ylabel("DDPM")
plt.title("1000 steps")
# DDIM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
ddim_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=ddim_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) + idx + 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("DDIM")
if idx == 0:
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
# PNDM
for idx, reduced_sampling_steps in enumerate(sampling_steps):
pndm_scheduler.set_timesteps(reduced_sampling_steps)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=pndm_scheduler)
plt.subplot(3, len(sampling_steps), len(sampling_steps) * 2 + idx + 1)
plt.imshow(image[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
plt.ylabel("PNDM")
if idx == 0:
plt.tick_params(top=False, bottom=False, left=False, right=False, labelleft=False, labelbottom=False)
else:
plt.axis("off")
plt.title(f"{reduced_sampling_steps} steps")
plt.show()

# %% [markdown]
# ### Cleanup data directory
Expand Down
Loading