From d5a904d3d13629d91601d820e8c78f04efd3e096 Mon Sep 17 00:00:00 2001 From: statelesshz <3140102143@zju.edu.cn> Date: Fri, 21 Jul 2023 14:43:51 +0800 Subject: [PATCH 1/2] make enable_sequential_cpu_offload more generic for third-party devices --- src/diffusers/pipelines/pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ad52c6ac1c59..d0043efe3928 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1127,7 +1127,9 @@ def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.dev if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + device_mod = getattr(torch, self.device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module): From c3b66d085df89473724d87a2c8617f35b5f0461b Mon Sep 17 00:00:00 2001 From: statelesshz <3140102143@zju.edu.cn> Date: Fri, 21 Jul 2023 15:54:02 +0800 Subject: [PATCH 2/2] make style --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index d0043efe3928..3d827596d508 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1129,7 +1129,7 @@ def enable_sequential_cpu_offload(self, gpu_id: int = 0, device: Union[torch.dev self.to("cpu", silence_dtype_warnings=True) device_mod = getattr(torch, self.device.type, None) if hasattr(device_mod, "empty_cache") and device_mod.is_available(): - device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) for name, model in self.components.items(): if not isinstance(model, torch.nn.Module):