diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8b371703ec84..6157f4b4dc65 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -252,9 +252,13 @@ def add_noise( ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype - self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) - timesteps = timesteps.to(original_samples.device) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) schedule_timesteps = self.timesteps diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6f86ffc85e05..194fb66f663f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -266,13 +266,14 @@ def test_add_noise_device(self): continue scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(100) sample = self.dummy_sample.to(torch_device) scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) noise = torch.randn_like(scaled_sample).to(torch_device) - t = torch.tensor([10]).to(torch_device) + t = scheduler.timesteps[5][None] noised = scheduler.add_noise(scaled_sample, noise, t) self.assertEqual(noised.shape, scaled_sample.shape)