diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py index f08ba52cfb96..4a7540aa161b 100644 --- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py +++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py @@ -336,7 +336,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 4532b83db857..4bbf9a960d37 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -381,7 +381,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 320b194a0d38..2977ee24a3a7 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -306,7 +306,7 @@ def __getitem__(self, i): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + (h, w,) = ( img.shape[0], img.shape[1], ) diff --git a/scripts/convert_kakao_brain_unclip_to_diffusers.py b/scripts/convert_kakao_brain_unclip_to_diffusers.py index ebb98eb66a8c..59bcd2c684eb 100644 --- a/scripts/convert_kakao_brain_unclip_to_diffusers.py +++ b/scripts/convert_kakao_brain_unclip_to_diffusers.py @@ -564,6 +564,7 @@ def super_res_unet_last_step_original_checkpoint_to_diffusers_checkpoint(model, # unet utils + # .time_embed -> .time_embedding def unet_time_embeddings(checkpoint, original_unet_prefix): diffusers_checkpoint = {} diff --git a/src/diffusers/models/modeling_flax_pytorch_utils.py b/src/diffusers/models/modeling_flax_pytorch_utils.py index 7463b408ed21..e75e0419c4d4 100644 --- a/src/diffusers/models/modeling_flax_pytorch_utils.py +++ b/src/diffusers/models/modeling_flax_pytorch_utils.py @@ -37,6 +37,7 @@ def rename_key(key): # PyTorch => Flax # ##################### + # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 0f35d004a09a..9eb2c829a8bf 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -24,7 +24,7 @@ from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging +from ...utils import is_accelerate_available, logging, torch_randn from .text_proj import UnCLIPTextProjModel @@ -105,11 +105,7 @@ def __init__( def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 0b83407d8ccd..0dbe44b012a0 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -29,7 +29,7 @@ from ...models import UNet2DConditionModel, UNet2DModel from ...pipelines import DiffusionPipeline, ImagePipelineOutput from ...schedulers import UnCLIPScheduler -from ...utils import is_accelerate_available, logging +from ...utils import is_accelerate_available, logging, torch_randn from .text_proj import UnCLIPTextProjModel @@ -113,11 +113,7 @@ def __init__( # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = torch_randn(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 243ca2f009b6..09ffbe2ff104 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -20,7 +20,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import BaseOutput, torch_randn from .scheduling_utils import SchedulerMixin @@ -273,15 +273,9 @@ def step( # 6. Add noise variance = 0 if t > 0: - device = model_output.device - if device.type == "mps": - # randn does not work reproducibly on mps - variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator) - variance_noise = variance_noise.to(device) - else: - variance_noise = torch.randn( - model_output.shape, generator=generator, device=device, dtype=model_output.dtype - ) + variance_noise = torch_randn( + model_output.shape, dtype=model_output.dtype, generator=generator, device=model_output.device + ) variance = self._get_variance( t, diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a729bead4b93..67b95c0b684d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -64,6 +64,7 @@ from .logging import get_logger from .outputs import BaseOutput from .pil_utils import PIL_INTERPOLATION +from .torch_utils import torch_randn if is_torch_available(): diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py new file mode 100644 index 000000000000..8242907bea32 --- /dev/null +++ b/src/diffusers/utils/torch_utils.py @@ -0,0 +1,64 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PyTorch utilities: Utilities related to PyTorch +""" +from typing import List, Optional, Tuple, Union + +from . import logging +from .import_utils import is_torch_available + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def torch_randn( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, +): + """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor + will always be created on CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + if generator is not None: + if generator.device != device and generator.device.type == "cpu": + rand_device = "cpu" + if device != "mps": + logger.info( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif generator.device.type != device.type and generator.device.type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {generator.device.type}.") + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + + return latents diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 670082c20c24..cc67182b34d0 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -382,7 +382,7 @@ def test_unclip_karlo(self): pipeline = pipeline.to(torch_device) pipeline.set_progress_bar_config(disable=None) - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.Generator(device="cpu").manual_seed(0) output = pipeline( "horse", num_images_per_prompt=1, diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index 87ad14146a11..daf84c19b966 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -480,7 +480,7 @@ def test_unclip_image_variation_karlo(self): pipeline.set_progress_bar_config(disable=None) pipeline.enable_sequential_cpu_offload() - generator = torch.Generator(device=torch_device).manual_seed(0) + generator = torch.Generator(device="cpu").manual_seed(0) output = pipeline( input_image, num_images_per_prompt=1, diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 44165d1fce23..a204a155feb1 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -96,6 +96,7 @@ def _inner(x): def sort_objects(objects, key=None): "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." + # If no key is provided, we use a noop. def noop(x): return x @@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement): """ Return the same `import_statement` but with objects properly sorted. """ + # This inner function sort imports between [ ]. def _replace(match): imports = match.groups()[0]