From 9e868e96591fb49f71757bb3849f68f596cfd273 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 28 Nov 2022 11:41:40 +0000 Subject: [PATCH 1/5] Add test --- src/diffusers/modeling_utils.py | 15 +++++++++++++++ tests/test_modeling_common.py | 18 +++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 5f79e7fe0155..bfcba2916a6b 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -472,6 +472,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file) + dtype = set(v.dtype for v in state_dict.values()) + + if len(dtype) > 1 and torch.float32 not in dtype: + raise ValueError( + f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please" + f" make sure that {model_file} weights have only one dtype." + ) + elif len(dtype) > 1 and torch.float32 in dtype: + dtype = torch.float32 + else: + dtype = dtype.pop() + + # move model to correct dtype + model = model.to(dtype) + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cad1887f4df8..b038fa16bb74 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,7 +27,7 @@ class ModelTesterMixin: - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -57,6 +57,22 @@ def test_from_pretrained_save_pretrained(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + def test_from_save_pretrained_dtype(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + with tempfile.TemporaryDirectory() as tmpdirname: + model.to(dtype) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True) + assert new_model.dtype == dtype + new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False) + assert new_model.dtype == dtype + def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) From 1968baf88cda16d5f0fea417f14871355d8f3eb6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Nov 2022 11:12:13 +0000 Subject: [PATCH 2/5] up --- tests/models/test_models_unet_1d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 089d935651a5..13ea68c2ccf5 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -64,7 +64,7 @@ def test_outputs_equivalence(self): @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): @@ -184,7 +184,7 @@ def test_outputs_equivalence(self): @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_from_pretrained_save_pretrained(self): - super().test_from_pretrained_save_pretrained() + super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_model_from_pretrained(self): From 9ee2357673cd6aeef99617eab8b6c0a3f647c020 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Nov 2022 11:25:38 +0000 Subject: [PATCH 3/5] no bfloat16 for mps --- tests/test_modeling_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index b038fa16bb74..5fc3e2bfb263 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -65,6 +65,8 @@ def test_from_save_pretrained_dtype(self): model.eval() for dtype in [torch.float32, torch.float16, torch.bfloat16]: + if torch_device == "mps" and dtype == torch.bfloat16: + pass with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) From 58b8933f1ea9b8e9094b0b4f9ba9c4118d33e046 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 29 Nov 2022 12:20:56 +0000 Subject: [PATCH 4/5] fix --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5fc3e2bfb263..68ab914b4209 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -66,7 +66,7 @@ def test_from_save_pretrained_dtype(self): for dtype in [torch.float32, torch.float16, torch.bfloat16]: if torch_device == "mps" and dtype == torch.bfloat16: - pass + continue with tempfile.TemporaryDirectory() as tmpdirname: model.to(dtype) model.save_pretrained(tmpdirname) From 8041033c073dc9d484f9a0a7bb5019e8977fa7ac Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 30 Nov 2022 11:20:10 +0100 Subject: [PATCH 5/5] rename test --- tests/models/test_models_unet_1d.py | 4 ++-- .../versatile_diffusion/test_versatile_diffusion_mega.py | 2 +- tests/test_pipelines.py | 2 +- tests/test_scheduler.py | 8 ++++---- tests/test_scheduler_flax.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 13ea68c2ccf5..b494c231b5fe 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -63,7 +63,7 @@ def test_outputs_equivalence(self): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") @@ -183,7 +183,7 @@ def test_outputs_equivalence(self): super().test_outputs_equivalence() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): super().test_from_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") diff --git a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py index ab4580dae1fe..ad24ec01f633 100644 --- a/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py +++ b/tests/pipelines/versatile_diffusion/test_versatile_diffusion_mega.py @@ -42,7 +42,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pipe = VersatileDiffusionPipeline.from_pretrained("shi-labs/versatile-diffusion", torch_dtype=torch.float16) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 033f363ff41f..ec44b69cb1f9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -656,7 +656,7 @@ def test_warning_unused_kwargs(self): assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): # 1. Load models model = UNet2DModel( block_out_channels=(32, 64), diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6a76581632ad..0243e8840522 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -333,7 +333,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -860,7 +860,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1037,7 +1037,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): @@ -1717,7 +1717,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def check_over_forward(self, time_step=0, **forward_kwargs): diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 5ada689b724d..da1042f3d698 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -126,7 +126,7 @@ def check_over_forward(self, time_step=0, **forward_kwargs): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -408,7 +408,7 @@ def check_over_configs(self, time_step=0, **config): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): kwargs = dict(self.forward_default_kwargs) num_inference_steps = kwargs.pop("num_inference_steps", None) @@ -690,7 +690,7 @@ def check_over_configs(self, time_step=0, **config): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_from_pretrained_save_pretrained(self): + def test_from_save_pretrained(self): pass def test_scheduler_outputs_equivalence(self):