Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 44 additions & 41 deletions generative/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def sample(
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
Expand Down Expand Up @@ -592,13 +593,15 @@ def __call__(
raise NotImplementedError(f"{mode} condition is not supported")

noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
)

if mode == "concat":
noisy_image = torch.cat([noisy_image, condition], dim=1)
condition = None

down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)
Expand Down Expand Up @@ -654,32 +657,32 @@ def sample(
progress_bar = iter(scheduler.timesteps)
intermediates = []
for t in progress_bar:
if mode == "concat":
model_input = torch.cat([image, conditioning], dim=1)
context_ = None
else:
model_input = image
context_ = conditioning

# 1. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
x=model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=context_,
)
# 2. predict noise model_output
diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat":
model_input = torch.cat([image, conditioning], dim=1)
model_output = diffuse(
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
else:
model_output = diffuse(
image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
context=conditioning,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
model_output = diffuse(
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
context=context_,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)

# 3. compute previous image: x_t -> x_t-1
image, _ = scheduler.step(model_output, t, image)
Expand Down Expand Up @@ -743,31 +746,30 @@ def get_likelihood(
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)

if mode == "concat":
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
conditioning = None

down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
x=noisy_image,
timesteps=torch.Tensor((t,)).to(inputs.device),
controlnet_cond=cn_cond,
context=conditioning,
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat":
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
model_output = diffuse(
noisy_image,
timesteps=timesteps,
context=None,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
else:
model_output = diffuse(
x=noisy_image,
timesteps=timesteps,
context=conditioning,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)
model_output = diffuse(
noisy_image,
timesteps=timesteps,
context=conditioning,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)

# get the model's predicted mean, and variance if it is predicted
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
Expand Down Expand Up @@ -994,7 +996,8 @@ def sample(
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
Expand Down
4 changes: 3 additions & 1 deletion generative/networks/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def reversed_step(

# 2. compute alphas, betas at timestep t+1
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
alpha_prod_t_next = (
self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod
)

beta_prod_t = 1 - alpha_prod_t

Expand Down
14 changes: 10 additions & 4 deletions tests/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):

@parameterized.expand(CNDM_TEST_CASES)
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
model_params["with_conditioning"] = True
model_params["cross_attention_dim"] = 3
model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True
model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3
model = DiffusionModelUNet(**model_params)
controlnet = ControlNet(**controlnet_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -603,10 +603,12 @@ def test_normal_cdf(self):
def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape):
# copy the model_params dict to prevent from modifying test cases
model_params = model_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None
model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down Expand Up @@ -986,8 +988,10 @@ def test_prediction_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down Expand Up @@ -1066,8 +1070,10 @@ def test_sample_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down
68 changes: 37 additions & 31 deletions tutorials/generative/2d_controlnet/2d_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@
inferer = DiffusionInferer(scheduler)



# %% [markdown]
# ### Run training
#
Expand Down Expand Up @@ -348,10 +347,14 @@
0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()

noise_pred = controlnet_inferer(inputs = images, diffusion_model = model,
controlnet = controlnet, noise = noise,
timesteps = timesteps,
cn_cond = masks, )
noise_pred = controlnet_inferer(
inputs=images,
diffusion_model=model,
controlnet=controlnet,
noise=noise,
timesteps=timesteps,
cn_cond=masks,
)

loss = F.mse_loss(noise_pred.float(), noise.float())

Expand All @@ -378,13 +381,16 @@
0, controlnet_inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device
).long()

noise_pred = controlnet_inferer(inputs = images, diffusion_model = model,
controlnet = controlnet, noise = noise,
timesteps = timesteps,
cn_cond = masks, )
noise_pred = controlnet_inferer(
inputs=images,
diffusion_model=model,
controlnet=controlnet,
noise=noise,
timesteps=timesteps,
cn_cond=masks,
)
val_loss = F.mse_loss(noise_pred.float(), noise.float())


val_epoch_loss += val_loss.item()

progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)})
Expand All @@ -398,30 +404,30 @@
with autocast(enabled=True):
noise = torch.randn((1, 1, 64, 64)).to(device)
sample = controlnet_inferer.sample(
input_noise = noise,
diffusion_model = model,
controlnet = controlnet,
cn_cond = masks[0, None, ...],
scheduler = scheduler,
input_noise=noise,
diffusion_model=model,
controlnet=controlnet,
cn_cond=masks[0, None, ...],
scheduler=scheduler,
)

# Without using an inferer:
# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110)
# progress_bar_sampling.set_description("sampling...")
# sample = torch.randn((1, 1, 64, 64)).to(device)
# for t in progress_bar_sampling:
# with torch.no_grad():
# with autocast(enabled=True):
# down_block_res_samples, mid_block_res_sample = controlnet(
# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...]
# )
# noise_pred = model(
# sample,
# timesteps=torch.Tensor((t,)).to(device),
# down_block_additional_residuals=down_block_res_samples,
# mid_block_additional_residual=mid_block_res_sample,
# )
# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample)
# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110)
# progress_bar_sampling.set_description("sampling...")
# sample = torch.randn((1, 1, 64, 64)).to(device)
# for t in progress_bar_sampling:
# with torch.no_grad():
# with autocast(enabled=True):
# down_block_res_samples, mid_block_res_sample = controlnet(
# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...]
# )
# noise_pred = model(
# sample,
# timesteps=torch.Tensor((t,)).to(device),
# down_block_additional_residuals=down_block_res_samples,
# mid_block_additional_residual=mid_block_res_sample,
# )
# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample)

plt.subplots(1, 2, figsize=(4, 2))
plt.subplot(1, 2, 1)
Expand Down