diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 49d0e4dbdca2..495667598488 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -265,10 +265,6 @@ def test_dict_tuple_outputs_equivalent(self): max_diff = np.abs(output - output_tuple).max() self.assertLess(max_diff, 1e-4) - @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") - def test_num_inference_steps_consistent(self): - super().test_num_inference_steps_consistent() - @unittest.skipIf(torch_device == "mps", reason="The depth model does not support MPS yet") def test_progress_bar(self): super().test_progress_bar() diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index a1d3122f875c..32f050a51d3c 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -4,7 +4,6 @@ import io import re import tempfile -import time import unittest from typing import Callable, Union @@ -294,36 +293,6 @@ def test_dict_tuple_outputs_equivalent(self): max_diff = np.abs(output - output_tuple).max() self.assertLess(max_diff, 1e-4) - def test_num_inference_steps_consistent(self): - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = pipe(**self.get_dummy_inputs(torch_device)) - - outputs = [] - times = [] - for num_steps in [9, 6, 3]: - inputs = self.get_dummy_inputs(torch_device) - - for arg in self.num_inference_steps_args: - inputs[arg] = num_steps - - start_time = time.time() - output = pipe(**inputs)[0] - inference_time = time.time() - start_time - - outputs.append(output) - times.append(inference_time) - - # check that all outputs have the same shape - self.assertTrue(all(outputs[0].shape == output.shape for output in outputs)) - # check that the inference time increases with the number of inference steps - self.assertTrue(all(times[i] < times[i - 1] for i in range(1, len(times)))) - def test_components_function(self): init_components = self.get_dummy_components() pipe = self.pipeline_class(**init_components)