From 2986f8c285a3f0c2787288071bb99a7280b88173 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 22 Oct 2022 09:11:12 +0200 Subject: [PATCH 1/7] Add failing test for #940. --- tests/test_scheduler.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index c3d4b9bc76f9..28fc49209aae 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -20,7 +20,7 @@ import torch from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler - +from diffusers.utils import torch_device torch.backends.cuda.matmul.allow_tf32 = False @@ -247,6 +247,20 @@ def test_scheduler_public_api(self): scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) + def test_add_noise_device(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample.to(torch_device) + scaled_sample = scheduler.scale_model_input(sample, 0.0) + self.assertEqual(sample.shape, scaled_sample.shape) + + noise = torch.randn_like(scaled_sample).to(torch_device) + t = torch.tensor([10]).to(torch_device) + noised = scheduler.add_noise(scaled_sample, noise, t) + self.assertEqual(noised.shape, scaled_sample.shape) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) From e6ee09d168ca1167f8df93e0cf763253596d0b64 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 22 Oct 2022 09:11:49 +0200 Subject: [PATCH 2/7] Do not use torch.float64 in mps. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 768413e9e6a3..8b371703ec84 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -252,7 +252,8 @@ def add_noise( ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - self.timesteps = self.timesteps.to(original_samples.device) + dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype + self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) timesteps = timesteps.to(original_samples.device) schedule_timesteps = self.timesteps From 445a6e0d8f42f4fa6c290d79237c1aec47da3c77 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 22 Oct 2022 09:14:11 +0200 Subject: [PATCH 3/7] style --- tests/test_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 28fc49209aae..a726af8385c4 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -22,6 +22,7 @@ from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler from diffusers.utils import torch_device + torch.backends.cuda.matmul.allow_tf32 = False From 56210ad82159db99dbb35569d3f1d01de6a76191 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 26 Oct 2022 11:31:12 +0200 Subject: [PATCH 4/7] Temporarily skip add_noise for IPNDMScheduler. Until #990 is addressed. --- tests/test_scheduler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e9e9a068270e..6f86ffc85e05 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -261,6 +261,9 @@ def test_scheduler_public_api(self): def test_add_noise_device(self): for scheduler_class in self.scheduler_classes: + if scheduler_class == IPNDMScheduler: + # Skip until #990 is addressed + continue scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) From d6afc58a22d6b5560e2a4fc40733f5a203585bc9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 26 Oct 2022 17:19:21 +0200 Subject: [PATCH 5/7] Fix additional float64 error in mps. --- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- tests/test_scheduler.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 8b371703ec84..261ea77d765d 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -254,7 +254,7 @@ def add_noise( self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) - timesteps = timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device, dtype=dtype) schedule_timesteps = self.timesteps diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6f86ffc85e05..8b71b8d5a641 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -276,6 +276,10 @@ def test_add_noise_device(self): noised = scheduler.add_noise(scaled_sample, noise, t) self.assertEqual(noised.shape, scaled_sample.shape) + t = torch.tensor([10.], dtype=torch.float64) + noised = scheduler.add_noise(scaled_sample, noise, t) + self.assertEqual(noised.shape, scaled_sample.shape) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) From cb7be5873b2c81e89d0655dc123015d15a032b39 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 26 Oct 2022 18:08:04 +0200 Subject: [PATCH 6/7] Improve add_noise test --- tests/test_scheduler.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 8b71b8d5a641..194fb66f663f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -266,17 +266,14 @@ def test_add_noise_device(self): continue scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(100) sample = self.dummy_sample.to(torch_device) scaled_sample = scheduler.scale_model_input(sample, 0.0) self.assertEqual(sample.shape, scaled_sample.shape) noise = torch.randn_like(scaled_sample).to(torch_device) - t = torch.tensor([10]).to(torch_device) - noised = scheduler.add_noise(scaled_sample, noise, t) - self.assertEqual(noised.shape, scaled_sample.shape) - - t = torch.tensor([10.], dtype=torch.float64) + t = scheduler.timesteps[5][None] noised = scheduler.add_noise(scaled_sample, noise, t) self.assertEqual(noised.shape, scaled_sample.shape) From 3311a12d2b1efb457d766e3bb9eadcaad2d30b89 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 26 Oct 2022 20:20:32 +0200 Subject: [PATCH 7/7] =?UTF-8?q?Slight=20edit=20=E2=80=93=20I=20think=20it'?= =?UTF-8?q?s=20clearer=20this=20way.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/diffusers/schedulers/scheduling_lms_discrete.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 261ea77d765d..6157f4b4dc65 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -252,9 +252,13 @@ def add_noise( ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype - self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype) - timesteps = timesteps.to(original_samples.device, dtype=dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + self.timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) schedule_timesteps = self.timesteps