From 01527d829b4bba9589fdeddc6a44300183cedc60 Mon Sep 17 00:00:00 2001 From: Walter Hugo Lopez Pinaya Date: Wed, 11 Jan 2023 23:10:20 +0000 Subject: [PATCH] Fix test Signed-off-by: Walter Hugo Lopez Pinaya --- tests/test_diffusion_inferer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 74f725c4..b6b6bbf2 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -120,6 +120,8 @@ def test_ddim_sampler(self, model_params, input_shape): @parameterized.expand(TEST_CASES) def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -130,7 +132,7 @@ def test_sampler_conditioned(self, model_params, input_shape): ) inferer = DiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - conditioning = torch.randn([input_shape[0], 1, 3]) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model,