From 55f63f705f9391c6facd8213ce7b87ab3ffdc102 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 10 Sep 2023 12:08:46 -0700 Subject: [PATCH 1/2] =?UTF-8?q?Revert=20"Temp=20Revert=20"[Core]=20better?= =?UTF-8?q?=20support=20offloading=20when=20side=20loading=20is=20enabled?= =?UTF-8?q?=E2=80=A6=20(#4927)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 2ab170499eaaf7adfa24a80e0e2717c916f598f1. --- src/diffusers/loaders.py | 43 ++++++++++++++ .../pipeline_controlnet_inpaint_sd_xl.py | 26 +++++++++ .../controlnet/pipeline_controlnet_sd_xl.py | 26 +++++++++ .../pipeline_stable_diffusion_xl.py | 26 +++++++++ .../pipeline_stable_diffusion_xl_img2img.py | 26 +++++++++ .../pipeline_stable_diffusion_xl_inpaint.py | 26 +++++++++ tests/models/test_lora_layers.py | 56 ++++++++++++++++++- .../stable_diffusion/test_stable_diffusion.py | 50 +++++++++++++++++ 8 files changed, 278 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e4d8f1a3a5eb..1de899cad927 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -45,6 +45,7 @@ if is_accelerate_available(): from accelerate import init_empty_weights + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) @@ -778,6 +779,21 @@ def load_textual_inversion( f" `{self.load_textual_inversion.__name__}`" ) + # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -931,6 +947,12 @@ def load_textual_inversion( for token_id, embedding in token_ids_and_embeddings: text_encoder.get_input_embeddings().weight.data[token_id] = embedding + # offload back + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + class LoraLoaderMixin: r""" @@ -962,6 +984,21 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ + # Remove any existing hooks. + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recurive = False + for _, component in self.components.items(): + if isinstance(component, nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recurive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recurive) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_text_encoder( @@ -971,6 +1008,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod def lora_state_dict( cls, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index b20d1f0c636e..c64204501b97 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -1549,6 +1549,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1576,6 +1596,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 6f2b36ba6976..ef6b54e81548 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1216,6 +1216,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1243,6 +1263,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 459b47de7ea1..7b7755085ed6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -922,6 +922,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -949,6 +969,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod def save_lora_weights( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index b9e2b263b893..04902234d54e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -1072,6 +1072,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1099,6 +1119,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 0b00e31a0a50..1d86dff702ef 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -1392,6 +1392,26 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) state_dict, network_alphas = self.lora_state_dict( pretrained_model_name_or_path_or_dict, unet_config=self.unet.config, @@ -1419,6 +1439,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di lora_scale=self.lora_scale, ) + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + @classmethod # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights def save_lora_weights( diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6b3498ae5f41..c49ea7f2d960 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -1081,6 +1081,42 @@ def test_a1111(self): self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_a1111_with_model_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_model_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + + def test_a1111_with_sequential_cpu_offload(self): + generator = torch.Generator().manual_seed(0) + + pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/Counterfeit-V2.5", safety_checker=None) + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/civitai-light-shadow-lora" + lora_filename = "light_and_shadow.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) + def test_kohya_sd_v15_with_higher_dimensions(self): generator = torch.Generator().manual_seed(0) @@ -1257,10 +1293,10 @@ def test_sdxl_1_0_lora(self): generator = torch.Generator().manual_seed(0) pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() lora_model_id = "hf-internal-testing/sdxl-1.0-lora" lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.enable_model_cpu_offload() images = pipe( "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 @@ -1411,3 +1447,21 @@ def test_sdxl_1_0_fuse_unfuse_all(self): assert state_dicts_almost_equal(text_encoder_1_sd, pipe.text_encoder.state_dict()) assert state_dicts_almost_equal(text_encoder_2_sd, pipe.text_encoder_2.state_dict()) assert state_dicts_almost_equal(unet_sd, pipe.unet.state_dict()) + + def test_sdxl_1_0_lora_with_sequential_cpu_offloading(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_sequential_cpu_offload() + lora_model_id = "hf-internal-testing/sdxl-1.0-lora" + lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe( + "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + ).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3)) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 7935a63eceaa..31de557a0ac3 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -1019,6 +1019,56 @@ def test_stable_diffusion_textual_inversion(self): max_diff = np.abs(expected_image - image).max() assert max_diff < 8e-1 + def test_stable_diffusion_textual_inversion_with_model_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_model_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + + def test_stable_diffusion_textual_inversion_with_sequential_cpu_offload(self): + pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") + pipe.enable_sequential_cpu_offload() + pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons") + + a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt") + a111_file_neg = hf_hub_download( + "hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt" + ) + pipe.load_textual_inversion(a111_file) + pipe.load_textual_inversion(a111_file_neg) + + generator = torch.Generator(device="cpu").manual_seed(1) + + prompt = "An logo of a turtle in strong Style-Winter with " + neg_prompt = "Style-Winter-neg" + + image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0] + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy" + ) + + max_diff = np.abs(expected_image - image).max() + assert max_diff < 8e-1 + @require_torch_2 def test_stable_diffusion_compile(self): seed = 0 From e11219aad1650d8e8600ff9a85245f331cf55ab1 Mon Sep 17 00:00:00 2001 From: William Berman Date: Sun, 10 Sep 2023 12:20:05 -0700 Subject: [PATCH 2/2] tests: install accelerate from main --- .github/workflows/pr_tests.yml | 1 + .github/workflows/push_tests.yml | 1 + .github/workflows/push_tests_mps.yml | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index defd418edcef..42c0c8e42252 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -67,6 +67,7 @@ jobs: run: | apt-get update && apt-get install libsndfile1-dev libgl1 -y python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate.git - name: Environment run: | diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 5ec8dbdc4026..a13519ec5876 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -63,6 +63,7 @@ jobs: run: | apt-get update && apt-get install libsndfile1-dev libgl1 -y python -m pip install -e .[quality,test] + python -m pip install git+https://github.com/huggingface/accelerate.git - name: Environment run: | diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index 6b95815f1ea5..c92aa6426d55 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -40,7 +40,7 @@ jobs: ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] ${CONDA_RUN} python -m pip install torch torchvision torchaudio - ${CONDA_RUN} python -m pip install accelerate --upgrade + ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate.git ${CONDA_RUN} python -m pip install transformers --upgrade - name: Environment