diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5880d456..48eb7f6d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -459,7 +459,8 @@ def sample( latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs @@ -592,13 +593,15 @@ def __call__( raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond - ) + if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition + ) + diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -654,32 +657,32 @@ def sample( progress_bar = iter(scheduler.timesteps) intermediates = [] for t in progress_bar: + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + context_ = None + else: + model_input = image + context_ = conditioning + # 1. ControlNet forward down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + x=model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=context_, ) # 2. predict noise model_output diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": - model_input = torch.cat([image, conditioning], dim=1) - model_output = diffuse( - model_input, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - else: - model_output = diffuse( - image, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=context_, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) # 3. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -743,31 +746,30 @@ def get_likelihood( for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + conditioning = None + down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + x=noisy_image, + timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + context=conditioning, ) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": - noisy_image = torch.cat([noisy_image, conditioning], dim=1) - model_output = diffuse( - noisy_image, - timesteps=timesteps, - context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - else: - model_output = diffuse( - x=noisy_image, - timesteps=timesteps, - context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -994,7 +996,8 @@ def sample( latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index eb26b2fd..e64ff596 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -257,7 +257,9 @@ def reversed_step( # 2. compute alphas, betas at timestep t+1 alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod + alpha_prod_t_next = ( + self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod + ) beta_prod_t = 1 - alpha_prod_t diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index c38eb4c8..a67eb4bb 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -537,8 +537,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): @parameterized.expand(CNDM_TEST_CASES) def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): - model_params["with_conditioning"] = True - model_params["cross_attention_dim"] = 3 + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True + model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -603,10 +603,12 @@ def test_normal_cdf(self): def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): # copy the model_params dict to prevent from modifying test cases model_params = model_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 2 model_params["in_channels"] = model_params["in_channels"] + n_concat_channel - model_params["cross_attention_dim"] = None - model_params["with_conditioning"] = False + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -986,8 +988,10 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: @@ -1066,8 +1070,10 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: diff --git a/tutorials/generative/2d_controlnet/2d_controlnet.py b/tutorials/generative/2d_controlnet/2d_controlnet.py index 8de30c3d..9aca295e 100644 --- a/tutorials/generative/2d_controlnet/2d_controlnet.py +++ b/tutorials/generative/2d_controlnet/2d_controlnet.py @@ -211,7 +211,6 @@ inferer = DiffusionInferer(scheduler) - # %% [markdown] # ### Run training # @@ -348,10 +347,14 @@ 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device ).long() - noise_pred = controlnet_inferer(inputs = images, diffusion_model = model, - controlnet = controlnet, noise = noise, - timesteps = timesteps, - cn_cond = masks, ) + noise_pred = controlnet_inferer( + inputs=images, + diffusion_model=model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=masks, + ) loss = F.mse_loss(noise_pred.float(), noise.float()) @@ -378,13 +381,16 @@ 0, controlnet_inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device ).long() - noise_pred = controlnet_inferer(inputs = images, diffusion_model = model, - controlnet = controlnet, noise = noise, - timesteps = timesteps, - cn_cond = masks, ) + noise_pred = controlnet_inferer( + inputs=images, + diffusion_model=model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=masks, + ) val_loss = F.mse_loss(noise_pred.float(), noise.float()) - val_epoch_loss += val_loss.item() progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)}) @@ -398,30 +404,30 @@ with autocast(enabled=True): noise = torch.randn((1, 1, 64, 64)).to(device) sample = controlnet_inferer.sample( - input_noise = noise, - diffusion_model = model, - controlnet = controlnet, - cn_cond = masks[0, None, ...], - scheduler = scheduler, + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=masks[0, None, ...], + scheduler=scheduler, ) # Without using an inferer: -# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110) -# progress_bar_sampling.set_description("sampling...") -# sample = torch.randn((1, 1, 64, 64)).to(device) -# for t in progress_bar_sampling: -# with torch.no_grad(): -# with autocast(enabled=True): -# down_block_res_samples, mid_block_res_sample = controlnet( -# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...] -# ) -# noise_pred = model( -# sample, -# timesteps=torch.Tensor((t,)).to(device), -# down_block_additional_residuals=down_block_res_samples, -# mid_block_additional_residual=mid_block_res_sample, -# ) -# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample) + # progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110) + # progress_bar_sampling.set_description("sampling...") + # sample = torch.randn((1, 1, 64, 64)).to(device) + # for t in progress_bar_sampling: + # with torch.no_grad(): + # with autocast(enabled=True): + # down_block_res_samples, mid_block_res_sample = controlnet( + # x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...] + # ) + # noise_pred = model( + # sample, + # timesteps=torch.Tensor((t,)).to(device), + # down_block_additional_residuals=down_block_res_samples, + # mid_block_additional_residual=mid_block_res_sample, + # ) + # sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample) plt.subplots(1, 2, figsize=(4, 2)) plt.subplot(1, 2, 1)