diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py index 74f725c4..b6b6bbf2 100644 --- a/tests/test_diffusion_inferer.py +++ b/tests/test_diffusion_inferer.py @@ -120,6 +120,8 @@ def test_ddim_sampler(self, model_params, input_shape): @parameterized.expand(TEST_CASES) def test_sampler_conditioned(self, model_params, input_shape): + model_params["with_conditioning"] = True + model_params["cross_attention_dim"] = 3 model = DiffusionModelUNet(**model_params) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) @@ -130,7 +132,7 @@ def test_sampler_conditioned(self, model_params, input_shape): ) inferer = DiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - conditioning = torch.randn([input_shape[0], 1, 3]) + conditioning = torch.randn([input_shape[0], 1, 3]).to(device) sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model,