diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 768413e9e6a3..8b371703ec84 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -252,7 +252,8 @@ def add_noise( ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - self.timesteps = self.timesteps.to(original_samples.device) + dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype + self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) timesteps = timesteps.to(original_samples.device) schedule_timesteps = self.timesteps diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 899870a74a3f..6f86ffc85e05 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -27,6 +27,7 @@ PNDMScheduler, ScoreSdeVeScheduler, ) +from diffusers.utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -258,6 +259,23 @@ def test_scheduler_public_api(self): scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) + def test_add_noise_device(self): + for scheduler_class in self.scheduler_classes: + if scheduler_class == IPNDMScheduler: + # Skip until #990 is addressed + continue + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample.to(torch_device) + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) + + noise = torch.randn_like(scaled_sample).to(torch_device) + t = torch.tensor([10]).to(torch_device) + noised = scheduler.add_noise(scaled_sample, noise, t) + self.assertEqual(noised.shape, scaled_sample.shape) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,)