From 72f1b13edcbf3bcabd1dbd903aa4930b9a1f78db Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 17 Mar 2023 21:44:44 +0100 Subject: [PATCH 01/11] Workaround for saving dynamo-wrapped models. --- src/diffusers/pipelines/pipeline_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6560f305c18e..74a44892537d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -484,6 +484,11 @@ def is_saveable_module(name, value): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ + # Dynamo wraps the original mode and changes the class. + # Is there a principled way to obtain the original class? + if "_dynamo.eval_frame" in str(model_cls): + model_cls = sub_model._orig_mod.__class__ + save_method_name = None # search for the model's base class in LOADABLE_CLASSES for library_name, library_classes in LOADABLE_CLASSES.items(): From 9351bc82abfcf7b32805a94049ff851d2a4877b9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 14:09:17 +0000 Subject: [PATCH 02/11] Accept suggestion from code review Co-authored-by: Patrick von Platen --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 74a44892537d..dae7826888bb 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -486,7 +486,7 @@ def is_saveable_module(name, value): # Dynamo wraps the original mode and changes the class. # Is there a principled way to obtain the original class? - if "_dynamo.eval_frame" in str(model_cls): + if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): model_cls = sub_model._orig_mod.__class__ save_method_name = None From 351b47b5af7350adeee90b040b7b685d1cbb78a8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 15:34:24 +0100 Subject: [PATCH 03/11] Apply workaround when overriding pipeline components. --- src/diffusers/pipelines/pipeline_utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index dae7826888bb..2f286bde0ac8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -255,7 +255,14 @@ def maybe_raise_or_warn( if class_candidate is not None and issubclass(class_obj, class_candidate): expected_class_obj = class_candidate - if not issubclass(passed_class_obj[name].__class__, expected_class_obj): + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. + sub_model = passed_class_obj[name] + model_cls = sub_model.__class__ + if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): + model_cls = sub_model._orig_mod.__class__ + + if not issubclass(model_cls, expected_class_obj): raise ValueError( f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" @@ -484,8 +491,8 @@ def is_saveable_module(name, value): sub_model = getattr(self, pipeline_component_name) model_cls = sub_model.__class__ - # Dynamo wraps the original mode and changes the class. - # Is there a principled way to obtain the original class? + # Dynamo wraps the original model in a private class. + # I didn't find a public API to get the original class. if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): model_cls = sub_model._orig_mod.__class__ From cb76bdee20822693d7c45a3e1eb2e8b07e9e965a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 16:08:46 +0100 Subject: [PATCH 04/11] Ensure the correct config.json is saved to disk. Instead of the dynamo class. --- src/diffusers/pipelines/pipeline_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 2f286bde0ac8..a3e25e1bfaaf 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -426,6 +426,10 @@ def register_modules(self, **kwargs): if module is None: register_dict = {name: (None, None)} else: + # register the original module, not the dynamo compiled one + if is_torch_version(">=", "2.0.0") and isinstance(module, torch._dynamo.eval_frame.OptimizedModule): + module = module._orig_mod + library = module.__module__.split(".")[0] # check if the module is a pipeline module From e990b53dbe9a2a8ebb903fb46a805cae1a6394cb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 16:10:04 +0100 Subject: [PATCH 05/11] Save correct module (not compiled one) --- src/diffusers/pipelines/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index a3e25e1bfaaf..ebf492ed4dc4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -498,7 +498,8 @@ def is_saveable_module(name, value): # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): - model_cls = sub_model._orig_mod.__class__ + sub_model = sub_model._orig_mod + model_cls = sub_model.__class__ save_method_name = None # search for the model's base class in LOADABLE_CLASSES From f803679ef390c3774f2b591a1b7d78394aca4b43 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 16:12:18 +0100 Subject: [PATCH 06/11] Add test --- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/testing_utils.py | 11 ++++++-- tests/test_pipelines.py | 38 +++++++++++++++++++++++++--- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 196b3b0279d0..a46f01701ec6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -85,6 +85,7 @@ nightly, parse_flag_from_env, print_tensor_test, + require_torch_2, require_torch_gpu, skip_mps, slow, diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index cea2869b3193..a1c4335c6f18 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -16,7 +16,7 @@ import requests from packaging import version -from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available +from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available, is_torch_version from .logging import get_logger @@ -151,11 +151,18 @@ def nightly(test_case): def require_torch(test_case): """ - Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. """ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) +def require_torch_2(test_case): + """ + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + """ + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch")(test_case) + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index daf88417227f..7eb22f8133b1 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -54,7 +54,7 @@ logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, nightly, slow, torch_device +from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, is_torch_version, nightly, require_torch_2, slow, torch_device from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu @@ -966,9 +966,41 @@ def test_from_save_pretrained(self): down_block_types=("DownBlock2D", "AttnDownBlock2D"), up_block_types=("AttnUpBlock2D", "UpBlock2D"), ) - schedular = DDPMScheduler(num_train_timesteps=10) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @require_torch_2 + def test_from_save_pretrained_dynamo(self): + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) - ddpm = DDPMPipeline(model, schedular) + ddpm = DDPMPipeline(model, scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) From ed7a625ae3da8d58ecd2e11f99f3cc3de04fcd10 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 16:14:07 +0100 Subject: [PATCH 07/11] style --- src/diffusers/models/cross_attention.py | 4 +--- src/diffusers/utils/testing_utils.py | 12 ++++++++++-- tests/test_pipelines.py | 11 ++++++++++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 91a0dbfa1238..4fdb2acaabed 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -24,9 +24,7 @@ SlicedAttnProcessor, XFormersAttnProcessor, ) -from .attention_processor import ( # noqa: F401 - AttnProcessor as AttnProcessorRename, -) +from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401 deprecate( diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index a1c4335c6f18..7a1e3a6bbe41 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -16,7 +16,13 @@ import requests from packaging import version -from .import_utils import is_compel_available, is_flax_available, is_onnx_available, is_torch_available, is_torch_version +from .import_utils import ( + is_compel_available, + is_flax_available, + is_onnx_available, + is_torch_available, + is_torch_version, +) from .logging import get_logger @@ -160,7 +166,9 @@ def require_torch_2(test_case): """ Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ - return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch")(test_case) + return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( + test_case + ) def require_torch_gpu(test_case): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 7eb22f8133b1..de0ae49adb73 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -54,7 +54,16 @@ logging, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME -from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, is_flax_available, is_torch_version, nightly, require_torch_2, slow, torch_device +from diffusers.utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + floats_tensor, + is_flax_available, + nightly, + require_torch_2, + slow, + torch_device, +) from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu From 0f3a8f0da7816994c34651850c8cd5c8b7a4897d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 18:11:37 +0100 Subject: [PATCH 08/11] fix docstrings --- src/diffusers/utils/testing_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a1e3a6bbe41..fb65f4a2848e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -157,14 +157,14 @@ def nightly(test_case): def require_torch(test_case): """ - Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. + Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. """ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) def require_torch_2(test_case): """ - Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. + Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. """ return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( test_case From 2c4ba90a83e9326d8b66af1679a163737ab0aa5e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 20 Mar 2023 18:31:24 +0100 Subject: [PATCH 09/11] Go back to using string comparisons. PyTorch CPU does not have _dynamo. --- src/diffusers/pipelines/pipeline_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ebf492ed4dc4..c76d877d2b60 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -259,7 +259,7 @@ def maybe_raise_or_warn( # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] model_cls = sub_model.__class__ - if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): + if "_dynamo.eval_frame" in str(model_cls): model_cls = sub_model._orig_mod.__class__ if not issubclass(model_cls, expected_class_obj): @@ -427,7 +427,7 @@ def register_modules(self, **kwargs): register_dict = {name: (None, None)} else: # register the original module, not the dynamo compiled one - if is_torch_version(">=", "2.0.0") and isinstance(module, torch._dynamo.eval_frame.OptimizedModule): + if "_dynamo.eval_frame" in str(module.__class__): module = module._orig_mod library = module.__module__.split(".")[0] @@ -497,7 +497,7 @@ def is_saveable_module(name, value): # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. - if is_torch_version(">=", "2.0.0") and isinstance(sub_model, torch._dynamo.eval_frame.OptimizedModule): + if "_dynamo.eval_frame" in str(model_cls): sub_model = sub_model._orig_mod model_cls = sub_model.__class__ From 4680fafaec06fe1612a1e99760b08565a8849333 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 21 Mar 2023 20:37:18 +0100 Subject: [PATCH 10/11] Simple test for save_pretrained of compiled models. --- tests/test_modeling_common.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e9b7d5f34e82..9599346a6b2e 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -27,6 +27,7 @@ from diffusers.models import ModelMixin, UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torch_gpu class ModelUtilsTest(unittest.TestCase): @@ -169,6 +170,21 @@ def test_from_save_pretrained_variant(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") + @require_torch_gpu + def test_from_save_pretrained_dynamo(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == self.model_class + def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From e41331b77fdcf8ce6bd50bbc5f83321f7d8a9ef8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 25 Mar 2023 19:36:53 +0100 Subject: [PATCH 11/11] Helper function to test whether module is compiled. --- src/diffusers/pipelines/pipeline_utils.py | 7 ++++--- src/diffusers/utils/__init__.py | 2 +- src/diffusers/utils/torch_utils.py | 9 ++++++++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index e024df70776d..40a7a4a2e98e 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,6 +50,7 @@ get_class_from_dynamic_module, is_accelerate_available, is_accelerate_version, + is_compiled_module, is_safetensors_available, is_torch_version, is_transformers_available, @@ -259,7 +260,7 @@ def maybe_raise_or_warn( # I didn't find a public API to get the original class. sub_model = passed_class_obj[name] model_cls = sub_model.__class__ - if "_dynamo.eval_frame" in str(model_cls): + if is_compiled_module(sub_model): model_cls = sub_model._orig_mod.__class__ if not issubclass(model_cls, expected_class_obj): @@ -427,7 +428,7 @@ def register_modules(self, **kwargs): register_dict = {name: (None, None)} else: # register the original module, not the dynamo compiled one - if "_dynamo.eval_frame" in str(module.__class__): + if is_compiled_module(module): module = module._orig_mod library = module.__module__.split(".")[0] @@ -497,7 +498,7 @@ def is_saveable_module(name, value): # Dynamo wraps the original model in a private class. # I didn't find a public API to get the original class. - if "_dynamo.eval_frame" in str(model_cls): + if is_compiled_module(sub_model): sub_model = sub_model._orig_mod model_cls = sub_model.__class__ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 0c3f8f9ba1d9..30d559d19244 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -73,7 +73,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION -from .torch_utils import randn_tensor +from .torch_utils import is_compiled_module, randn_tensor if is_torch_available(): diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 113e64c16bac..b9815cbceede 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union from . import logging -from .import_utils import is_torch_available +from .import_utils import is_torch_available, is_torch_version if is_torch_available(): @@ -68,3 +68,10 @@ def randn_tensor( latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) return latents + + +def is_compiled_module(module): + """Check whether the module was compiled with torch.compile()""" + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)