From 70e3de61ba8d1a0e7d29e7b19725b2c3a7df8d89 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 14:37:37 +0100 Subject: [PATCH 01/10] [Deterministic torch randn] Allow tensors to be generated on CPU --- .../pipelines/unclip/pipeline_unclip.py | 8 +-- .../unclip/pipeline_unclip_image_variation.py | 8 +-- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/torch_utils.py | 70 +++++++++++++++++++ 4 files changed, 75 insertions(+), 12 deletions(-) create mode 100644 src/diffusers/utils/torch_utils.py diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 0f35d004a09a..9eb2c829a8bf 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -24,7 +24,7 @@ from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging +from ...utils import is_accelerate_available, logging, torch_randn from .text_proj import UnCLIPTextProjModel @@ -105,11 +105,7 @@ def __init__( def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 0b83407d8ccd..0dbe44b012a0 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -29,7 +29,7 @@ from ...models import UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging +from ...utils import is_accelerate_available, logging, torch_randn from .text_proj import UnCLIPTextProjModel @@ -113,11 +113,7 @@ def __init__( # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a729bead4b93..67b95c0b684d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -64,6 +64,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION +from .torch_utils import torch_randn if is_torch_available(): diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py new file mode 100644 index 000000000000..e7ac521e21a2 --- /dev/null +++ b/src/diffusers/utils/torch_utils.py @@ -0,0 +1,70 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" +import importlib.util +import operator as op +import os +import sys +from collections import OrderedDict +from typing import List, Optional, Tuple, Union + +import torch + +from packaging import version +from packaging.version import Version, parse + +from . import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def torch_randn( + shape: Union[Tuple, List], + generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When +passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor will +always be created on CPU. + """ + # device on which tensor is createad defaults to device + rand_device = device + batch_size = shape[0] + + if generator is not None: + if generator.device != device and generator.device.type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Generator will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif generator.device != device and generator.device.type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.") + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.rand(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + + return latents From b2d85ea4c695e4e7e3505f6bee54dbcc302490b1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 13:54:48 +0000 Subject: [PATCH 02/10] fix more --- .../textual_inversion/textual_inversion_bf16.py | 2 +- examples/textual_inversion/textual_inversion.py | 2 +- .../textual_inversion/textual_inversion_flax.py | 2 +- .../convert_kakao_brain_unclip_to_diffusers.py | 1 + .../models/modeling_flax_pytorch_utils.py | 1 + src/diffusers/schedulers/scheduling_unclip.py | 14 ++++---------- src/diffusers/utils/torch_utils.py | 16 ++++------------ tests/pipelines/unclip/test_unclip.py | 4 +++- .../unclip/test_unclip_image_variation.py | 2 +- utils/custom_init_isort.py | 2 ++ 10 files changed, 19 insertions(+), 27 deletions(-) diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index f08ba52cfb96..4a7540aa161b 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,7 +336,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 4532b83db857..4bbf9a960d37 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -381,7 +381,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 320b194a0d38..2977ee24a3a7 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -306,7 +306,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/scripts/convert_kakao_brain_unclip_to_diffusers.py b/scripts/convert_kakao_brain_unclip_to_diffusers.py index ebb98eb66a8c..59bcd2c684eb 100644 --- a/scripts/convert_kakao_brain_unclip_to_diffusers.py +++ b/scripts/convert_kakao_brain_unclip_to_diffusers.py @@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, # unet utils + # .time_embed -> .time_embedding def unet_time_embeddings(checkpoint, original_unet_prefix): diffusers_checkpoint = {} diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index 7463b408ed21..e75e0419c4d4 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -37,6 +37,7 @@ def rename_key(key): # PyTorch => Flax # ##################### + # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 243ca2f009b6..09ffbe2ff104 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -20,7 +20,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, torch_randn from .scheduling_utils import SchedulerMixin @@ -273,15 +273,9 @@ def step( # 6. Add noise variance = 0 if t > 0: - device = model_output.device - if device.type == "mps": - # randn does not work reproducibly on mps - variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) - variance_noise = variance_noise.to(device) - else: - variance_noise = torch.randn( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype - ) + variance_noise = torch_randn( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) variance = self._get_variance( t, diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index e7ac521e21a2..021ff1a0c54b 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -12,20 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Import utilities: Utilities related to imports and our lazy inits. +PyTorch utilities: Utilities related to PyTorch """ -import importlib.util -import operator as op -import os -import sys -from collections import OrderedDict from typing import List, Optional, Tuple, Union import torch -from packaging import version -from packaging.version import Version, parse - from . import logging @@ -39,8 +31,8 @@ def torch_randn( dtype: Optional[torch.dtype] = None, ): """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When -passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor will -always be created on CPU. + passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor + will always be created on CPU. """ # device on which tensor is createad defaults to device rand_device = device @@ -55,7 +47,7 @@ def torch_randn( f" Generator will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) - elif generator.device != device and generator.device.type == "cuda": + elif generator.device.type != device.type and generator.device.type == "cuda": raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.") if isinstance(generator, list): diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 670082c20c24..39b01703b0b1 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -382,7 +382,7 @@ def test_unclip_karlo(self): pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.Generator(device="cpu").manual_seed(0) output = pipeline( "horse", num_images_per_prompt=1, @@ -392,6 +392,8 @@ def test_unclip_karlo(self): image = output.images[0] + np.save("/home/patrick_huggingface_co/diffusers-images/unclip/karlo_v1_alpha_horse_fp16.npy", image) + assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 87ad14146a11..daf84c19b966 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -480,7 +480,7 @@ def test_unclip_image_variation_karlo(self): pipeline.set_progress_bar_config(disable=None) pipeline.enable_sequential_cpu_offload() - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.Generator(device="cpu").manual_seed(0) output = pipeline( input_image, num_images_per_prompt=1, diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 44165d1fce23..a204a155feb1 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -96,6 +96,7 @@ def _inner(x): def sort_objects(objects, key=None): "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." + # If no key is provided, we use a noop. def noop(x): return x @@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement): """ Return the same `import_statement` but with objects properly sorted. """ + # This inner function sort imports between [ ]. def _replace(match): imports = match.groups()[0] From 63353bf10bc8f9eeaf074c46b18094d55bb1b9f4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 13:57:10 +0000 Subject: [PATCH 03/10] up --- src/diffusers/utils/torch_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 021ff1a0c54b..2a518659c0d7 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -16,10 +16,12 @@ """ from typing import List, Optional, Tuple, Union -import torch - from . import logging +from .import_uitls import is_torch_available + +if is_torch_available(): + import torch logger = logging.get_logger(__name__) # pylint: disable=invalid-name From a6394480f563c3e73bc3365bb23532beddbc7e98 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 14:01:12 +0000 Subject: [PATCH 04/10] fix more --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 2a518659c0d7..cb569563ecd3 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -17,7 +17,7 @@ from typing import List, Optional, Tuple, Union from . import logging -from .import_uitls import is_torch_available +from .import_utils import is_torch_available if is_torch_available(): From 80ba55a2769d2f1148acfba92282f5480d8d2ee1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 14:48:12 +0000 Subject: [PATCH 05/10] up --- src/diffusers/utils/torch_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index cb569563ecd3..83223cda2d61 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -28,9 +28,9 @@ def torch_randn( shape: Union[Tuple, List], - generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, ): """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor From 14794248375b866aa43274a3030fa74ad519bb04 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 16:09:14 +0100 Subject: [PATCH 06/10] Update src/diffusers/utils/torch_utils.py Co-authored-by: Anton Lozhkov --- src/diffusers/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 83223cda2d61..79b64fd3f9b3 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -55,7 +55,7 @@ def torch_randn( if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ - torch.rand(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: From f6db58bbefeacd7fb2819a3fec73b6e1689162c8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 16:09:37 +0100 Subject: [PATCH 07/10] Apply suggestions from code review --- tests/pipelines/unclip/test_unclip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 39b01703b0b1..035db0b8cde4 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -392,7 +392,6 @@ def test_unclip_karlo(self): image = output.images[0] - np.save("/home/patrick_huggingface_co/diffusers-images/unclip/karlo_v1_alpha_horse_fp16.npy", image) assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 From c7a69e72b7dc9c54ad8d20165ce0abc8ba377aff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 16:18:53 +0100 Subject: [PATCH 08/10] up --- src/diffusers/utils/torch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index e7ac521e21a2..db8f644cd81b 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -39,8 +39,8 @@ def torch_randn( dtype: Optional[torch.dtype] = None, ): """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When -passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor will -always be created on CPU. + passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor + will always be created on CPU. """ # device on which tensor is createad defaults to device rand_device = device From b535251b47127258ed086c62426e86d36e00d67b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 16:22:01 +0100 Subject: [PATCH 09/10] up --- tests/pipelines/unclip/test_unclip.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 035db0b8cde4..cc67182b34d0 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -392,7 +392,6 @@ def test_unclip_karlo(self): image = output.images[0] - assert image.shape == (256, 256, 3) assert np.abs(expected_image - image).max() < 1e-2 From 575b74c71b904613d5c3ba4e915f99eddebfc3d4 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 3 Jan 2023 16:49:03 +0100 Subject: [PATCH 10/10] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- src/diffusers/utils/torch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 79b64fd3f9b3..8242907bea32 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -36,7 +36,7 @@ def torch_randn( passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor will always be created on CPU. """ - # device on which tensor is createad defaults to device + # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] @@ -46,7 +46,7 @@ def torch_randn( if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." - f" Generator will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif generator.device.type != device.type and generator.device.type == "cuda":