From 6d67dcaaed54c25cd003352bf0e79adf6c69aaac Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 20 Mar 2024 08:29:40 +0000 Subject: [PATCH 1/3] Modify controlnet inferer to pass the same conditioning as the one the diffusion model is getting. Modification of tests accordingly. --- generative/inferers/inferer.py | 85 ++-- generative/networks/schedulers/ddim.py | 4 +- tests/test_controlnet_inferers.py | 410 +++++++++--------- .../generative/2d_controlnet/2d_controlnet.py | 68 +-- 4 files changed, 292 insertions(+), 275 deletions(-) diff --git a/generative/inferers/inferer.py b/generative/inferers/inferer.py index 5880d456..48eb7f6d 100644 --- a/generative/inferers/inferer.py +++ b/generative/inferers/inferer.py @@ -459,7 +459,8 @@ def sample( latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs @@ -592,13 +593,15 @@ def __call__( raise NotImplementedError(f"{mode} condition is not supported") noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) - down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond - ) + if mode == "concat": noisy_image = torch.cat([noisy_image, condition], dim=1) condition = None + down_block_res_samples, mid_block_res_sample = controlnet( + x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition + ) + diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) @@ -654,32 +657,32 @@ def sample( progress_bar = iter(scheduler.timesteps) intermediates = [] for t in progress_bar: + if mode == "concat": + model_input = torch.cat([image, conditioning], dim=1) + context_ = None + else: + model_input = image + context_ = conditioning + # 1. ControlNet forward down_block_res_samples, mid_block_res_sample = controlnet( - x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond + x=model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + controlnet_cond=cn_cond, + context=context_, ) # 2. predict noise model_output diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": - model_input = torch.cat([image, conditioning], dim=1) - model_output = diffuse( - model_input, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - else: - model_output = diffuse( - image, - timesteps=torch.Tensor((t,)).to(input_noise.device), - context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) + model_output = diffuse( + model_input, + timesteps=torch.Tensor((t,)).to(input_noise.device), + context=context_, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) # 3. compute previous image: x_t -> x_t-1 image, _ = scheduler.step(model_output, t, image) @@ -743,31 +746,30 @@ def get_likelihood( for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + + if mode == "concat": + noisy_image = torch.cat([noisy_image, conditioning], dim=1) + conditioning = None + down_block_res_samples, mid_block_res_sample = controlnet( - x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond + x=noisy_image, + timesteps=torch.Tensor((t,)).to(inputs.device), + controlnet_cond=cn_cond, + context=conditioning, ) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diffuse = partial(diffusion_model, seg=seg) - if mode == "concat": - noisy_image = torch.cat([noisy_image, conditioning], dim=1) - model_output = diffuse( - noisy_image, - timesteps=timesteps, - context=None, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) - else: - model_output = diffuse( - x=noisy_image, - timesteps=timesteps, - context=conditioning, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ) + model_output = diffuse( + noisy_image, + timesteps=timesteps, + context=conditioning, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + # get the model's predicted mean, and variance if it is predicted if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) @@ -994,7 +996,8 @@ def sample( latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) if save_intermediates: latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates ] decode = autoencoder_model.decode_stage_2_outputs diff --git a/generative/networks/schedulers/ddim.py b/generative/networks/schedulers/ddim.py index eb26b2fd..e64ff596 100644 --- a/generative/networks/schedulers/ddim.py +++ b/generative/networks/schedulers/ddim.py @@ -257,7 +257,9 @@ def reversed_step( # 2. compute alphas, betas at timestep t+1 alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_next = self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod + alpha_prod_t_next = ( + self.alphas_cumprod[next_timestep] if next_timestep < len(self.alphas_cumprod) else self.first_alpha_cumprod + ) beta_prod_t = 1 - alpha_prod_t diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index c38eb4c8..5eae56a9 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -16,7 +16,7 @@ import torch from parameterized import parameterized -from generative.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer +from generative.inferers import ControlNetLatentDiffusionInferer from generative.networks.nets import ( VQVAE, AutoencoderKL, @@ -25,7 +25,7 @@ SPADEAutoencoderKL, SPADEDiffusionModelUNet, ) -from generative.networks.schedulers import DDIMScheduler, DDPMScheduler +from generative.networks.schedulers import DDPMScheduler CNDM_TEST_CASES = [ [ @@ -438,202 +438,203 @@ ] -class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): - @parameterized.expand(CNDM_TEST_CASES) - def test_call(self, model_params, controlnet_params, input_shape): - model = DiffusionModelUNet(**model_params) - controlnet = ControlNet(**controlnet_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet.to(device) - controlnet.eval() - input = torch.randn(input_shape).to(device) - mask = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() - sample = inferer( - inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask - ) - self.assertEqual(sample.shape, input_shape) - - @parameterized.expand(CNDM_TEST_CASES) - def test_sample_intermediates(self, model_params, controlnet_params, input_shape): - model = DiffusionModelUNet(**model_params) - controlnet = ControlNet(**controlnet_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet.to(device) - controlnet.eval() - noise = torch.randn(input_shape).to(device) - mask = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(CNDM_TEST_CASES) - def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): - model = DiffusionModelUNet(**model_params) - controlnet = ControlNet(**controlnet_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet.to(device) - controlnet.eval() - mask = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=1000) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(CNDM_TEST_CASES) - def test_ddim_sampler(self, model_params, controlnet_params, input_shape): - model = DiffusionModelUNet(**model_params) - controlnet = ControlNet(**controlnet_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet.to(device) - controlnet.eval() - mask = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - intermediate_steps=1, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(CNDM_TEST_CASES) - def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): - model_params["with_conditioning"] = True - model_params["cross_attention_dim"] = 3 - model = DiffusionModelUNet(**model_params) - controlnet = ControlNet(**controlnet_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet.to(device) - controlnet.eval() - mask = torch.randn(input_shape).to(device) - noise = torch.randn(input_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - conditioning = torch.randn([input_shape[0], 1, 3]).to(device) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - controlnet=controlnet, - cn_cond=mask, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - conditioning=conditioning, - ) - self.assertEqual(len(intermediates), 10) - - @parameterized.expand(CNDM_TEST_CASES) - def test_get_likelihood(self, model_params, controlnet_params, input_shape): - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet = ControlNet(**controlnet_params) - controlnet.to(device) - controlnet.eval() - input = torch.randn(input_shape).to(device) - mask = torch.randn(input_shape).to(device) - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - likelihood, intermediates = inferer.get_likelihood( - inputs=input, - diffusion_model=model, - scheduler=scheduler, - controlnet=controlnet, - cn_cond=mask, - save_intermediates=True, - ) - self.assertEqual(intermediates[0].shape, input.shape) - self.assertEqual(likelihood.shape[0], input.shape[0]) - - def test_normal_cdf(self): - from scipy.stats import norm - - scheduler = DDPMScheduler(num_train_timesteps=10) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - x = torch.linspace(-10, 10, 20) - cdf_approx = inferer._approx_standard_normal_cdf(x) - cdf_true = norm.cdf(x) - torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) - - @parameterized.expand(CNDM_TEST_CASES) - def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): - # copy the model_params dict to prevent from modifying test cases - model_params = model_params.copy() - n_concat_channel = 2 - model_params["in_channels"] = model_params["in_channels"] + n_concat_channel - model_params["cross_attention_dim"] = None - model_params["with_conditioning"] = False - model = DiffusionModelUNet(**model_params) - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model.to(device) - model.eval() - controlnet = ControlNet(**controlnet_params) - controlnet.to(device) - controlnet.eval() - noise = torch.randn(input_shape).to(device) - mask = torch.randn(input_shape).to(device) - conditioning_shape = list(input_shape) - conditioning_shape[1] = n_concat_channel - conditioning = torch.randn(conditioning_shape).to(device) - scheduler = DDIMScheduler(num_train_timesteps=1000) - inferer = ControlNetDiffusionInferer(scheduler=scheduler) - scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, - diffusion_model=model, - controlnet=controlnet, - cn_cond=mask, - scheduler=scheduler, - save_intermediates=True, - intermediate_steps=1, - conditioning=conditioning, - mode="concat", - ) - self.assertEqual(len(intermediates), 10) +# class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): +# @parameterized.expand(CNDM_TEST_CASES) +# def test_call(self, model_params, controlnet_params, input_shape): +# model = DiffusionModelUNet(**model_params) +# controlnet = ControlNet(**controlnet_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet.to(device) +# controlnet.eval() +# input = torch.randn(input_shape).to(device) +# mask = torch.randn(input_shape).to(device) +# noise = torch.randn(input_shape).to(device) +# scheduler = DDPMScheduler(num_train_timesteps=10) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() +# sample = inferer( +# inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask +# ) +# self.assertEqual(sample.shape, input_shape) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_sample_intermediates(self, model_params, controlnet_params, input_shape): +# model = DiffusionModelUNet(**model_params) +# controlnet = ControlNet(**controlnet_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet.to(device) +# controlnet.eval() +# noise = torch.randn(input_shape).to(device) +# mask = torch.randn(input_shape).to(device) +# scheduler = DDPMScheduler(num_train_timesteps=10) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# sample, intermediates = inferer.sample( +# input_noise=noise, +# diffusion_model=model, +# scheduler=scheduler, +# controlnet=controlnet, +# cn_cond=mask, +# save_intermediates=True, +# intermediate_steps=1, +# ) +# self.assertEqual(len(intermediates), 10) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): +# model = DiffusionModelUNet(**model_params) +# controlnet = ControlNet(**controlnet_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet.to(device) +# controlnet.eval() +# mask = torch.randn(input_shape).to(device) +# noise = torch.randn(input_shape).to(device) +# scheduler = DDPMScheduler(num_train_timesteps=1000) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# sample, intermediates = inferer.sample( +# input_noise=noise, +# diffusion_model=model, +# scheduler=scheduler, +# controlnet=controlnet, +# cn_cond=mask, +# save_intermediates=True, +# intermediate_steps=1, +# ) +# self.assertEqual(len(intermediates), 10) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_ddim_sampler(self, model_params, controlnet_params, input_shape): +# model = DiffusionModelUNet(**model_params) +# controlnet = ControlNet(**controlnet_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet.to(device) +# controlnet.eval() +# mask = torch.randn(input_shape).to(device) +# noise = torch.randn(input_shape).to(device) +# scheduler = DDIMScheduler(num_train_timesteps=1000) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# sample, intermediates = inferer.sample( +# input_noise=noise, +# diffusion_model=model, +# scheduler=scheduler, +# controlnet=controlnet, +# cn_cond=mask, +# save_intermediates=True, +# intermediate_steps=1, +# ) +# self.assertEqual(len(intermediates), 10) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): +# model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True +# model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 +# model = DiffusionModelUNet(**model_params) +# controlnet = ControlNet(**controlnet_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet.to(device) +# controlnet.eval() +# mask = torch.randn(input_shape).to(device) +# noise = torch.randn(input_shape).to(device) +# scheduler = DDIMScheduler(num_train_timesteps=1000) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# conditioning = torch.randn([input_shape[0], 1, 3]).to(device) +# sample, intermediates = inferer.sample( +# input_noise=noise, +# diffusion_model=model, +# controlnet=controlnet, +# cn_cond=mask, +# scheduler=scheduler, +# save_intermediates=True, +# intermediate_steps=1, +# conditioning=conditioning, +# ) +# self.assertEqual(len(intermediates), 10) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_get_likelihood(self, model_params, controlnet_params, input_shape): +# model = DiffusionModelUNet(**model_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet = ControlNet(**controlnet_params) +# controlnet.to(device) +# controlnet.eval() +# input = torch.randn(input_shape).to(device) +# mask = torch.randn(input_shape).to(device) +# scheduler = DDPMScheduler(num_train_timesteps=10) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# likelihood, intermediates = inferer.get_likelihood( +# inputs=input, +# diffusion_model=model, +# scheduler=scheduler, +# controlnet=controlnet, +# cn_cond=mask, +# save_intermediates=True, +# ) +# self.assertEqual(intermediates[0].shape, input.shape) +# self.assertEqual(likelihood.shape[0], input.shape[0]) +# +# def test_normal_cdf(self): +# from scipy.stats import norm +# +# scheduler = DDPMScheduler(num_train_timesteps=10) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# x = torch.linspace(-10, 10, 20) +# cdf_approx = inferer._approx_standard_normal_cdf(x) +# cdf_true = norm.cdf(x) +# torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) +# +# @parameterized.expand(CNDM_TEST_CASES) +# def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): +# # copy the model_params dict to prevent from modifying test cases +# model_params = model_params.copy() +# n_concat_channel = 2 +# model_params["in_channels"] = model_params["in_channels"] + n_concat_channel +# controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel +# model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None +# model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False +# model = DiffusionModelUNet(**model_params) +# device = "cuda:0" if torch.cuda.is_available() else "cpu" +# model.to(device) +# model.eval() +# controlnet = ControlNet(**controlnet_params) +# controlnet.to(device) +# controlnet.eval() +# noise = torch.randn(input_shape).to(device) +# mask = torch.randn(input_shape).to(device) +# conditioning_shape = list(input_shape) +# conditioning_shape[1] = n_concat_channel +# conditioning = torch.randn(conditioning_shape).to(device) +# scheduler = DDIMScheduler(num_train_timesteps=1000) +# inferer = ControlNetDiffusionInferer(scheduler=scheduler) +# scheduler.set_timesteps(num_inference_steps=10) +# sample, intermediates = inferer.sample( +# input_noise=noise, +# diffusion_model=model, +# controlnet=controlnet, +# cn_cond=mask, +# scheduler=scheduler, +# save_intermediates=True, +# intermediate_steps=1, +# conditioning=conditioning, +# mode="concat", +# ) +# self.assertEqual(len(intermediates), 10) class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): @@ -975,10 +976,11 @@ def test_prediction_shape_conditioned_concat( autoencoder_params, dm_model_type, stage_2_params, - controlnet_params, + cn_params, input_shape, latent_shape, ): + if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -986,13 +988,15 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + cn_params = cn_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + cn_params["in_channels"] = cn_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) - controlnet = ControlNet(**controlnet_params) + controlnet = ControlNet(**cn_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" stage_1.to(device) @@ -1055,7 +1059,7 @@ def test_sample_shape_conditioned_concat( autoencoder_params, dm_model_type, stage_2_params, - controlnet_params, + cn_params, input_shape, latent_shape, ): @@ -1066,13 +1070,15 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() + cn_params = cn_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel + cn_params["in_channels"] = cn_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) - controlnet = ControlNet(**controlnet_params) + controlnet = ControlNet(**cn_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" stage_1.to(device) diff --git a/tutorials/generative/2d_controlnet/2d_controlnet.py b/tutorials/generative/2d_controlnet/2d_controlnet.py index 8de30c3d..9aca295e 100644 --- a/tutorials/generative/2d_controlnet/2d_controlnet.py +++ b/tutorials/generative/2d_controlnet/2d_controlnet.py @@ -211,7 +211,6 @@ inferer = DiffusionInferer(scheduler) - # %% [markdown] # ### Run training # @@ -348,10 +347,14 @@ 0, inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device ).long() - noise_pred = controlnet_inferer(inputs = images, diffusion_model = model, - controlnet = controlnet, noise = noise, - timesteps = timesteps, - cn_cond = masks, ) + noise_pred = controlnet_inferer( + inputs=images, + diffusion_model=model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=masks, + ) loss = F.mse_loss(noise_pred.float(), noise.float()) @@ -378,13 +381,16 @@ 0, controlnet_inferer.scheduler.num_train_timesteps, (images.shape[0],), device=images.device ).long() - noise_pred = controlnet_inferer(inputs = images, diffusion_model = model, - controlnet = controlnet, noise = noise, - timesteps = timesteps, - cn_cond = masks, ) + noise_pred = controlnet_inferer( + inputs=images, + diffusion_model=model, + controlnet=controlnet, + noise=noise, + timesteps=timesteps, + cn_cond=masks, + ) val_loss = F.mse_loss(noise_pred.float(), noise.float()) - val_epoch_loss += val_loss.item() progress_bar.set_postfix({"val_loss": val_epoch_loss / (step + 1)}) @@ -398,30 +404,30 @@ with autocast(enabled=True): noise = torch.randn((1, 1, 64, 64)).to(device) sample = controlnet_inferer.sample( - input_noise = noise, - diffusion_model = model, - controlnet = controlnet, - cn_cond = masks[0, None, ...], - scheduler = scheduler, + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=masks[0, None, ...], + scheduler=scheduler, ) # Without using an inferer: -# progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110) -# progress_bar_sampling.set_description("sampling...") -# sample = torch.randn((1, 1, 64, 64)).to(device) -# for t in progress_bar_sampling: -# with torch.no_grad(): -# with autocast(enabled=True): -# down_block_res_samples, mid_block_res_sample = controlnet( -# x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...] -# ) -# noise_pred = model( -# sample, -# timesteps=torch.Tensor((t,)).to(device), -# down_block_additional_residuals=down_block_res_samples, -# mid_block_additional_residual=mid_block_res_sample, -# ) -# sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample) + # progress_bar_sampling = tqdm(scheduler.timesteps, total=len(scheduler.timesteps), ncols=110) + # progress_bar_sampling.set_description("sampling...") + # sample = torch.randn((1, 1, 64, 64)).to(device) + # for t in progress_bar_sampling: + # with torch.no_grad(): + # with autocast(enabled=True): + # down_block_res_samples, mid_block_res_sample = controlnet( + # x=sample, timesteps=torch.Tensor((t,)).to(device).long(), controlnet_cond=masks[0, None, ...] + # ) + # noise_pred = model( + # sample, + # timesteps=torch.Tensor((t,)).to(device), + # down_block_additional_residuals=down_block_res_samples, + # mid_block_additional_residual=mid_block_res_sample, + # ) + # sample, _ = scheduler.step(model_output=noise_pred, timestep=t, sample=sample) plt.subplots(1, 2, figsize=(4, 2)) plt.subplot(1, 2, 1) From ec8694bf559d93b8b8616f87566dc3a3fa695120 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 20 Mar 2024 11:09:06 +0000 Subject: [PATCH 2/3] Uncommented controlnet inferer tests, fixed them. These should be running now. --- tests/test_controlnet_inferers.py | 399 +++++++++++++++--------------- 1 file changed, 200 insertions(+), 199 deletions(-) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 5eae56a9..4082e0fb 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -15,8 +15,8 @@ import torch from parameterized import parameterized - -from generative.inferers import ControlNetLatentDiffusionInferer +from generative.networks.schedulers import DDIMScheduler +from generative.inferers import ControlNetLatentDiffusionInferer, ControlNetDiffusionInferer from generative.networks.nets import ( VQVAE, AutoencoderKL, @@ -438,203 +438,204 @@ ] -# class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): -# @parameterized.expand(CNDM_TEST_CASES) -# def test_call(self, model_params, controlnet_params, input_shape): -# model = DiffusionModelUNet(**model_params) -# controlnet = ControlNet(**controlnet_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet.to(device) -# controlnet.eval() -# input = torch.randn(input_shape).to(device) -# mask = torch.randn(input_shape).to(device) -# noise = torch.randn(input_shape).to(device) -# scheduler = DDPMScheduler(num_train_timesteps=10) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() -# sample = inferer( -# inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask -# ) -# self.assertEqual(sample.shape, input_shape) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_sample_intermediates(self, model_params, controlnet_params, input_shape): -# model = DiffusionModelUNet(**model_params) -# controlnet = ControlNet(**controlnet_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet.to(device) -# controlnet.eval() -# noise = torch.randn(input_shape).to(device) -# mask = torch.randn(input_shape).to(device) -# scheduler = DDPMScheduler(num_train_timesteps=10) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# sample, intermediates = inferer.sample( -# input_noise=noise, -# diffusion_model=model, -# scheduler=scheduler, -# controlnet=controlnet, -# cn_cond=mask, -# save_intermediates=True, -# intermediate_steps=1, -# ) -# self.assertEqual(len(intermediates), 10) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): -# model = DiffusionModelUNet(**model_params) -# controlnet = ControlNet(**controlnet_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet.to(device) -# controlnet.eval() -# mask = torch.randn(input_shape).to(device) -# noise = torch.randn(input_shape).to(device) -# scheduler = DDPMScheduler(num_train_timesteps=1000) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# sample, intermediates = inferer.sample( -# input_noise=noise, -# diffusion_model=model, -# scheduler=scheduler, -# controlnet=controlnet, -# cn_cond=mask, -# save_intermediates=True, -# intermediate_steps=1, -# ) -# self.assertEqual(len(intermediates), 10) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_ddim_sampler(self, model_params, controlnet_params, input_shape): -# model = DiffusionModelUNet(**model_params) -# controlnet = ControlNet(**controlnet_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet.to(device) -# controlnet.eval() -# mask = torch.randn(input_shape).to(device) -# noise = torch.randn(input_shape).to(device) -# scheduler = DDIMScheduler(num_train_timesteps=1000) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# sample, intermediates = inferer.sample( -# input_noise=noise, -# diffusion_model=model, -# scheduler=scheduler, -# controlnet=controlnet, -# cn_cond=mask, -# save_intermediates=True, -# intermediate_steps=1, -# ) -# self.assertEqual(len(intermediates), 10) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): -# model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True -# model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 -# model = DiffusionModelUNet(**model_params) -# controlnet = ControlNet(**controlnet_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet.to(device) -# controlnet.eval() -# mask = torch.randn(input_shape).to(device) -# noise = torch.randn(input_shape).to(device) -# scheduler = DDIMScheduler(num_train_timesteps=1000) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# conditioning = torch.randn([input_shape[0], 1, 3]).to(device) -# sample, intermediates = inferer.sample( -# input_noise=noise, -# diffusion_model=model, -# controlnet=controlnet, -# cn_cond=mask, -# scheduler=scheduler, -# save_intermediates=True, -# intermediate_steps=1, -# conditioning=conditioning, -# ) -# self.assertEqual(len(intermediates), 10) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_get_likelihood(self, model_params, controlnet_params, input_shape): -# model = DiffusionModelUNet(**model_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet = ControlNet(**controlnet_params) -# controlnet.to(device) -# controlnet.eval() -# input = torch.randn(input_shape).to(device) -# mask = torch.randn(input_shape).to(device) -# scheduler = DDPMScheduler(num_train_timesteps=10) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# likelihood, intermediates = inferer.get_likelihood( -# inputs=input, -# diffusion_model=model, -# scheduler=scheduler, -# controlnet=controlnet, -# cn_cond=mask, -# save_intermediates=True, -# ) -# self.assertEqual(intermediates[0].shape, input.shape) -# self.assertEqual(likelihood.shape[0], input.shape[0]) -# -# def test_normal_cdf(self): -# from scipy.stats import norm -# -# scheduler = DDPMScheduler(num_train_timesteps=10) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# x = torch.linspace(-10, 10, 20) -# cdf_approx = inferer._approx_standard_normal_cdf(x) -# cdf_true = norm.cdf(x) -# torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) -# -# @parameterized.expand(CNDM_TEST_CASES) -# def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): -# # copy the model_params dict to prevent from modifying test cases -# model_params = model_params.copy() -# n_concat_channel = 2 -# model_params["in_channels"] = model_params["in_channels"] + n_concat_channel -# controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel -# model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None -# model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False -# model = DiffusionModelUNet(**model_params) -# device = "cuda:0" if torch.cuda.is_available() else "cpu" -# model.to(device) -# model.eval() -# controlnet = ControlNet(**controlnet_params) -# controlnet.to(device) -# controlnet.eval() -# noise = torch.randn(input_shape).to(device) -# mask = torch.randn(input_shape).to(device) -# conditioning_shape = list(input_shape) -# conditioning_shape[1] = n_concat_channel -# conditioning = torch.randn(conditioning_shape).to(device) -# scheduler = DDIMScheduler(num_train_timesteps=1000) -# inferer = ControlNetDiffusionInferer(scheduler=scheduler) -# scheduler.set_timesteps(num_inference_steps=10) -# sample, intermediates = inferer.sample( -# input_noise=noise, -# diffusion_model=model, -# controlnet=controlnet, -# cn_cond=mask, -# scheduler=scheduler, -# save_intermediates=True, -# intermediate_steps=1, -# conditioning=conditioning, -# mode="concat", -# ) -# self.assertEqual(len(intermediates), 10) +class ControlNetTestDiffusionSamplingInferer(unittest.TestCase): + @parameterized.expand(CNDM_TEST_CASES) + def test_call(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long() + sample = inferer( + inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask + ) + self.assertEqual(sample.shape, input_shape) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sample_intermediates(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddpm_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_ddim_sampler(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + intermediate_steps=1, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True + model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 + model = DiffusionModelUNet(**model_params) + controlnet = ControlNet(**controlnet_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet.to(device) + controlnet.eval() + mask = torch.randn(input_shape).to(device) + noise = torch.randn(input_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + ) + self.assertEqual(len(intermediates), 10) + + @parameterized.expand(CNDM_TEST_CASES) + def test_get_likelihood(self, model_params, controlnet_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + input = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + likelihood, intermediates = inferer.get_likelihood( + inputs=input, + diffusion_model=model, + scheduler=scheduler, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=True, + ) + self.assertEqual(intermediates[0].shape, input.shape) + self.assertEqual(likelihood.shape[0], input.shape[0]) + + def test_normal_cdf(self): + from scipy.stats import norm + + scheduler = DDPMScheduler(num_train_timesteps=10) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + x = torch.linspace(-10, 10, 20) + cdf_approx = inferer._approx_standard_normal_cdf(x) + cdf_true = norm.cdf(x) + torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5) + + @parameterized.expand(CNDM_TEST_CASES) + def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape): + # copy the model_params dict to prevent from modifying test cases + model_params = model_params.copy() + controlnet_params = controlnet_params.copy() + n_concat_channel = 2 + model_params["in_channels"] = model_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel + model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + controlnet = ControlNet(**controlnet_params) + controlnet.to(device) + controlnet.eval() + noise = torch.randn(input_shape).to(device) + mask = torch.randn(input_shape).to(device) + conditioning_shape = list(input_shape) + conditioning_shape[1] = n_concat_channel + conditioning = torch.randn(conditioning_shape).to(device) + scheduler = DDIMScheduler(num_train_timesteps=1000) + inferer = ControlNetDiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, + diffusion_model=model, + controlnet=controlnet, + cn_cond=mask, + scheduler=scheduler, + save_intermediates=True, + intermediate_steps=1, + conditioning=conditioning, + mode="concat", + ) + self.assertEqual(len(intermediates), 10) class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase): From 36d6e5465d34179f87307530cfba4f0241f3873c Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Wed, 20 Mar 2024 14:41:59 +0000 Subject: [PATCH 3/3] Re-formatting the test script, fix naming issues. --- tests/test_controlnet_inferers.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py index 4082e0fb..a67eb4bb 100644 --- a/tests/test_controlnet_inferers.py +++ b/tests/test_controlnet_inferers.py @@ -15,8 +15,8 @@ import torch from parameterized import parameterized -from generative.networks.schedulers import DDIMScheduler -from generative.inferers import ControlNetLatentDiffusionInferer, ControlNetDiffusionInferer + +from generative.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer from generative.networks.nets import ( VQVAE, AutoencoderKL, @@ -25,7 +25,7 @@ SPADEAutoencoderKL, SPADEDiffusionModelUNet, ) -from generative.networks.schedulers import DDPMScheduler +from generative.networks.schedulers import DDIMScheduler, DDPMScheduler CNDM_TEST_CASES = [ [ @@ -537,8 +537,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape): @parameterized.expand(CNDM_TEST_CASES) def test_sampler_conditioned(self, model_params, controlnet_params, input_shape): - model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True - model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = True + model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -608,7 +608,7 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input model_params["in_channels"] = model_params["in_channels"] + n_concat_channel controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel model_params["cross_attention_dim"] = controlnet_params["cross_attention_dim"] = None - model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False + model_params["with_conditioning"] = controlnet_params["with_conditioning"] = False model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -977,11 +977,10 @@ def test_prediction_shape_conditioned_concat( autoencoder_params, dm_model_type, stage_2_params, - cn_params, + controlnet_params, input_shape, latent_shape, ): - if ae_model_type == "AutoencoderKL": stage_1 = AutoencoderKL(**autoencoder_params) if ae_model_type == "VQVAE": @@ -989,15 +988,15 @@ def test_prediction_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() - cn_params = cn_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - cn_params["in_channels"] = cn_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) - controlnet = ControlNet(**cn_params) + controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" stage_1.to(device) @@ -1060,7 +1059,7 @@ def test_sample_shape_conditioned_concat( autoencoder_params, dm_model_type, stage_2_params, - cn_params, + controlnet_params, input_shape, latent_shape, ): @@ -1071,15 +1070,15 @@ def test_sample_shape_conditioned_concat( if ae_model_type == "SPADEAutoencoderKL": stage_1 = SPADEAutoencoderKL(**autoencoder_params) stage_2_params = stage_2_params.copy() - cn_params = cn_params.copy() + controlnet_params = controlnet_params.copy() n_concat_channel = 3 stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel - cn_params["in_channels"] = cn_params["in_channels"] + n_concat_channel + controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel if dm_model_type == "SPADEDiffusionModelUNet": stage_2 = SPADEDiffusionModelUNet(**stage_2_params) else: stage_2 = DiffusionModelUNet(**stage_2_params) - controlnet = ControlNet(**cn_params) + controlnet = ControlNet(**controlnet_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" stage_1.to(device)