diff --git a/README.md b/README.md
index c9e6c3bb1..fc92df556 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,59 @@
+# Genetic Stable Diffusion.
+
+This fork of Stable Diffusion uses genetic stuff and a graphical user interface.
+It also works in many languages (tested: French and German, should be ok for many more).
+It generates many images.
+It should work directly on Mac M1.
+It should be easy to adap to a machine with GPU.
+Without GPU it will be more complicated.
+Ping us at the Nevergrad user group if you need help, I'll do my best.
+
+[**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
+[**Doc**](https://docs.google.com/document/d/12Bz095QNuo_ojxSlGENXKL5Law75IUx5_Nm5L5guKgo/edit?usp=sharing)
+
+
+
+## Get a HuggingFace token! This is a fork of HuggingFace's stablediffusion.
+
+Just click here and copy-paste your token:
+[**Hugging face tokens**](https://huggingface.co/login?next=%2Fsettings%2Ftokens)
+
+## Install StableDiffusion as usual, plus a few more stuff. Basically:
+
+You need homebrew.
+On a Mac, you need to do special stuff for the MPS: we recommend
+[**This page**](https://towardsdatascience.com/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1)
+
+You need to open a terminal. Then:
+```
+mkdir stablediffusion
+cd stablediffusion
+git clone git@github.com:teytaud/genetic-stable-diffusion.git .
+brew install wget
+conda env create -f environment.yaml
+conda activate ldm # you can change that name in the environment.yaml file...
+conda install pytorch torchvision -c pytorch
+pip install transformers diffusers invisible-watermark
+pip install -e .
+pip install pygame
+pip install joblib
+pip install pyttsx3
+pip install einops
+pip install webbrowser
+pip install pyfiglet
+pip install nevergrad
+pip install langdetect
+pip install deep-translator
+pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
+wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth -P weights
+wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
+```
+
+## Then run << python geneticsd.py >>.
+You should be asked for a prompt (just <> if you like the proposed hardcoded prompt), and then a window should be opened.
+
+## Send feedback to [**Nevergrad Users**](https://www.facebook.com/groups/nevergradusers/)
+
# Stable Diffusion
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
diff --git a/archimulti_minisd.sh b/archimulti_minisd.sh
new file mode 100755
index 000000000..e730b58d7
--- /dev/null
+++ b/archimulti_minisd.sh
@@ -0,0 +1 @@
+echo deprecated.
diff --git a/diffusers/__init__.py b/diffusers/__init__.py
new file mode 100644
index 000000000..bf2f183c9
--- /dev/null
+++ b/diffusers/__init__.py
@@ -0,0 +1,60 @@
+from .utils import (
+ is_inflect_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_transformers_available,
+ is_unidecode_available,
+)
+
+
+__version__ = "0.3.0"
+
+from .configuration_utils import ConfigMixin
+from .modeling_utils import ModelMixin
+from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from .onnx_utils import OnnxRuntimeModel
+from .optimization import (
+ get_constant_schedule,
+ get_constant_schedule_with_warmup,
+ get_cosine_schedule_with_warmup,
+ get_cosine_with_hard_restarts_schedule_with_warmup,
+ get_linear_schedule_with_warmup,
+ get_polynomial_decay_schedule_with_warmup,
+ get_scheduler,
+)
+from .pipeline_utils import DiffusionPipeline
+from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
+from .schedulers import (
+ DDIMScheduler,
+ DDPMScheduler,
+ KarrasVeScheduler,
+ PNDMScheduler,
+ SchedulerMixin,
+ ScoreSdeVeScheduler,
+)
+from .utils import logging
+
+
+if is_scipy_available():
+ from .schedulers import LMSDiscreteScheduler
+else:
+ from .utils.dummy_scipy_objects import * # noqa F403
+
+from .training_utils import EMAModel
+
+
+if is_transformers_available():
+ from .pipelines import (
+ LDMTextToImagePipeline,
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+else:
+ from .utils.dummy_transformers_objects import * # noqa F403
+
+
+if is_transformers_available() and is_onnx_available():
+ from .pipelines import StableDiffusionOnnxPipeline
+else:
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
diff --git a/diffusers/commands/__init__.py b/diffusers/commands/__init__.py
new file mode 100644
index 000000000..902bd46ce
--- /dev/null
+++ b/diffusers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseDiffusersCLICommand(ABC):
+ @staticmethod
+ @abstractmethod
+ def register_subcommand(parser: ArgumentParser):
+ raise NotImplementedError()
+
+ @abstractmethod
+ def run(self):
+ raise NotImplementedError()
diff --git a/diffusers/commands/diffusers_cli.py b/diffusers/commands/diffusers_cli.py
new file mode 100644
index 000000000..30084e55b
--- /dev/null
+++ b/diffusers/commands/diffusers_cli.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from argparse import ArgumentParser
+
+from .env import EnvironmentCommand
+
+
+def main():
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli []")
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
+
+ # Register commands
+ EnvironmentCommand.register_subcommand(commands_parser)
+
+ # Let's go
+ args = parser.parse_args()
+
+ if not hasattr(args, "func"):
+ parser.print_help()
+ exit(1)
+
+ # Run
+ service = args.func(args)
+ service.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/diffusers/commands/env.py b/diffusers/commands/env.py
new file mode 100644
index 000000000..81a878bff
--- /dev/null
+++ b/diffusers/commands/env.py
@@ -0,0 +1,70 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+from argparse import ArgumentParser
+
+import huggingface_hub
+
+from .. import __version__ as version
+from ..utils import is_torch_available, is_transformers_available
+from . import BaseDiffusersCLICommand
+
+
+def info_command_factory(_):
+ return EnvironmentCommand()
+
+
+class EnvironmentCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ download_parser = parser.add_parser("env")
+ download_parser.set_defaults(func=info_command_factory)
+
+ def run(self):
+ hub_version = huggingface_hub.__version__
+
+ pt_version = "not installed"
+ pt_cuda_available = "NA"
+ if is_torch_available():
+ import torch
+
+ pt_version = torch.__version__
+ pt_cuda_available = torch.cuda.is_available()
+
+ transformers_version = "not installed"
+ if is_transformers_available:
+ import transformers
+
+ transformers_version = transformers.__version__
+
+ info = {
+ "`diffusers` version": version,
+ "Platform": platform.platform(),
+ "Python version": platform.python_version(),
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
+ "Huggingface_hub version": hub_version,
+ "Transformers version": transformers_version,
+ "Using GPU in script?": "",
+ "Using distributed or parallel set-up in script?": "",
+ }
+
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
+ print(self.format_dict(info))
+
+ return info
+
+ @staticmethod
+ def format_dict(d):
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diff --git a/diffusers/configuration_utils.py b/diffusers/configuration_utils.py
new file mode 100644
index 000000000..fbe75f3f1
--- /dev/null
+++ b/diffusers/configuration_utils.py
@@ -0,0 +1,403 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" ConfigMixinuration base class and utilities."""
+import functools
+import inspect
+import json
+import os
+import re
+from collections import OrderedDict
+from typing import Any, Dict, Tuple, Union
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from . import __version__
+from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+logger = logging.get_logger(__name__)
+
+_re_configuration_file = re.compile(r"config\.(.*)\.json")
+
+
+class ConfigMixin:
+ r"""
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
+ - [`~ConfigMixin.from_config`]
+ - [`~ConfigMixin.save_config`]
+
+ Class attributes:
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
+ [`~ConfigMixin.save_config`] (should be overriden by parent class).
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
+ overriden by parent class).
+ """
+ config_name = None
+ ignore_for_config = []
+
+ def register_to_config(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ kwargs["_class_name"] = self.__class__.__name__
+ kwargs["_diffusers_version"] = __version__
+
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ if not hasattr(self, "_internal_dict"):
+ internal_dict = kwargs
+ else:
+ previous_dict = dict(self._internal_dict)
+ internal_dict = {**self._internal_dict, **kwargs}
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
+
+ self._internal_dict = FrozenDict(internal_dict)
+
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~ConfigMixin.from_config`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
+
+ @classmethod
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
+ r"""
+ Instantiate a Python class from a pre-defined JSON-file.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
+ organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+ checkpoint with 3 labels).
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+ use this method in a firewalled environment.
+
+
+
+ """
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
+
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
+
+ model = cls(**init_dict)
+
+ if return_unused_kwargs:
+ return model, unused_kwargs
+ else:
+ return model
+
+ @classmethod
+ def get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "config"}
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ ):
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
+ " login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ return config_dict
+
+ @classmethod
+ def extract_init_dict(cls, config_dict, **kwargs):
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
+ expected_keys.remove("self")
+ # remove general kwargs if present in dict
+ if "kwargs" in expected_keys:
+ expected_keys.remove("kwargs")
+ # remove keys to be ignored
+ if len(cls.ignore_for_config) > 0:
+ expected_keys = expected_keys - set(cls.ignore_for_config)
+ init_dict = {}
+ for key in expected_keys:
+ if key in kwargs:
+ # overwrite key
+ init_dict[key] = kwargs.pop(key)
+ elif key in config_dict:
+ # use value from config dict
+ init_dict[key] = config_dict.pop(key)
+
+ unused_kwargs = config_dict.update(kwargs)
+
+ passed_keys = set(init_dict.keys())
+ if len(expected_keys - passed_keys) > 0:
+ logger.warning(
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
+ )
+
+ return init_dict, unused_kwargs
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @property
+ def config(self) -> Dict[str, Any]:
+ return self._internal_dict
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+
+class FrozenDict(OrderedDict):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ for key, value in self.items():
+ setattr(self, key, value)
+
+ self.__frozen = True
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __setattr__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if hasattr(self, "__frozen") and self.__frozen:
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
+ super().__setitem__(name, value)
+
+
+def register_to_config(init):
+ r"""
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
+
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
+ """
+
+ @functools.wraps(init)
+ def inner_init(self, *args, **kwargs):
+ # Ignore private kwargs in the init.
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
+ init(self, *args, **init_kwargs)
+ if not isinstance(self, ConfigMixin):
+ raise RuntimeError(
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
+ "not inherit from `ConfigMixin`."
+ )
+
+ ignore = getattr(self, "ignore_for_config", [])
+ # Get positional arguments aligned with kwargs
+ new_kwargs = {}
+ signature = inspect.signature(init)
+ parameters = {
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
+ }
+ for arg, name in zip(args, parameters.keys()):
+ new_kwargs[name] = arg
+
+ # Then add all kwargs
+ new_kwargs.update(
+ {
+ k: init_kwargs.get(k, default)
+ for k, default in parameters.items()
+ if k not in ignore and k not in new_kwargs
+ }
+ )
+ getattr(self, "register_to_config")(**new_kwargs)
+
+ return inner_init
diff --git a/diffusers/dependency_versions_check.py b/diffusers/dependency_versions_check.py
new file mode 100644
index 000000000..bbf863222
--- /dev/null
+++ b/diffusers/dependency_versions_check.py
@@ -0,0 +1,47 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+
+from .dependency_versions_table import deps
+from .utils.versions import require_version, require_version_core
+
+
+# define which module versions we always want to check at run time
+# (usually the ones defined in `install_requires` in setup.py)
+#
+# order specific notes:
+# - tqdm must be checked before tokenizers
+
+pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
+if sys.version_info < (3, 7):
+ pkgs_to_check_at_runtime.append("dataclasses")
+if sys.version_info < (3, 8):
+ pkgs_to_check_at_runtime.append("importlib_metadata")
+
+for pkg in pkgs_to_check_at_runtime:
+ if pkg in deps:
+ if pkg == "tokenizers":
+ # must be loaded here, or else tqdm check may fail
+ from .utils import is_tokenizers_available
+
+ if not is_tokenizers_available():
+ continue # not required, check version only if installed
+
+ require_version_core(deps[pkg])
+ else:
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
+
+
+def dep_version_check(pkg, hint=None):
+ require_version(deps[pkg], hint)
diff --git a/diffusers/dependency_versions_table.py b/diffusers/dependency_versions_table.py
new file mode 100644
index 000000000..74c5331e5
--- /dev/null
+++ b/diffusers/dependency_versions_table.py
@@ -0,0 +1,26 @@
+# THIS FILE HAS BEEN AUTOGENERATED. To update:
+# 1. modify the `_deps` dict in setup.py
+# 2. run `make deps_table_update``
+deps = {
+ "Pillow": "Pillow",
+ "accelerate": "accelerate>=0.11.0",
+ "black": "black==22.3",
+ "datasets": "datasets",
+ "filelock": "filelock",
+ "flake8": "flake8>=3.8.3",
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
+ "huggingface-hub": "huggingface-hub>=0.8.1",
+ "importlib_metadata": "importlib_metadata",
+ "isort": "isort>=5.5.4",
+ "modelcards": "modelcards==0.1.4",
+ "numpy": "numpy",
+ "pytest": "pytest",
+ "pytest-timeout": "pytest-timeout",
+ "pytest-xdist": "pytest-xdist",
+ "scipy": "scipy",
+ "regex": "regex!=2019.12.17",
+ "requests": "requests",
+ "tensorboard": "tensorboard",
+ "torch": "torch>=1.4",
+ "transformers": "transformers>=4.21.0",
+}
diff --git a/diffusers/dynamic_modules_utils.py b/diffusers/dynamic_modules_utils.py
new file mode 100644
index 000000000..0ebf916e7
--- /dev/null
+++ b/diffusers/dynamic_modules_utils.py
@@ -0,0 +1,335 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities to dynamically load objects from the Hub."""
+
+import importlib
+import os
+import re
+import shutil
+import sys
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from huggingface_hub import cached_download
+
+from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def init_hf_modules():
+ """
+ Creates the cache directory for modules with an init, and adds it to the Python path.
+ """
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
+ if HF_MODULES_CACHE in sys.path:
+ return
+
+ sys.path.append(HF_MODULES_CACHE)
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def create_dynamic_module(name: Union[str, os.PathLike]):
+ """
+ Creates a dynamic module in the cache directory for modules.
+ """
+ init_hf_modules()
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
+ # If the parent module does not exist yet, recursively create it.
+ if not dynamic_module_path.parent.exists():
+ create_dynamic_module(dynamic_module_path.parent)
+ os.makedirs(dynamic_module_path, exist_ok=True)
+ init_path = dynamic_module_path / "__init__.py"
+ if not init_path.exists():
+ init_path.touch()
+
+
+def get_relative_imports(module_file):
+ """
+ Get the list of modules that are relatively imported in a module file.
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ with open(module_file, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import .xxx`
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from .xxx import yyy`
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
+ # Unique-ify
+ return list(set(relative_imports))
+
+
+def get_relative_import_files(module_file):
+ """
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
+ imports (if a imports b and b imports c, it will return module files for b and c).
+
+ Args:
+ module_file (`str` or `os.PathLike`): The module file to inspect.
+ """
+ no_change = False
+ files_to_check = [module_file]
+ all_relative_imports = []
+
+ # Let's recurse through all relative imports
+ while not no_change:
+ new_imports = []
+ for f in files_to_check:
+ new_imports.extend(get_relative_imports(f))
+
+ module_path = Path(module_file).parent
+ new_import_files = [str(module_path / m) for m in new_imports]
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
+ files_to_check = [f"{f}.py" for f in new_import_files]
+
+ no_change = len(new_import_files) == 0
+ all_relative_imports.extend(files_to_check)
+
+ return all_relative_imports
+
+
+def check_imports(filename):
+ """
+ Check if the current Python environment contains all the libraries that are imported in a file.
+ """
+ with open(filename, "r", encoding="utf-8") as f:
+ content = f.read()
+
+ # Imports of the form `import xxx`
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
+ # Imports of the form `from xxx import yyy`
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
+ # Only keep the top-level module
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
+
+ # Unique-ify and test we got them all
+ imports = list(set(imports))
+ missing_packages = []
+ for imp in imports:
+ try:
+ importlib.import_module(imp)
+ except ImportError:
+ missing_packages.append(imp)
+
+ if len(missing_packages) > 0:
+ raise ImportError(
+ "This modeling file requires the following packages that were not found in your environment: "
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
+ )
+
+ return get_relative_imports(filename)
+
+
+def get_class_in_module(class_name, module_path):
+ """
+ Import a module on the cache directory for modules and extract a class from it.
+ """
+ module_path = module_path.replace(os.path.sep, ".")
+ module = importlib.import_module(module_path)
+ return getattr(module, class_name)
+
+
+def get_cached_module_file(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+):
+ """
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
+ Transformers module.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `str`: The path to the module inside the cache.
+ """
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
+ submodule = "local"
+
+ if os.path.isfile(module_file_or_url):
+ resolved_module_file = module_file_or_url
+ else:
+ try:
+ # Load from URL or cache if already cached
+ resolved_module_file = cached_download(
+ module_file_or_url,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ )
+
+ except EnvironmentError:
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
+ raise
+
+ # Check we have all the requirements in our environment
+ modules_needed = check_imports(resolved_module_file)
+
+ # Now we move the module inside our cached dynamic modules.
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
+ create_dynamic_module(full_submodule)
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
+ # that hash, to only copy when there is a modification but it seems overkill for now).
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
+ shutil.copy(resolved_module_file, submodule_path / module_file)
+ for module_needed in modules_needed:
+ module_needed = f"{module_needed}.py"
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
+ return os.path.join(full_submodule, module_file)
+
+
+def get_class_from_dynamic_module(
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ module_file: str,
+ class_name: str,
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ resume_download: bool = False,
+ proxies: Optional[Dict[str, str]] = None,
+ use_auth_token: Optional[Union[bool, str]] = None,
+ revision: Optional[str] = None,
+ local_files_only: bool = False,
+ **kwargs,
+):
+ """
+ Extracts a class from a module file, present in the local folder or repository of a model.
+
+
+
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
+ therefore only be called on trusted repos.
+
+
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
+ - a path to a *directory* containing a configuration file saved using the
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+
+ module_file (`str`):
+ The name of the module file containing the class to look for.
+ class_name (`str`):
+ The name of the class to import in the module.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
+ cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
+ exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+ use_auth_token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `transformers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ If `True`, will only try to load the tokenizer configuration from local files.
+
+
+
+ Passing `use_auth_token=True` is required when you want to use a private model.
+
+
+
+ Returns:
+ `type`: The class, dynamically imported from the module.
+
+ Examples:
+
+ ```python
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
+ # module.
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
+ ```"""
+ # And lastly we get the class inside our newly created module
+ final_module = get_cached_module_file(
+ pretrained_model_name_or_path,
+ module_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ local_files_only=local_files_only,
+ )
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diff --git a/diffusers/hub_utils.py b/diffusers/hub_utils.py
new file mode 100644
index 000000000..c07329e36
--- /dev/null
+++ b/diffusers/hub_utils.py
@@ -0,0 +1,197 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional
+
+from huggingface_hub import HfFolder, Repository, whoami
+
+from .pipeline_utils import DiffusionPipeline
+from .utils import is_modelcards_available, logging
+
+
+if is_modelcards_available():
+ from modelcards import CardData, ModelCard
+
+
+logger = logging.get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+def init_git_repo(args, at_init: bool = False):
+ """
+ Args:
+ Initializes a git repo in `args.hub_model_id`.
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
+ """
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ use_auth_token = True if hub_token is None else hub_token
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ repo_name = Path(args.output_dir).absolute().name
+ else:
+ repo_name = args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
+
+ try:
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ private=args.hub_private_repo,
+ )
+ except EnvironmentError:
+ if args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(args.output_dir)
+ repo = Repository(
+ args.output_dir,
+ clone_from=repo_name,
+ use_auth_token=use_auth_token,
+ )
+ else:
+ raise
+
+ repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ return repo
+
+
+def push_to_hub(
+ args,
+ pipeline: DiffusionPipeline,
+ repo: Repository,
+ commit_message: Optional[str] = "End of training",
+ blocking: bool = True,
+ **kwargs,
+) -> str:
+ """
+ Parameters:
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`create_model_card`].
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
+ commit and an object to track the progress of the commit if `blocking=True`
+ """
+
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
+ model_name = Path(args.output_dir).name
+ else:
+ model_name = args.hub_model_id.split("/")[-1]
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
+ pipeline.save_pretrained(output_dir)
+
+ # Only push from one node.
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if (
+ blocking
+ and len(repo.command_queue) > 0
+ and repo.command_queue[-1] is not None
+ and not repo.command_queue[-1].is_done
+ ):
+ repo.command_queue[-1]._process.kill()
+
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
+ # push separately the model card to be independent from the rest of the model
+ create_model_card(args, model_name=model_name)
+ try:
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+
+def create_model_card(args, model_name):
+ if not is_modelcards_available:
+ raise ValueError(
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
+ " install the package with `pip install modelcards`."
+ )
+
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
+ return
+
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
+ repo_name = get_full_repo_name(model_name, token=hub_token)
+
+ model_card = ModelCard.from_template(
+ card_data=CardData( # Card metadata object that will be converted to YAML block
+ language="en",
+ license="apache-2.0",
+ library_name="diffusers",
+ tags=[],
+ datasets=args.dataset_name,
+ metrics=[],
+ ),
+ template_path=MODEL_CARD_TEMPLATE_PATH,
+ model_name=model_name,
+ repo_name=repo_name,
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
+ learning_rate=args.learning_rate,
+ train_batch_size=args.train_batch_size,
+ eval_batch_size=args.eval_batch_size,
+ gradient_accumulation_steps=args.gradient_accumulation_steps
+ if hasattr(args, "gradient_accumulation_steps")
+ else None,
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
+ mixed_precision=args.mixed_precision,
+ )
+
+ card_path = os.path.join(args.output_dir, "README.md")
+ model_card.save(card_path)
diff --git a/diffusers/modeling_utils.py b/diffusers/modeling_utils.py
new file mode 100644
index 000000000..fb613614a
--- /dev/null
+++ b/diffusers/modeling_utils.py
@@ -0,0 +1,542 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from torch import Tensor, device
+
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
+from requests import HTTPError
+
+from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
+
+
+WEIGHTS_NAME = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_parameter_device(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).device
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: torch.nn.Module):
+ try:
+ return next(parameter.parameters()).dtype
+ except StopIteration:
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
+
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+
+def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
+ """
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
+ """
+ try:
+ return torch.load(checkpoint_file, map_location="cpu")
+ except Exception as e:
+ try:
+ with open(checkpoint_file) as f:
+ if f.read().startswith("version"):
+ raise OSError(
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+ "you cloned."
+ )
+ else:
+ raise ValueError(
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+ "model. Make sure you have saved the model properly."
+ ) from e
+ except (UnicodeDecodeError, ValueError):
+ raise OSError(
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+ f"at '{checkpoint_file}'. "
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+ )
+
+
+def _load_state_dict_into_model(model_to_load, state_dict):
+ # Convert old format to new format if needed from a PyTorch state_dict
+ # copy state_dict so _load_from_state_dict can modify it
+ state_dict = state_dict.copy()
+ error_msgs = []
+
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
+ # so we need to apply the function recursively.
+ def load(module: torch.nn.Module, prefix=""):
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
+ module._load_from_state_dict(*args)
+
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + ".")
+
+ load(model_to_load)
+
+ return error_msgs
+
+
+class ModelMixin(torch.nn.Module):
+ r"""
+ Base class for all models.
+
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
+ and saving models.
+
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
+ [`~modeling_utils.ModelMixin.save_pretrained`].
+ """
+ config_name = CONFIG_NAME
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
+
+ def __init__(self):
+ super().__init__()
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ save_function: Callable = torch.save,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful when in distributed training like
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+ the main process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+ need to replace `torch.save` by another method.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ model_to_save = self
+
+ # Attach architecture to the config
+ # Save the config
+ if is_main_process:
+ model_to_save.save_config(save_directory)
+
+ # Save the model
+ state_dict = model_to_save.state_dict()
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
+ os.remove(full_filename)
+
+ # Save the model
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
+
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+ the model, you should first set it back in training mode with `model.train()`.
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
+ `./my_model_directory/`.
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model.
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ output_loading_info = kwargs.pop("output_loading_info", False)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ subfolder = kwargs.pop("subfolder", None)
+
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+
+ # Load config if we don't provide a configuration
+ config_path = pretrained_model_name_or_path
+ model, unused_kwargs = cls.from_config(
+ config_path,
+ cache_dir=cache_dir,
+ return_unused_kwargs=True,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ subfolder=subfolder,
+ **kwargs,
+ )
+
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
+ raise ValueError(
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
+ )
+ elif torch_dtype is not None:
+ model = model.to(torch_dtype)
+
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+ # Load model
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+ # Load from a PyTorch checkpoint
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+ elif subfolder is not None and os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ ):
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ model_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=WEIGHTS_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ user_agent=user_agent,
+ subfolder=subfolder,
+ revision=revision,
+ )
+
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
+ "login` and pass `use_auth_token=True`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
+ "this model name. Check the model page at "
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
+ )
+ except HTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a file named {WEIGHTS_NAME} or"
+ " \nCheckout your internet connection or see how to run the library in"
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a file named {WEIGHTS_NAME}"
+ )
+
+ # restore default dtype
+ state_dict = load_state_dict(model_file)
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
+ model,
+ state_dict,
+ model_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
+ )
+
+ # Set model in evaluation mode to deactivate DropOut modules by default
+ model.eval()
+
+ if output_loading_info:
+ loading_info = {
+ "missing_keys": missing_keys,
+ "unexpected_keys": unexpected_keys,
+ "mismatched_keys": mismatched_keys,
+ "error_msgs": error_msgs,
+ }
+ return model, loading_info
+
+ return model
+
+ @classmethod
+ def _load_pretrained_model(
+ cls,
+ model,
+ state_dict,
+ resolved_archive_file,
+ pretrained_model_name_or_path,
+ ignore_mismatched_sizes=False,
+ ):
+ # Retrieve missing & unexpected_keys
+ model_state_dict = model.state_dict()
+ loaded_keys = [k for k in state_dict.keys()]
+
+ expected_keys = list(model_state_dict.keys())
+
+ original_loaded_keys = loaded_keys
+
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
+
+ # Make sure we are able to load base models as well as derived models (with heads)
+ model_to_load = model
+
+ def _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ loaded_keys,
+ ignore_mismatched_sizes,
+ ):
+ mismatched_keys = []
+ if ignore_mismatched_sizes:
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+
+ if (
+ model_key in model_state_dict
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
+ ):
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+ if state_dict is not None:
+ # Whole checkpoint
+ mismatched_keys = _find_mismatched_keys(
+ state_dict,
+ model_state_dict,
+ original_loaded_keys,
+ ignore_mismatched_sizes,
+ )
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
+
+ if len(error_msgs) > 0:
+ error_msg = "\n\t".join(error_msgs)
+ if "size mismatch" in error_msg:
+ error_msg += (
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+ )
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+
+ if len(unexpected_keys) > 0:
+ logger.warning(
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
+ " identical (initializing a BertForSequenceClassification model from a"
+ " BertForSequenceClassification model)."
+ )
+ else:
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+ if len(missing_keys) > 0:
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+ )
+ elif len(mismatched_keys) == 0:
+ logger.info(
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
+ " without further training."
+ )
+ if len(mismatched_keys) > 0:
+ mismatched_warning = "\n".join(
+ [
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+ for key, shape1, shape2 in mismatched_keys
+ ]
+ )
+ logger.warning(
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
+ " able to use it for predictions and inference."
+ )
+
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
+
+ @property
+ def device(self) -> device:
+ """
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+ device).
+ """
+ return get_parameter_device(self)
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+ """
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+ Args:
+ only_trainable (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of trainable parameters
+
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether or not to return only the number of non-embeddings parameters
+
+ Returns:
+ `int`: The number of parameters.
+ """
+
+ if exclude_embeddings:
+ embedding_param_names = [
+ f"{name}.weight"
+ for name, module_type in self.named_modules()
+ if isinstance(module_type, torch.nn.Embedding)
+ ]
+ non_embedding_parameters = [
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+ ]
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
+ else:
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+
+
+def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Recursively unwraps a model from potential containers (as used in distributed training).
+
+ Args:
+ model (`torch.nn.Module`): The model to unwrap.
+ """
+ # since there could be multiple levels of wrapping, unwrap recursively
+ if hasattr(model, "module"):
+ return unwrap_model(model.module)
+ else:
+ return model
diff --git a/diffusers/models/__init__.py b/diffusers/models/__init__.py
new file mode 100644
index 000000000..e0ac5c8d5
--- /dev/null
+++ b/diffusers/models/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .unet_2d import UNet2DModel
+from .unet_2d_condition import UNet2DConditionModel
+from .vae import AutoencoderKL, VQModel
diff --git a/diffusers/models/attention.py b/diffusers/models/attention.py
new file mode 100644
index 000000000..de9c92691
--- /dev/null
+++ b/diffusers/models/attention.py
@@ -0,0 +1,333 @@
+import math
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (:obj:`int`): The number of channels in the input and output.
+ num_head_channels (:obj:`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ num_groups: int = 32,
+ rescale_output_factor: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
+ return new_projection
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ # transpose
+ query_states = self.transpose_for_scores(query_proj)
+ key_states = self.transpose_for_scores(key_proj)
+ value_states = self.transpose_for_scores(value_proj)
+
+ # get scores
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
+
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value_states)
+
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
+ hidden_states = hidden_states.view(new_hidden_states_shape)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image.
+
+ Parameters:
+ in_channels (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ n_heads: int,
+ d_head: int,
+ depth: int = 1,
+ dropout: float = 0.0,
+ context_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.n_heads = n_heads
+ self.d_head = d_head
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
+ for d in range(depth)
+ ]
+ )
+
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def _set_attention_slice(self, slice_size):
+ for block in self.transformer_blocks:
+ block._set_attention_slice(slice_size)
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input and output.
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
+ d_head (:obj:`int`): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ d_head: int,
+ dropout=0.0,
+ context_dim: Optional[int] = None,
+ gated_ff: bool = True,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.attn1 = CrossAttention(
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = CrossAttention(
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def _set_attention_slice(self, slice_size):
+ self.attn1._slice_size = slice_size
+ self.attn2._slice_size = slice_size
+
+ def forward(self, x, context=None):
+ x = x.contiguous() if x.device.type == "mps" else x
+ x = self.attn1(self.norm1(x)) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (:obj:`int`): The number of channels in the query.
+ context_dim (:obj:`int`, *optional*):
+ The number of channels in the context. If not given, defaults to `query_dim`.
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = context_dim if context_dim is not None else query_dim
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self._slice_size = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, dim = x.shape
+
+ q = self.to_q(x)
+ context = context if context is not None else x
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q = self.reshape_heads_to_batch_dim(q)
+ k = self.reshape_heads_to_batch_dim(k)
+ v = self.reshape_heads_to_batch_dim(v)
+
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
+
+ # attention, what we cannot get enough of
+ hidden_states = self._attention(q, k, v, sequence_length, dim)
+
+ return self.to_out(hidden_states)
+
+ def _attention(self, query, key, value, sequence_length, dim):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+ attn_slice = (
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
+ )
+ attn_slice = attn_slice.softmax(dim=-1)
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ """
+
+ def __init__(
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+ project_in = GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (:obj:`int`): The number of channels in the input.
+ dim_out (:obj:`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
diff --git a/diffusers/models/embeddings.py b/diffusers/models/embeddings.py
new file mode 100644
index 000000000..86ac074c1
--- /dev/null
+++ b/diffusers/models/embeddings.py
@@ -0,0 +1,115 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def get_timestep_embedding(
+ timesteps: torch.Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
+ """
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
+ exponent = exponent / (half_dim - downscale_freq_shift)
+
+ emb = torch.exp(exponent).to(device=timesteps.device)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ # scale embeddings
+ emb = scale * emb
+
+ # concat sine and cosine embeddings
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
+ super().__init__()
+
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
+ self.act = None
+ if act_fn == "silu":
+ self.act = nn.SiLU()
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+ return sample
+
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
+ return t_emb
+
+
+class GaussianFourierProjection(nn.Module):
+ """Gaussian Fourier embeddings for noise levels."""
+
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ # to delete later
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
+
+ self.weight = self.W
+
+ def forward(self, x):
+ x = torch.log(x)
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
+ return out
diff --git a/diffusers/models/resnet.py b/diffusers/models/resnet.py
new file mode 100644
index 000000000..27fae24f7
--- /dev/null
+++ b/diffusers/models/resnet.py
@@ -0,0 +1,483 @@
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Upsample2D(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
+ elif use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(x)
+
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if self.use_conv:
+ if self.name == "conv":
+ x = self.conv(x)
+ else:
+ x = self.Conv2d_0(x)
+
+ return x
+
+
+class Downsample2D(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
+
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+
+ assert x.shape[1] == self.channels
+ x = self.conv(x)
+
+ return x
+
+
+class FirUpsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
+ `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+
+ # Setup filter kernel.
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float32)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+
+ if self.use_conv:
+ convH = weight.shape[2]
+ convW = weight.shape[3]
+ inC = weight.shape[1]
+
+ p = (kernel.shape[0] - factor) - (convW - 1)
+
+ stride = (factor, factor)
+ # Determine data dimensions.
+ stride = [1, 1, factor, factor]
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
+ output_padding = (
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
+ )
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ inC = weight.shape[1]
+ num_groups = x.shape[1] // inC
+
+ # Transpose weights.
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
+ weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
+
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
+
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return height
+
+
+class FirDownsample2D(nn.Module):
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+
+ Args:
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
+ order.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ # setup kernel
+ kernel = np.asarray(kernel, dtype=np.float32)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+
+ if self.use_conv:
+ _, _, convH, convW = weight.shape
+ p = (kernel.shape[0] - factor) + (convW - 1)
+ s = [factor, factor]
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
+ x = F.conv2d(x, weight, stride=s, padding=0)
+ else:
+ p = kernel.shape[0] - factor
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+ return x
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
+
+ return x
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ kernel=None,
+ output_scale_factor=1.0,
+ use_nin_shortcut=None,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.up = up
+ self.down = down
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.upsample = self.downsample = None
+ if self.up:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
+ else:
+ self.upsample = Upsample2D(in_channels, use_conv=False)
+ elif self.down:
+ if kernel == "fir":
+ fir_kernel = (1, 3, 3, 1)
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
+ elif kernel == "sde_vp":
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
+ else:
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
+
+ self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
+
+ self.conv_shortcut = None
+ if self.use_nin_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x, temb):
+ hidden_states = x
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ hidden_states = self.upsample(hidden_states)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ hidden_states = self.downsample(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
+ hidden_states = hidden_states + temb
+
+ # make sure hidden states is in float32
+ # when running in half-precision
+ hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+
+ out = (x + hidden_states) / self.output_scale_factor
+
+ return out
+
+
+class Mish(torch.nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(torch.nn.functional.softplus(x))
+
+
+def upsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Upsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
+ multiple of the upsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float32)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * (gain * (factor**2))
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
+ )
+
+
+def downsample_2d(x, kernel=None, factor=2, gain=1):
+ r"""Downsample2D a batch of 2D images with the given filter.
+
+ Args:
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
+ C]`.
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+
+ kernel = np.asarray(kernel, dtype=np.float32)
+ if kernel.ndim == 1:
+ kernel = np.outer(kernel, kernel)
+ kernel /= np.sum(kernel)
+
+ kernel = kernel * gain
+ p = kernel.shape[0] - factor
+ return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
+
+
+def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
+ if input.device.type == "mps":
+ out = out.to("cpu")
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out.to(input.device) # Move back to mps if necessary
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/diffusers/models/unet_2d.py b/diffusers/models/unet_2d.py
new file mode 100644
index 000000000..c3ab621a2
--- /dev/null
+++ b/diffusers/models/unet_2d.py
@@ -0,0 +1,246 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states output. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
+ Input sample size.
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
+ types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ center_input_sample: bool = False,
+ time_embedding_type: str = "positional",
+ freq_shift: int = 0,
+ flip_sin_to_cos: bool = True,
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
+ layers_per_block: int = 2,
+ mid_block_scale_factor: float = 1,
+ downsample_padding: int = 1,
+ act_fn: str = "silu",
+ attention_head_dim: int = 8,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ if time_embedding_type == "fourier":
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
+ timestep_input_dim = 2 * block_out_channels[0]
+ elif time_embedding_type == "positional":
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ return_dict: bool = True,
+ ) -> Union[UNet2DOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ skip_sample = sample
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "skip_conv"):
+ sample, res_samples, skip_sample = downsample_block(
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb)
+
+ # 5. up
+ skip_sample = None
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "skip_conv"):
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
+ else:
+ sample = upsample_block(sample, res_samples, emb)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if skip_sample is not None:
+ sample += skip_sample
+
+ if self.config.time_embedding_type == "fourier":
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
+ sample = sample / timesteps
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DOutput(sample=sample)
diff --git a/diffusers/models/unet_2d_condition.py b/diffusers/models/unet_2d_condition.py
new file mode 100644
index 000000000..92caaca92
--- /dev/null
+++ b/diffusers/models/unet_2d_condition.py
@@ -0,0 +1,270 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .embeddings import TimestepEmbedding, Timesteps
+from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin):
+ r"""
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
+ and returns sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ sample_size (`int`, *optional*): The size of the input sample.
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
+ The tuple of upsample blocks to use.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: int = 8,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ downsample_padding=downsample_padding,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift="default",
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ resnet_groups=norm_num_groups,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
+ )
+
+ for block in self.down_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ self.mid_block.set_attention_slice(slice_size)
+
+ for block in self.up_blocks:
+ if hasattr(block, "attentions") and block.attentions is not None:
+ block.set_attention_slice(slice_size)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ """r
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps.to(dtype=torch.float32)
+ timesteps = timesteps[None].to(device=sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ emb = self.time_embedding(t_emb)
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
+ sample, res_samples = downsample_block(
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
+
+ # 5. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ else:
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
+
+ # 6. post-process
+ # make sure hidden states is in float32
+ # when running in half-precision
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/diffusers/models/unet_blocks.py b/diffusers/models/unet_blocks.py
new file mode 100644
index 000000000..9e0621653
--- /dev/null
+++ b/diffusers/models/unet_blocks.py
@@ -0,0 +1,1481 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import numpy as np
+
+# limitations under the License.
+import torch
+from torch import nn
+
+from .attention import AttentionBlock, SpatialTransformer
+from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+ downsample_padding=None,
+):
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ )
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ cross_attention_dim=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attn_num_head_channels=attn_num_head_channels,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ AttentionBlock(
+ in_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.attention_type == "default":
+ hidden_states = attn(hidden_states)
+ else:
+ hidden_states = attn(hidden_states, encoder_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ SpatialTransformer(
+ in_channels,
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ downsample_padding=1,
+ add_downsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(self, hidden_states, temb=None, skip_sample=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_type="default",
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ attention_type="default",
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.attention_type = attention_type
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ out_channels,
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ depth=1,
+ context_dim=cross_attention_dim,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
+ raise ValueError(
+ f"Make sure slice_size {slice_size} is a divisor of "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
+ raise ValueError(
+ f"Chunk_size {slice_size} has to be smaller or equal to "
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
+ )
+
+ for attn in self.attentions:
+ attn._set_attention_slice(slice_size)
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
+ for resnet, attn in zip(self.resnets, self.attentions):
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+ for resnet in self.resnets:
+
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ num_groups=resnet_groups,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(self, hidden_states):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ attention_type="default",
+ output_scale_factor=np.sqrt(2.0),
+ upsample_padding=1,
+ add_upsample=True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ self.attention_type = attention_type
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions.append(
+ AttentionBlock(
+ out_channels,
+ num_head_channels=attn_num_head_channels,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor=np.sqrt(2.0),
+ add_upsample=True,
+ upsample_padding=1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_nin_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
diff --git a/diffusers/models/vae.py b/diffusers/models/vae.py
new file mode 100644
index 000000000..82748cb5b
--- /dev/null
+++ b/diffusers/models/vae.py
@@ -0,0 +1,581 @@
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..modeling_utils import ModelMixin
+from ..utils import BaseOutput
+from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ """
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Decoded output sample of the model. Output of the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+@dataclass
+class VQEncoderOutput(BaseOutput):
+ """
+ Output of VQModel encoding method.
+
+ Args:
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ Encoded output sample of the model. Output of the last layer of the model.
+ """
+
+ latents: torch.FloatTensor
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ down_block_types=("DownEncoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ double_z=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ def forward(self, x):
+ sample = x
+ sample = self.conv_in(sample)
+
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels=3,
+ out_channels=3,
+ up_block_types=("UpDecoderBlock2D",),
+ block_out_channels=(64,),
+ layers_per_block=2,
+ act_fn="silu",
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attn_num_head_channels=None,
+ resnet_groups=32,
+ temb_channels=None,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ attn_num_head_channels=None,
+ temb_channels=None,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ num_groups_out = 32
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ def forward(self, z):
+ sample = z
+ sample = self.conv_in(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ device = self.parameters.device
+ sample_device = "cpu" if device.type == "mps" else device
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+class VQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
+ Kavukcuoglu.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 3,
+ sample_size: int = 32,
+ num_vq_embeddings: int = 256,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=False,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+ self.quantize = VectorQuantizer(
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
+ )
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ # also go through quantization layer
+ if not force_not_quantize:
+ quant, emb_loss, info = self.quantize(h)
+ else:
+ quant = h
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ sample_size: int = 32,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
+
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/diffusers/onnx_utils.py b/diffusers/onnx_utils.py
new file mode 100644
index 000000000..e840565dd
--- /dev/null
+++ b/diffusers/onnx_utils.py
@@ -0,0 +1,189 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+
+from huggingface_hub import hf_hub_download
+
+from .utils import is_onnx_available, logging
+
+
+if is_onnx_available():
+ import onnxruntime as ort
+
+
+ONNX_WEIGHTS_NAME = "model.onnx"
+
+
+logger = logging.get_logger(__name__)
+
+
+class OnnxRuntimeModel:
+ base_model_prefix = "onnx_model"
+
+ def __init__(self, model=None, **kwargs):
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
+ self.model = model
+ self.model_save_dir = kwargs.get("model_save_dir", None)
+ self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
+
+ def __call__(self, **kwargs):
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
+ return self.model.run(None, inputs)
+
+ @staticmethod
+ def load_model(path: Union[str, Path], provider=None):
+ """
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
+
+ Arguments:
+ path (`str` or `Path`):
+ Directory from which to load
+ provider(`str`, *optional*):
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
+ """
+ if provider is None:
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
+ provider = "CPUExecutionProvider"
+
+ return ort.InferenceSession(path, providers=[provider])
+
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
+ latest_model_name.
+
+ Arguments:
+ save_directory (`str` or `Path`):
+ Directory where to save the model file.
+ file_name(`str`, *optional*):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
+ model with a different name.
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
+ dst_path = Path(save_directory).joinpath(model_file_name)
+ if not src_path.samefile(dst_path):
+ shutil.copyfile(src_path, dst_path)
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ **kwargs,
+ ):
+ """
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
+ method.:
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # saving model weights/files
+ self._save_pretrained(save_directory, **kwargs)
+
+ @classmethod
+ def _from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ use_auth_token: Optional[Union[bool, str, None]] = None,
+ revision: Optional[Union[str, None]] = None,
+ force_download: bool = False,
+ cache_dir: Optional[str] = None,
+ file_name: Optional[str] = None,
+ provider: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load a model from a directory or the HF Hub.
+
+ Arguments:
+ model_id (`str` or `Path`):
+ Directory from which to load
+ use_auth_token (`str` or `bool`):
+ Is needed to load models from a private or gated repository
+ revision (`str`):
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
+ cache_dir (`Union[str, Path]`, *optional*):
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ file_name(`str`):
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
+ different model files from the same repository or directory.
+ provider(`str`):
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
+ kwargs (`Dict`, *optional*):
+ kwargs will be passed to the model during initialization
+ """
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
+ # load model from local directory
+ if os.path.isdir(model_id):
+ model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
+ kwargs["model_save_dir"] = Path(model_id)
+ # load model from hub
+ else:
+ # download model
+ model_cache_path = hf_hub_download(
+ repo_id=model_id,
+ filename=model_file_name,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ )
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
+ kwargs["latest_model_name"] = Path(model_cache_path).name
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
+ return cls(model=model, **kwargs)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_id: Union[str, Path],
+ force_download: bool = True,
+ use_auth_token: Optional[str] = None,
+ cache_dir: Optional[str] = None,
+ **model_kwargs,
+ ):
+ revision = None
+ if len(str(model_id).split("@")) == 2:
+ model_id, revision = model_id.split("@")
+
+ return cls._from_pretrained(
+ model_id=model_id,
+ revision=revision,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ use_auth_token=use_auth_token,
+ **model_kwargs,
+ )
diff --git a/diffusers/optimization.py b/diffusers/optimization.py
new file mode 100644
index 000000000..e7b836b4a
--- /dev/null
+++ b/diffusers/optimization.py
@@ -0,0 +1,275 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for diffusion models."""
+
+import math
+from enum import Enum
+from typing import Optional, Union
+
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class SchedulerType(Enum):
+ LINEAR = "linear"
+ COSINE = "cosine"
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
+ POLYNOMIAL = "polynomial"
+ CONSTANT = "constant"
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+ """
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+ increases linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1.0, num_warmup_steps))
+ return 1.0
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+ """
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ return max(
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
+ )
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`float`, *optional*, defaults to 0.5):
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+ following a half-cosine).
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+ """
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+ linearly between 0 and the initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ num_cycles (`int`, *optional*, defaults to 1):
+ The number of hard restarts to use.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+ """
+
+ def lr_lambda(current_step):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def get_polynomial_decay_schedule_with_warmup(
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+ """
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+ initial lr set in the optimizer.
+
+ Args:
+ optimizer ([`~torch.optim.Optimizer`]):
+ The optimizer for which to schedule the learning rate.
+ num_warmup_steps (`int`):
+ The number of steps for the warmup phase.
+ num_training_steps (`int`):
+ The total number of training steps.
+ lr_end (`float`, *optional*, defaults to 1e-7):
+ The end LR.
+ power (`float`, *optional*, defaults to 1.0):
+ Power factor.
+ last_epoch (`int`, *optional*, defaults to -1):
+ The index of the last epoch when resuming training.
+
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+ implementation at
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+ Return:
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+ """
+
+ lr_init = optimizer.defaults["lr"]
+ if not (lr_init > lr_end):
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
+
+ def lr_lambda(current_step: int):
+ if current_step < num_warmup_steps:
+ return float(current_step) / float(max(1, num_warmup_steps))
+ elif current_step > num_training_steps:
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
+ else:
+ lr_range = lr_init - lr_end
+ decay_steps = num_training_steps - num_warmup_steps
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+ decay = lr_range * pct_remaining**power + lr_end
+ return decay / lr_init # as LambdaLR multiplies by lr_init
+
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+ SchedulerType.CONSTANT: get_constant_schedule,
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+}
+
+
+def get_scheduler(
+ name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+):
+ """
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diff --git a/diffusers/pipeline_utils.py b/diffusers/pipeline_utils.py
new file mode 100644
index 000000000..84ee9e20f
--- /dev/null
+++ b/diffusers/pipeline_utils.py
@@ -0,0 +1,417 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import inspect
+import os
+from dataclasses import dataclass
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import diffusers
+import PIL
+from huggingface_hub import snapshot_download
+from PIL import Image
+from tqdm.auto import tqdm
+
+from .configuration_utils import ConfigMixin
+from .utils import DIFFUSERS_CACHE, BaseOutput, logging
+
+
+INDEX_FILE = "diffusion_pytorch_model.bin"
+
+
+logger = logging.get_logger(__name__)
+
+
+LOADABLE_CLASSES = {
+ "diffusers": {
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
+ "SchedulerMixin": ["save_config", "from_config"],
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ },
+ "transformers": {
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
+ },
+}
+
+ALL_IMPORTABLE_CLASSES = {}
+for library in LOADABLE_CLASSES:
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
+
+
+@dataclass
+class ImagePipelineOutput(BaseOutput):
+ """
+ Output class for image pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+class DiffusionPipeline(ConfigMixin):
+ r"""
+ Base class for all models.
+
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
+
+ - move all PyTorch modules to the device of your choice
+ - enabling/disabling the progress bar for the denoising iteration
+
+ Class attributes:
+
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
+ compenents of the diffusion pipeline.
+ """
+ config_name = "model_index.json"
+
+ def register_modules(self, **kwargs):
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ for name, module in kwargs.items():
+ # retrive library
+ library = module.__module__.split(".")[0]
+
+ # check if the module is a pipeline module
+ pipeline_dir = module.__module__.split(".")[-2]
+ path = module.__module__.split(".")
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
+
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
+ # Or if it's a pipeline module, then the module is inside the pipeline
+ # folder so we set the library to module name.
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
+ library = pipeline_dir
+
+ # retrive class_name
+ class_name = module.__class__.__name__
+
+ register_dict = {name: (library, class_name)}
+
+ # save model index config
+ self.register_to_config(**register_dict)
+
+ # set models
+ setattr(self, name, module)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
+ """
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to which to save. Will be created if it doesn't exist.
+ """
+ self.save_config(save_directory)
+
+ model_index_dict = dict(self.config)
+ model_index_dict.pop("_class_name")
+ model_index_dict.pop("_diffusers_version")
+ model_index_dict.pop("_module", None)
+
+ for pipeline_component_name in model_index_dict.keys():
+ sub_model = getattr(self, pipeline_component_name)
+ model_cls = sub_model.__class__
+
+ save_method_name = None
+ # search for the model's base class in LOADABLE_CLASSES
+ for library_name, library_classes in LOADABLE_CLASSES.items():
+ library = importlib.import_module(library_name)
+ for base_class, save_load_methods in library_classes.items():
+ class_candidate = getattr(library, base_class)
+ if issubclass(model_cls, class_candidate):
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
+ save_method_name = save_load_methods[0]
+ break
+ if save_method_name is not None:
+ break
+
+ save_method = getattr(sub_model, save_method_name)
+ save_method(os.path.join(save_directory, pipeline_component_name))
+
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
+ if torch_device is None:
+ return self
+
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ module.to(torch_device)
+ return self
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ module_names, _ = self.extract_init_dict(dict(self.config))
+ for name in module_names.keys():
+ module = getattr(self, name)
+ if isinstance(module, torch.nn.Module):
+ return module.device
+ return torch.device("cpu")
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
+ r"""
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
+
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
+
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+ task.
+
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+ weights are discarded.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
+ `CompVis/ldm-text2im-large-256`.
+ - A path to a *directory* containing pipeline weights saved using
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
+ torch_dtype (`str` or `torch.dtype`, *optional*):
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
+ will be automatically derived from the model's weights.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ resume_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
+ file exists.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether or not to only look at local files (i.e., do not try to download the model).
+ use_auth_token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+ identifier allowed by git.
+ mirror (`str`, *optional*):
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+ Please refer to the mirror site for more information. specify the folder name here.
+
+ kwargs (remaining dictionary of keyword arguments, *optional*):
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
+ speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
+ method. See example below for more information.
+
+
+
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
+ `"CompVis/stable-diffusion-v1-4"`
+
+
+
+
+
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
+ this method in a firewalled environment.
+
+
+
+ Examples:
+
+ ```py
+ >>> from diffusers import DiffusionPipeline
+
+ >>> # Download pipeline from huggingface.co and cache.
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
+
+ >>> # Download pipeline that requires an authorization token
+ >>> # For more information on access tokens, please refer to this section
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
+
+ >>> # Download pipeline, but overwrite scheduler
+ >>> from diffusers import LMSDiscreteScheduler
+
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
+ >>> pipeline = DiffusionPipeline.from_pretrained(
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
+ ... )
+ ```
+ """
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ torch_dtype = kwargs.pop("torch_dtype", None)
+ provider = kwargs.pop("provider", None)
+
+ # 1. Download the checkpoints and configs
+ # use snapshot download here to get it working from from_pretrained
+ if not os.path.isdir(pretrained_model_name_or_path):
+ cached_folder = snapshot_download(
+ pretrained_model_name_or_path,
+ cache_dir=cache_dir,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ use_auth_token=use_auth_token,
+ revision=revision,
+ )
+ else:
+ cached_folder = pretrained_model_name_or_path
+
+ config_dict = cls.get_config_dict(cached_folder)
+
+ # 2. Load the pipeline class, if using custom module then load it from the hub
+ # if we load from explicit class, let's use it
+ if cls != DiffusionPipeline:
+ pipeline_class = cls
+ else:
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
+
+ # some modules can be passed directly to the init
+ # in this case they are already instantiated in `kwargs`
+ # extract them here
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
+
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
+
+ init_kwargs = {}
+
+ # import it here to avoid circular import
+ from diffusers import pipelines
+
+ # 3. Load each module in the pipeline
+ for name, (library_name, class_name) in init_dict.items():
+ is_pipeline_module = hasattr(pipelines, library_name)
+ loaded_sub_model = None
+
+ # if the model is in a pipeline module, then we load it from the pipeline
+ if name in passed_class_obj:
+ # 1. check that passed_class_obj has correct parent class
+ if not is_pipeline_module:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ expected_class_obj = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ expected_class_obj = class_candidate
+
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
+ raise ValueError(
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
+ f" {expected_class_obj}"
+ )
+ else:
+ logger.warn(
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
+ " has the correct type"
+ )
+
+ # set passed class object
+ loaded_sub_model = passed_class_obj[name]
+ elif is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
+ else:
+ # else we just import it from the library.
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+ importable_classes = LOADABLE_CLASSES[library_name]
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
+
+ if loaded_sub_model is None:
+ load_method_name = None
+ for class_name, class_candidate in class_candidates.items():
+ if issubclass(class_obj, class_candidate):
+ load_method_name = importable_classes[class_name][1]
+
+ load_method = getattr(class_obj, load_method_name)
+
+ loading_kwargs = {}
+ if issubclass(class_obj, torch.nn.Module):
+ loading_kwargs["torch_dtype"] = torch_dtype
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
+ loading_kwargs["provider"] = provider
+
+ # check if the module is in a subdirectory
+ if os.path.isdir(os.path.join(cached_folder, name)):
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
+ else:
+ # else load from the root directory
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
+
+ # 4. Instantiate the pipeline
+ model = pipeline_class(**init_kwargs)
+ return model
+
+ @staticmethod
+ def numpy_to_pil(images):
+ """
+ Convert a numpy image or a batch of images to a PIL image.
+ """
+ if images.ndim == 3:
+ images = images[None, ...]
+ images = (images * 255).round().astype("uint8")
+ pil_images = [Image.fromarray(image) for image in images]
+
+ return pil_images
+
+ def progress_bar(self, iterable):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ return tqdm(iterable, **self._progress_bar_config)
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
diff --git a/diffusers/pipelines/__init__.py b/diffusers/pipelines/__init__.py
new file mode 100644
index 000000000..3e2aeb4fb
--- /dev/null
+++ b/diffusers/pipelines/__init__.py
@@ -0,0 +1,19 @@
+from ..utils import is_onnx_available, is_transformers_available
+from .ddim import DDIMPipeline
+from .ddpm import DDPMPipeline
+from .latent_diffusion_uncond import LDMPipeline
+from .pndm import PNDMPipeline
+from .score_sde_ve import ScoreSdeVePipeline
+from .stochastic_karras_ve import KarrasVePipeline
+
+
+if is_transformers_available():
+ from .latent_diffusion import LDMTextToImagePipeline
+ from .stable_diffusion import (
+ StableDiffusionImg2ImgPipeline,
+ StableDiffusionInpaintPipeline,
+ StableDiffusionPipeline,
+ )
+
+if is_transformers_available() and is_onnx_available():
+ from .stable_diffusion import StableDiffusionOnnxPipeline
diff --git a/diffusers/pipelines/ddim/__init__.py b/diffusers/pipelines/ddim/__init__.py
new file mode 100644
index 000000000..8fd31868a
--- /dev/null
+++ b/diffusers/pipelines/ddim/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddim import DDIMPipeline
diff --git a/diffusers/pipelines/ddim/pipeline_ddim.py b/diffusers/pipelines/ddim/pipeline_ddim.py
new file mode 100644
index 000000000..33f6064db
--- /dev/null
+++ b/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -0,0 +1,117 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDIMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ eta (`float`, *optional*, defaults to 0.0):
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # eta corresponds to η in paper and should be between [0, 1]
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
+ # do x_t -> x_t-1
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/ddpm/__init__.py b/diffusers/pipelines/ddpm/__init__.py
new file mode 100644
index 000000000..8889bdae1
--- /dev/null
+++ b/diffusers/pipelines/ddpm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_ddpm import DDPMPipeline
diff --git a/diffusers/pipelines/ddpm/pipeline_ddpm.py b/diffusers/pipelines/ddpm/pipeline_ddpm.py
new file mode 100644
index 000000000..71103bbe4
--- /dev/null
+++ b/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -0,0 +1,106 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+class DDPMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
+ [`DDPMScheduler`], or [`DDIMScheduler`].
+ """
+
+ def __init__(self, unet, scheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ # set step values
+ self.scheduler.set_timesteps(1000)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # 1. predict noise model_output
+ model_output = self.unet(image, t).sample
+
+ # 2. compute previous image: x_t -> t_t-1
+ image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/latent_diffusion/__init__.py b/diffusers/pipelines/latent_diffusion/__init__.py
new file mode 100644
index 000000000..c481b38cf
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/__init__.py
@@ -0,0 +1,6 @@
+# flake8: noqa
+from ...utils import is_transformers_available
+
+
+if is_transformers_available():
+ from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
diff --git a/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
new file mode 100644
index 000000000..b39840f24
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -0,0 +1,705 @@
+import inspect
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from transformers.activations import ACT2FN
+from transformers.configuration_utils import PretrainedConfig
+from transformers.modeling_outputs import BaseModelOutput
+from transformers.modeling_utils import PreTrainedModel
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+
+
+class LDMTextToImagePipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ bert ([`LDMBertModel`]):
+ Text-encoder model based on [BERT](ttps://huggingface.co/docs/transformers/model_doc/bert) architecture.
+ tokenizer (`transformers.BertTokenizer`):
+ Tokenizer of class
+ [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ """
+
+ def __init__(
+ self,
+ vqvae: Union[VQModel, AutoencoderKL],
+ bert: PreTrainedModel,
+ tokenizer: PreTrainedTokenizer,
+ unet: Union[UNet2DModel, UNet2DConditionModel],
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 256,
+ width: Optional[int] = 256,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 1.0,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 256):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 256):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 1.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt` at
+ the, usually at the expense of lower image quality.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get unconditional embeddings for classifier free guidance
+ if guidance_scale != 1.0:
+ uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
+ uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
+ text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, height // 8, width // 8),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ if guidance_scale == 1.0:
+ # guidance_scale of 1 means no guidance
+ latents_input = latents
+ context = text_embeddings
+ else:
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ latents_input = torch.cat([latents] * 2)
+ context = torch.cat([uncond_embeddings, text_embeddings])
+
+ # predict the noise residual
+ noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
+ # perform guidance
+ if guidance_scale != 1.0:
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
+
+
+################################################################################
+# Code for the text transformer model
+################################################################################
+""" PyTorch LDMBERT model."""
+
+
+logger = logging.get_logger(__name__)
+
+LDMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "ldm-bert",
+ # See all LDMBert models at https://huggingface.co/models?filter=ldmbert
+]
+
+
+LDMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ "ldm-bert": "https://huggingface.co/ldm-bert/resolve/main/config.json",
+}
+
+
+""" LDMBERT model configuration"""
+
+
+class LDMBertConfig(PretrainedConfig):
+ model_type = "ldmbert"
+ keys_to_ignore_at_inference = ["past_key_values"]
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ max_position_embeddings=77,
+ encoder_layers=32,
+ encoder_ffn_dim=5120,
+ encoder_attention_heads=8,
+ head_dim=64,
+ encoder_layerdrop=0.0,
+ activation_function="gelu",
+ d_model=1280,
+ dropout=0.1,
+ attention_dropout=0.0,
+ activation_dropout=0.0,
+ init_std=0.02,
+ classifier_dropout=0.0,
+ scale_embedding=False,
+ use_cache=True,
+ pad_token_id=0,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.d_model = d_model
+ self.encoder_ffn_dim = encoder_ffn_dim
+ self.encoder_layers = encoder_layers
+ self.encoder_attention_heads = encoder_attention_heads
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_dropout = attention_dropout
+ self.activation_dropout = activation_dropout
+ self.activation_function = activation_function
+ self.init_std = init_std
+ self.encoder_layerdrop = encoder_layerdrop
+ self.classifier_dropout = classifier_dropout
+ self.use_cache = use_cache
+ self.num_hidden_layers = encoder_layers
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
+
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
+
+
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->LDMBert
+class LDMBertAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ head_dim: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ bias: bool = False,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = head_dim
+ self.inner_dim = head_dim * num_heads
+
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+
+ self.k_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, self.inner_dim, bias=bias)
+ self.out_proj = nn.Linear(self.inner_dim, embed_dim)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned aross GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.inner_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+class LDMBertEncoderLayer(nn.Module):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__()
+ self.embed_dim = config.d_model
+ self.self_attn = LDMBertAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ head_dim=config.head_dim,
+ dropout=config.attention_dropout,
+ )
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ layer_head_mask: torch.FloatTensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+ hidden_states, attn_weights, _ = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+# Copied from transformers.models.bart.modeling_bart.BartPretrainedModel with Bart->LDMBert
+class LDMBertPreTrainedModel(PreTrainedModel):
+ config_class = LDMBertConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (LDMBertEncoder,)):
+ module.gradient_checkpointing = value
+
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ }
+ return dummy_inputs
+
+
+class LDMBertEncoder(LDMBertPreTrainedModel):
+ """
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`LDMBertEncoderLayer`].
+
+ Args:
+ config: LDMBertConfig
+ embed_tokens (nn.Embedding): output embedding
+ """
+
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+
+ embed_dim = config.d_model
+ self.padding_idx = config.pad_token_id
+ self.max_source_positions = config.max_position_embeddings
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
+ self.embed_positions = nn.Embedding(config.max_position_embeddings, embed_dim)
+ self.layers = nn.ModuleList([LDMBertEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layer_norm = nn.LayerNorm(embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_len = input_shape[1]
+ if position_ids is None:
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=inputs_embeds.device).expand((1, -1))
+ embed_pos = self.embed_positions(position_ids)
+
+ hidden_states = inputs_embeds + embed_pos
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(encoder_layer),
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class LDMBertModel(LDMBertPreTrainedModel):
+ def __init__(self, config: LDMBertConfig):
+ super().__init__(config)
+ self.model = LDMBertEncoder(config)
+ self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/diffusers/pipelines/latent_diffusion_uncond/__init__.py b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
new file mode 100644
index 000000000..0826ca753
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_latent_diffusion_uncond import LDMPipeline
diff --git a/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
new file mode 100644
index 000000000..4979d88fe
--- /dev/null
+++ b/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py
@@ -0,0 +1,108 @@
+import inspect
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel, VQModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import DDIMScheduler
+
+
+class LDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ vqvae ([`VQModel`]):
+ Vector-quantized (VQ) Model to encode and decode images to and from latent representations.
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ [`DDIMScheduler`] is to be used in combination with `unet` to denoise the encoded image latens.
+ """
+
+ def __init__(self, vqvae: VQModel, unet: UNet2DModel, scheduler: DDIMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ generator: Optional[torch.Generator] = None,
+ eta: float = 0.0,
+ num_inference_steps: int = 50,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ Number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ latents = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ latents = latents.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+
+ extra_kwargs = {}
+ if accepts_eta:
+ extra_kwargs["eta"] = eta
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # predict the noise residual
+ noise_prediction = self.unet(latents, t).sample
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
+
+ # decode the image latents with the VAE
+ image = self.vqvae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/pndm/__init__.py b/diffusers/pipelines/pndm/__init__.py
new file mode 100644
index 000000000..6fc46aaab
--- /dev/null
+++ b/diffusers/pipelines/pndm/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_pndm import PNDMPipeline
diff --git a/diffusers/pipelines/pndm/pipeline_pndm.py b/diffusers/pipelines/pndm/pipeline_pndm.py
new file mode 100644
index 000000000..f3dff1a9a
--- /dev/null
+++ b/diffusers/pipelines/pndm/pipeline_pndm.py
@@ -0,0 +1,111 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+# limitations under the License.
+
+
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import PNDMScheduler
+
+
+class PNDMPipeline(DiffusionPipeline):
+ r"""
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Parameters:
+ unet (`UNet2DModel`): U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ The `PNDMScheduler` to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ unet: UNet2DModel
+ scheduler: PNDMScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: PNDMScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, `optional`, defaults to 1): The number of images to generate.
+ num_inference_steps (`int`, `optional`, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ generator (`torch.Generator`, `optional`): A [torch
+ generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, `optional`, defaults to `"pil"`): The output format of the generate image. Choose
+ between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, `optional`, defaults to `True`): Whether or not to return a
+ [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ # For more information on the sampling method you can take a look at Algorithm 2 of
+ # the official paper: https://arxiv.org/pdf/2202.09778.pdf
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ # Sample gaussian noise to begin loop
+ image = torch.randn(
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
+ generator=generator,
+ )
+ image = image.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ for t in self.progress_bar(self.scheduler.timesteps):
+ model_output = self.unet(image, t).sample
+
+ image = self.scheduler.step(model_output, t, image).prev_sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/pipelines/score_sde_ve/__init__.py b/diffusers/pipelines/score_sde_ve/__init__.py
new file mode 100644
index 000000000..000d61f6e
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_score_sde_ve import ScoreSdeVePipeline
diff --git a/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
new file mode 100644
index 000000000..604e2b54c
--- /dev/null
+++ b/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import ScoreSdeVeScheduler
+
+
+class ScoreSdeVePipeline(DiffusionPipeline):
+ r"""
+ Parameters:
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. scheduler ([`SchedulerMixin`]):
+ The [`ScoreSdeVeScheduler`] scheduler to be used in combination with `unet` to denoise the encoded image.
+ """
+ unet: UNet2DModel
+ scheduler: ScoreSdeVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline):
+ super().__init__()
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 2000,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+ self.scheduler.set_sigmas(num_inference_steps)
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
+
+ # correction step
+ for _ in range(self.scheduler.correct_steps):
+ model_output = self.unet(sample, sigma_t).sample
+ sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
+
+ # prediction step
+ model_output = model(sample, sigma_t).sample
+ output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
+
+ sample, sample_mean = output.prev_sample, output.prev_sample_mean
+
+ sample = sample_mean.clamp(0, 1)
+ sample = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ sample = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (sample,)
+
+ return ImagePipelineOutput(images=sample)
diff --git a/diffusers/pipelines/stable_diffusion/__init__.py b/diffusers/pipelines/stable_diffusion/__init__.py
new file mode 100644
index 000000000..5ffda93f1
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/__init__.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+
+import PIL
+from PIL import Image
+
+from ...utils import BaseOutput, is_onnx_available, is_transformers_available
+
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ """
+ Output class for Stable Diffusion pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ nsfw_content_detected (`List[bool]`)
+ List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
+ nsfw_content_detected: List[bool]
+
+
+if is_transformers_available():
+ from .pipeline_stable_diffusion import StableDiffusionPipeline
+ from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
+ from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
+ from .safety_checker import StableDiffusionSafetyChecker
+
+if is_transformers_available() and is_onnx_available():
+ from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
new file mode 100644
index 000000000..81c57e216
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -0,0 +1,410 @@
+# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion.
+
+import inspect
+import os
+import numpy as np
+import random
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+# def get_latent(self, image):
+# return self.vae.encode(image)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_device = "cpu" if self.device.type == "mps" else self.device
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ speedup = 1
+# if len(os.environ["forcedlatent"]) < 5:
+# forcedlatent = np.random.randn(4*64*64).reshape(4,64,64)
+# for u in range(64):
+# for v in range(64):
+# if (u-32)**2 + (v-32)**2 > (32*1.2)**2:
+# forcedlatent[0][u][v] = 0
+# forcedlatent[1][u][v] = 0
+# forcedlatent[2][u][v] = 0
+# forcedlatent[3][u][v] = 0
+# os.environ["forcedlatent"] = str(list(forcedlatent.flatten()))
+
+ if latents is None:
+ latents = torch.randn(
+ latents_intermediate_shape,
+ generator=generator,
+ device=latents_device,
+ )
+ if len(os.environ["forcedlatent"]) > 10:
+ stri = os.environ["forcedlatent"]
+ print(f"we get a forcing for the latent z: {stri[:20]}.")
+ if len(eval(stri)) == 1:
+ stri = str(eval(stri)[0])
+ speedup = 1
+ latents = np.array(list(eval(stri))).flatten()
+ #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents))
+ #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents
+ #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2))
+ print(f"As an array, this is {latents[:10]}")
+ print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}")
+ latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device)
+ os.environ["forcedlatent"] = ""
+ good = eval(os.environ["good"])
+ bad = eval(os.environ["bad"])
+ print(f"{len(good)} good and {len(bad)} bad")
+ i_believe_in_evolution = len(good) > 0 and len(bad) > 10
+ print(f"I believe in evolution = {i_believe_in_evolution}")
+ if i_believe_in_evolution:
+ from sklearn import tree
+ from sklearn.neural_network import MLPClassifier
+ #from sklearn.neighbors import NearestCentroid
+ from sklearn.linear_model import LogisticRegression
+ #z = (np.random.randn(4*64*64))
+ z = latents.cpu().numpy().flatten()
+ if os.environ.get("skl", "tree") == "tree":
+ clf = tree.DecisionTreeClassifier()#min_samples_split=0.1)
+ elif os.environ.get("skl", "tree") == "logit":
+ clf = LogisticRegression()
+ else:
+ clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1)
+ #clf = NearestCentroid()
+
+
+
+ X=good + bad
+ Y = [1] * len(good) + [0] * len(bad)
+ clf = clf.fit(X,Y)
+ epsilon = 0.0001 # for astronauts
+ epsilon = 1.0
+
+ def loss(x):
+ return clf.predict_proba([x])[0][0] # for astronauts
+ #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts
+ #return clf.predict_proba([z+epsilon*x])[0][0]
+
+
+ budget = int(os.environ.get("budget", "300"))
+ if i_believe_in_evolution and budget > 20:
+ import nevergrad as ng
+ #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget)
+ #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget)
+ optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")]
+ #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget)
+ nevergrad_optimizer = optim_class(len(z), budget)
+ #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget)
+# for k in range(5):
+# z1 = np.array(random.choice(good))
+# z2 = np.array(random.choice(good))
+# z3 = np.array(random.choice(good))
+# z4 = np.array(random.choice(good))
+# z5 = np.array(random.choice(good))
+# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4.
+# z = 0.2 * (z1 + z2 + z3 + z4 + z5)
+# mu = int(os.environ.get("mu", "5"))
+# parents = [z1, z2, z3, z4, z5]
+# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)]
+# z = weights[0] * z1
+# for u in range(mu):
+# if u > 0:
+# z += weights[u] * parents[u]
+# z = (1. / sum(weights[:mu])) * z
+# z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+#
+# #for u in range(len(z)):
+# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]])
+# nevergrad_optimizer.suggest
+ if len(os.environ["forcedlatent"]) > 0:
+ print("we get a forcing for the latent z.")
+ z0 = eval(os.environ["forcedlatent"])
+ #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"]))
+ else:
+ z0 = z
+ for i in range(budget):
+ x = nevergrad_optimizer.ask()
+ z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value
+ z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+ l = loss(z)
+ nevergrad_optimizer.tell(x, l)
+ if np.log2(i+1) == int(np.log2(i+1)):
+ print(f"iteration {i} --> {l}")
+ print("var/variable = ", sum(z**2)/len(z))
+ #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2))
+ if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]:
+ print(f"we find proba(bad)={l}")
+ break
+ x = nevergrad_optimizer.recommend().value
+ z = z0 + float(os.environ.get("epsilon", "0.001")) * x
+ z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+ latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half()
+ else:
+ if latents.shape != latents_intermediate_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}")
+ print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}")
+ print(f"latent ==> {torch.max(latents)}")
+ print(f"latent ==> {torch.min(latents)}")
+ os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy()))
+ for i in [2, 3]:
+ latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i]))
+ latents = latents.float().to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ #print(f"text_embeddings.shape={text_embeddings.shape}")
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy()))
+ latents = 1 / 0.18215 * latents
+ #os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy()))
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
new file mode 100644
index 000000000..475ceef4f
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -0,0 +1,291 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+def preprocess(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image to image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
+ `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter will be modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ if not isinstance(init_image, torch.FloatTensor):
+ init_image = preprocess(init_image)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+ init_latents = 0.18215 * init_latents
+
+ # expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ timesteps = torch.tensor(
+ [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
+ )
+ else:
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])):
+ t_index = t_start + i
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[t_index]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+ latent_model_input = latent_model_input.to(self.unet.dtype)
+ t = t.to(self.unet.dtype)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents.to(self.vae.dtype)).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
new file mode 100644
index 000000000..05ea84ae0
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -0,0 +1,309 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+import PIL
+from tqdm.auto import tqdm
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, PNDMScheduler
+from ...utils import logging
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__)
+
+
+def preprocess_image(image):
+ w, h = image.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.0 * image - 1.0
+
+
+def preprocess_mask(mask):
+ mask = mask.convert("L")
+ w, h = mask.size
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
+ mask = np.array(mask).astype(np.float32) / 255.0
+ mask = np.tile(mask, (4, 1, 1))
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
+ mask = 1 - mask # repaint white, keep black
+ mask = torch.from_numpy(mask)
+ return mask
+
+
+class StableDiffusionInpaintPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `set_attention_slice`
+ self.enable_attention_slice(None)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ init_image: Union[torch.FloatTensor, PIL.Image.Image],
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
+ strength: float = 0.8,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ init_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
+ process. This is the image whose masked region will be inpainted.
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
+ `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
+ replaced by noise and therefore repainted, while black pixels will be preserved. The mask image will be
+ converted to a single channel (luminance) before use.
+ strength (`float`, *optional*, defaults to 0.8):
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
+ in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ offset = 0
+ if accepts_offset:
+ offset = 1
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # preprocess image
+ init_image = preprocess_image(init_image).to(self.device)
+
+ # encode the init image into latents and scale the latents
+ init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
+ init_latents = init_latent_dist.sample(generator=generator)
+
+ init_latents = 0.18215 * init_latents
+
+ # Expand init_latents for batch_size
+ init_latents = torch.cat([init_latents] * batch_size)
+ init_latents_orig = init_latents
+
+ # preprocess mask
+ mask = preprocess_mask(mask_image).to(self.device)
+ mask = torch.cat([mask] * batch_size)
+
+ # check sizes
+ if not mask.shape == init_latents.shape:
+ raise ValueError("The mask and init_image should be the same size!")
+
+ # get the original timestep using init_timestep
+ init_timestep = int(num_inference_steps * strength) + offset
+ init_timestep = min(init_timestep, num_inference_steps)
+ timesteps = self.scheduler.timesteps[-init_timestep]
+ timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
+
+ # add noise to latents using the timesteps
+ noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
+ init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ latents = init_latents
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
+ for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # masking
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
new file mode 100644
index 000000000..7ff3ff22f
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py
@@ -0,0 +1,165 @@
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+
+from transformers import CLIPFeatureExtractor, CLIPTokenizer
+
+from ...onnx_utils import OnnxRuntimeModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+
+
+class StableDiffusionOnnxPipeline(DiffusionPipeline):
+ vae_decoder: OnnxRuntimeModel
+ text_encoder: OnnxRuntimeModel
+ tokenizer: CLIPTokenizer
+ unet: OnnxRuntimeModel
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
+ safety_checker: OnnxRuntimeModel
+ feature_extractor: CLIPFeatureExtractor
+
+ def __init__(
+ self,
+ vae_decoder: OnnxRuntimeModel,
+ text_encoder: OnnxRuntimeModel,
+ tokenizer: CLIPTokenizer,
+ unet: OnnxRuntimeModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: OnnxRuntimeModel,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("np")
+ self.register_modules(
+ vae_decoder=vae_decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ latents: Optional[np.ndarray] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="np",
+ )
+ text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
+ )
+ uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = np.concatenate([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+ latents_shape = (batch_size, 4, height // 8, width // 8)
+ if latents is None:
+ latents = np.random.randn(*latents_shape).astype(np.float32)
+ elif latents.shape != latents_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings
+ )
+ noise_pred = noise_pred[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ latents = 1 / 0.18215 * latents
+ image = self.vae_decoder(latent_sample=latents)[0]
+
+ image = np.clip(image / 2 + 0.5, 0, 1)
+ image = image.transpose((0, 2, 3, 1))
+
+ # run safety checker
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
+ image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/diffusers/pipelines/stable_diffusion/safety_checker.py b/diffusers/pipelines/stable_diffusion/safety_checker.py
new file mode 100644
index 000000000..3ebc05c91
--- /dev/null
+++ b/diffusers/pipelines/stable_diffusion/safety_checker.py
@@ -0,0 +1,106 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
+
+from ...utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def cosine_distance(image_embeds, text_embeds):
+ normalized_image_embeds = nn.functional.normalize(image_embeds)
+ normalized_text_embeds = nn.functional.normalize(text_embeds)
+ return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
+
+
+class StableDiffusionSafetyChecker(PreTrainedModel):
+ config_class = CLIPConfig
+
+ def __init__(self, config: CLIPConfig):
+ super().__init__(config)
+
+ self.vision_model = CLIPVisionModel(config.vision_config)
+ self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
+
+ self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
+ self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
+
+ self.register_buffer("concept_embeds_weights", torch.ones(17))
+ self.register_buffer("special_care_embeds_weights", torch.ones(3))
+
+ @torch.no_grad()
+ def forward(self, clip_input, images):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy()
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy()
+
+ result = []
+ batch_size = image_embeds.shape[0]
+ for i in range(batch_size):
+ result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
+
+ # increase this value to create a stronger `nfsw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ for concet_idx in range(len(special_cos_dist[0])):
+ concept_cos = special_cos_dist[i][concet_idx]
+ concept_threshold = self.special_care_embeds_weights[concet_idx].item()
+ result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["special_scores"][concet_idx] > 0:
+ result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
+ adjustment = 0.01
+
+ for concet_idx in range(len(cos_dist[0])):
+ concept_cos = cos_dist[i][concet_idx]
+ concept_threshold = self.concept_embeds_weights[concet_idx].item()
+ result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
+ if result_img["concept_scores"][concet_idx] > 0:
+ result_img["bad_concepts"].append(concet_idx)
+
+ result.append(result_img)
+
+ has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
+
+ #for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
+ # if has_nsfw_concept:
+ # images[idx] = np.zeros(images[idx].shape) # black image
+#
+# if any(has_nsfw_concepts):
+# logger.warning(
+# "Potential NSFW content was detected in one or more images. A black image will be returned instead."
+# " Try again with a different prompt and/or seed."
+# )
+
+ return images, has_nsfw_concepts
+
+ @torch.inference_mode()
+ def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
+ pooled_output = self.vision_model(clip_input)[1] # pooled_output
+ image_embeds = self.visual_projection(pooled_output)
+
+ special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
+ cos_dist = cosine_distance(image_embeds, self.concept_embeds)
+
+ # increase this value to create a stronger `nsfw` filter
+ # at the cost of increasing the possibility of filtering benign images
+ adjustment = 0.0
+
+ special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment
+ # special_scores = special_scores.round(decimals=3)
+ special_care = torch.any(special_scores > 0, dim=1)
+ special_adjustment = special_care * 0.01
+ special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1])
+
+ concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
+ # concept_scores = concept_scores.round(decimals=3)
+ has_nsfw_concepts = torch.any(concept_scores > 0, dim=1)
+
+ images[has_nsfw_concepts] = 0.0 # black image
+
+ return images, has_nsfw_concepts
diff --git a/diffusers/pipelines/stochastic_karras_ve/__init__.py b/diffusers/pipelines/stochastic_karras_ve/__init__.py
new file mode 100644
index 000000000..db2582043
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/__init__.py
@@ -0,0 +1,2 @@
+# flake8: noqa
+from .pipeline_stochastic_karras_ve import KarrasVePipeline
diff --git a/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
new file mode 100644
index 000000000..15266544d
--- /dev/null
+++ b/diffusers/pipelines/stochastic_karras_ve/pipeline_stochastic_karras_ve.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ...models import UNet2DModel
+from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import KarrasVeScheduler
+
+
+class KarrasVePipeline(DiffusionPipeline):
+ r"""
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ Parameters:
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
+ scheduler ([`KarrasVeScheduler`]):
+ Scheduler for the diffusion process to be used in combination with `unet` to denoise the encoded image.
+ """
+
+ # add type hints for linting
+ unet: UNet2DModel
+ scheduler: KarrasVeScheduler
+
+ def __init__(self, unet: UNet2DModel, scheduler: KarrasVeScheduler):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(unet=unet, scheduler=scheduler)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ batch_size: int = 1,
+ num_inference_steps: int = 50,
+ generator: Optional[torch.Generator] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[Tuple, ImagePipelineOutput]:
+ r"""
+ Args:
+ batch_size (`int`, *optional*, defaults to 1):
+ The number of images to generate.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ img_size = self.unet.config.sample_size
+ shape = (batch_size, 3, img_size, img_size)
+
+ model = self.unet
+
+ # sample x_0 ~ N(0, sigma_0^2 * I)
+ sample = torch.randn(*shape) * self.scheduler.config.sigma_max
+ sample = sample.to(self.device)
+
+ self.scheduler.set_timesteps(num_inference_steps)
+
+ for t in self.progress_bar(self.scheduler.timesteps):
+ # here sigma_t == t_i from the paper
+ sigma = self.scheduler.schedule[t]
+ sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
+
+ # 1. Select temporarily increased noise level sigma_hat
+ # 2. Add new noise to move from sample_i to sample_hat
+ sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
+
+ # 3. Predict the noise residual given the noise magnitude `sigma_hat`
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
+
+ # 4. Evaluate dx/dt at sigma_hat
+ # 5. Take Euler step from sigma to sigma_prev
+ step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
+
+ if sigma_prev != 0:
+ # 6. Apply 2nd order correction
+ # The model inputs and output are adjusted by following eq. (213) in [1].
+ model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
+ step_output = self.scheduler.step_correct(
+ model_output,
+ sigma_hat,
+ sigma_prev,
+ sample_hat,
+ step_output.prev_sample,
+ step_output["derivative"],
+ )
+ sample = step_output.prev_sample
+
+ sample = (sample / 2 + 0.5).clamp(0, 1)
+ image = sample.cpu().permute(0, 2, 3, 1).numpy()
+ if output_type == "pil":
+ image = self.numpy_to_pil(sample)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/diffusers/schedulers/__init__.py b/diffusers/schedulers/__init__.py
new file mode 100644
index 000000000..20c25f351
--- /dev/null
+++ b/diffusers/schedulers/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..utils import is_scipy_available
+from .scheduling_ddim import DDIMScheduler
+from .scheduling_ddpm import DDPMScheduler
+from .scheduling_karras_ve import KarrasVeScheduler
+from .scheduling_pndm import PNDMScheduler
+from .scheduling_sde_ve import ScoreSdeVeScheduler
+from .scheduling_sde_vp import ScoreSdeVpScheduler
+from .scheduling_utils import SchedulerMixin
+
+
+if is_scipy_available():
+ from .scheduling_lms_discrete import LMSDiscreteScheduler
+else:
+ from ..utils.dummy_scipy_objects import * # noqa F403
diff --git a/diffusers/schedulers/scheduling_ddim.py b/diffusers/schedulers/scheduling_ddim.py
new file mode 100644
index 000000000..894d63bf2
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddim.py
@@ -0,0 +1,261 @@
+# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ timestep_values (`np.ndarray`, optional): TODO
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one (`bool`, default `True`):
+ if alpha for final step is 1 or the final alpha of the "non-previous" one.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this paratemer simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ self.timesteps += offset
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): TODO
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointingc to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+
+ # 4. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the model_output is always re-derived from the clipped x_0 in Glide
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ device = model_output.device if torch.is_tensor(model_output) else "cpu"
+ noise = torch.randn(model_output.shape, generator=generator).to(device)
+ variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
+
+ if not torch.is_tensor(model_output):
+ variance = variance.numpy()
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_ddpm.py b/diffusers/schedulers/scheduling_ddpm.py
new file mode 100644
index 000000000..4fbfb9038
--- /dev/null
+++ b/diffusers/schedulers/scheduling_ddpm.py
@@ -0,0 +1,264 @@
+# Copyright 2022 UC Berkely Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class DDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ variance_type (`str`):
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample between -1 and 1 for numerical stability.
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ variance_type: str = "fixed_small",
+ clip_sample: bool = True,
+ tensor_format: str = "pt",
+ ):
+
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ elif beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+ self.one = np.array(1.0)
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ self.variance_type = variance_type
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(
+ 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
+ )[::-1].copy()
+ self.set_format(tensor_format=self.tensor_format)
+
+ def _get_variance(self, t, predicted_variance=None, variance_type=None):
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
+
+ if variance_type is None:
+ variance_type = self.config.variance_type
+
+ # hacks - were probs added for training stability
+ if variance_type == "fixed_small":
+ variance = self.clip(variance, min_value=1e-20)
+ # for rl-diffuser https://arxiv.org/abs/2205.09991
+ elif variance_type == "fixed_small_log":
+ variance = self.log(self.clip(variance, min_value=1e-20))
+ elif variance_type == "fixed_large":
+ variance = self.betas[t]
+ elif variance_type == "fixed_large_log":
+ # Glide max_log
+ variance = self.log(self.betas[t])
+ elif variance_type == "learned":
+ return predicted_variance
+ elif variance_type == "learned_range":
+ min_log = variance
+ max_log = self.betas[t]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ predict_epsilon=True,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ predict_epsilon (`bool`):
+ optional flag to use when model predicts the samples directly instead of the noise, epsilon.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ t = timestep
+
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[t]
+ alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if predict_epsilon:
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ else:
+ pred_original_sample = model_output
+
+ # 3. Clip "predicted x_0"
+ if self.config.clip_sample:
+ pred_original_sample = self.clip(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
+ current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if t > 0:
+ noise = self.randn_like(model_output, generator=generator)
+ variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ if not return_dict:
+ return (pred_prev_sample,)
+
+ return SchedulerOutput(prev_sample=pred_prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_karras_ve.py b/diffusers/schedulers/scheduling_karras_ve.py
new file mode 100644
index 000000000..3a2370cfc
--- /dev/null
+++ b/diffusers/schedulers/scheduling_karras_ve.py
@@ -0,0 +1,208 @@
+# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class KarrasVeOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Derivate of predicted original image sample (x_0).
+ """
+
+ prev_sample: torch.FloatTensor
+ derivative: torch.FloatTensor
+
+
+class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
+ the VE column of Table 1 from [1] for reference.
+
+ [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
+ https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
+ differential equations." https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
+ Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
+ optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
+
+ Args:
+ sigma_min (`float`): minimum noise magnitude
+ sigma_max (`float`): maximum noise magnitude
+ s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
+ A reasonable range is [1.000, 1.011].
+ s_churn (`float`): the parameter controlling the overall amount of stochasticity.
+ A reasonable range is [0, 100].
+ s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
+ A reasonable range is [0, 10].
+ s_max (`float`): the end value of the sigma range where we add noise.
+ A reasonable range is [0.2, 80].
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ sigma_min: float = 0.02,
+ sigma_max: float = 100,
+ s_noise: float = 1.007,
+ s_churn: float = 80,
+ s_min: float = 0.05,
+ s_max: float = 50,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = None
+ self.schedule = None # sigma(t_i)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
+ self.schedule = [
+ (self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
+ for i in self.timesteps
+ ]
+ self.schedule = np.array(self.schedule, dtype=np.float32)
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def add_noise_to_input(
+ self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
+ ) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
+ """
+ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
+ higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
+
+ TODO Args:
+ """
+ if self.s_min <= sigma <= self.s_max:
+ gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
+ else:
+ gamma = 0
+
+ # sample eps ~ N(0, S_noise^2 * I)
+ eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
+ sigma_hat = sigma + gamma * sigma
+ sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
+
+ return sample_hat, sigma_hat
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ KarrasVeOutput: updated sample in the diffusion chain and derivative (TODO double check).
+ Returns:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] or `tuple`:
+ [`~schedulers.scheduling_karras_ve.KarrasVeOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+
+ pred_original_sample = sample_hat + sigma_hat * model_output
+ derivative = (sample_hat - pred_original_sample) / sigma_hat
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sigma_hat: float,
+ sigma_prev: float,
+ sample_hat: Union[torch.FloatTensor, np.ndarray],
+ sample_prev: Union[torch.FloatTensor, np.ndarray],
+ derivative: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[KarrasVeOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. TODO complete description
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sigma_hat (`float`): TODO
+ sigma_prev (`float`): TODO
+ sample_hat (`torch.FloatTensor` or `np.ndarray`): TODO
+ sample_prev (`torch.FloatTensor` or `np.ndarray`): TODO
+ derivative (`torch.FloatTensor` or `np.ndarray`): TODO
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ prev_sample (TODO): updated sample in the diffusion chain. derivative (TODO): TODO
+
+ """
+ pred_original_sample = sample_prev + sigma_prev * model_output
+ derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
+ sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
+
+ if not return_dict:
+ return (sample_prev, derivative)
+
+ return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
+
+ def add_noise(self, original_samples, noise, timesteps):
+ raise NotImplementedError()
diff --git a/diffusers/schedulers/scheduling_lms_discrete.py b/diffusers/schedulers/scheduling_lms_discrete.py
new file mode 100644
index 000000000..1381587fe
--- /dev/null
+++ b/diffusers/schedulers/scheduling_lms_discrete.py
@@ -0,0 +1,193 @@
+# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from scipy import integrate
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
+ Katherine Crowson:
+ https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear` or `scaled_linear`.
+ trained_betas (`np.ndarray`, optional): TODO
+ options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
+ `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
+ timestep_values (`np.ndarry`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ timestep_values: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self.derivatives = []
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def get_lms_coefficient(self, order, t, current_order):
+ """
+ Compute a linear multistep coefficient.
+
+ Args:
+ order (TODO):
+ t (TODO):
+ current_order (TODO):
+ """
+
+ def lms_derivative(tau):
+ prod = 1.0
+ for k in range(order):
+ if current_order == k:
+ continue
+ prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
+ return prod
+
+ integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]
+
+ return integrated_coeff
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ self.num_inference_steps = num_inference_steps
+ self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
+
+ low_idx = np.floor(self.timesteps).astype(int)
+ high_idx = np.ceil(self.timesteps).astype(int)
+ frac = np.mod(self.timesteps, 1.0)
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
+ sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
+ self.sigmas = np.concatenate([sigmas, [0.0]])
+
+ self.derivatives = []
+
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ order: int = 4,
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ order: coefficient for multi-step inference.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ sigma = self.sigmas[timestep]
+
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
+ pred_original_sample = sample - sigma * model_output
+
+ # 2. Convert to an ODE derivative
+ derivative = (sample - pred_original_sample) / sigma
+ self.derivatives.append(derivative)
+ if len(self.derivatives) > order:
+ self.derivatives.pop(0)
+
+ # 3. Compute linear multistep coefficients
+ order = min(timestep + 1, order)
+ lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
+
+ # 4. Compute previous sample based on the derivatives path
+ prev_sample = sample + sum(
+ coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> Union[torch.FloatTensor, np.ndarray]:
+ sigmas = self.match_shape(self.sigmas[timesteps], noise)
+ noisy_samples = original_samples + noise * sigmas
+
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_pndm.py b/diffusers/schedulers/scheduling_pndm.py
new file mode 100644
index 000000000..b43d88bba
--- /dev/null
+++ b/diffusers/schedulers/scheduling_pndm.py
@@ -0,0 +1,378 @@
+# Copyright 2022 Zhejiang University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas, dtype=np.float32)
+
+
+class PNDMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more details, see the original paper: https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional): TODO
+ tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
+ skip_prk_steps (`bool`):
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms steps; defaults to `False`.
+
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[np.ndarray] = None,
+ tensor_format: str = "pt",
+ skip_prk_steps: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = np.asarray(trained_betas)
+ if beta_schedule == "linear":
+ self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
+
+ self.one = np.array(1.0)
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ # running values
+ self.cur_model_output = 0
+ self.counter = 0
+ self.cur_sample = None
+ self.ets = []
+
+ # setable values
+ self.num_inference_steps = None
+ self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
+ self._offset = 0
+ self.prk_timesteps = None
+ self.plms_timesteps = None
+ self.timesteps = None
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ offset (`int`): TODO
+ """
+ self.num_inference_steps = num_inference_steps
+ self._timesteps = list(
+ range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
+ )
+ self._offset = offset
+ self._timesteps = np.array([t + self._offset for t in self._timesteps])
+
+ if self.config.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[
+ ::-1
+ ].copy()
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+
+ self.ets = []
+ self.counter = 0
+ self.set_format(tensor_format=self.tensor_format)
+
+ def step(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
+
+ def step_prk(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output += 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = 0
+
+ # cur_sample should not be `None`
+ cur_sample = self.cur_sample if self.cur_sample is not None else sample
+
+ prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def step_plms(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ return_dict: bool = True,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.config.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py "
+ "for more information."
+ )
+
+ prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
+
+ if self.counter != 1:
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = None
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
+
+ def add_noise(
+ self,
+ original_samples: Union[torch.FloatTensor, np.ndarray],
+ noise: Union[torch.FloatTensor, np.ndarray],
+ timesteps: Union[torch.IntTensor, np.ndarray],
+ ) -> torch.Tensor:
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.alphas_cumprod.device)
+ sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
+ sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_ve.py b/diffusers/schedulers/scheduling_sde_ve.py
new file mode 100644
index 000000000..e187f0796
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_ve.py
@@ -0,0 +1,283 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+@dataclass
+class SdeVeOutput(BaseOutput):
+ """
+ Output class for the ScoreSdeVeScheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
+ """
+
+ prev_sample: torch.FloatTensor
+ prev_sample_mean: torch.FloatTensor
+
+
+class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance exploding stochastic differential equation (SDE) scheduler.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ Args:
+ snr (`float`):
+ coefficient weighting the step from the model_output sample (from the network) to the random noise.
+ sigma_min (`float`):
+ initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
+ distribution of the data.
+ sigma_max (`float`): maximum value used for the range of continuous timesteps passed into the model.
+ sampling_eps (`float`): the end value of sampling, where timesteps decrease progessively from 1 to
+ epsilon.
+ correct_steps (`int`): number of correction steps performed on a produced sample.
+ tensor_format (`str`): "np" or "pt" for the expected format of samples passed to the Scheduler.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 2000,
+ snr: float = 0.15,
+ sigma_min: float = 0.01,
+ sigma_max: float = 1348.0,
+ sampling_eps: float = 1e-5,
+ correct_steps: int = 1,
+ tensor_format: str = "pt",
+ ):
+ # setable values
+ self.timesteps = None
+
+ self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
+
+ self.tensor_format = tensor_format
+ self.set_format(tensor_format=tensor_format)
+
+ def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
+ """
+ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
+ elif tensor_format == "pt":
+ self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_sigmas(
+ self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
+ ):
+ """
+ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
+
+ The sigmas control the weight of the `drift` and `diffusion` components of sample update.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ sigma_min (`float`, optional):
+ initial noise scale value (overrides value given at Scheduler instantiation).
+ sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
+ sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
+
+ """
+ sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
+ sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
+ sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
+ if self.timesteps is None:
+ self.set_timesteps(num_inference_steps, sampling_eps)
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ elif tensor_format == "pt":
+ self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
+ self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def get_adjacent_sigma(self, timesteps, t):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
+ elif tensor_format == "pt":
+ return torch.where(
+ timesteps == 0,
+ torch.zeros_like(t.to(timesteps.device)),
+ self.discrete_sigmas[timesteps - 1].to(timesteps.device),
+ )
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def set_seed(self, seed):
+ warnings.warn(
+ "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
+ " generator instead.",
+ DeprecationWarning,
+ )
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ np.random.seed(seed)
+ elif tensor_format == "pt":
+ torch.manual_seed(seed)
+ else:
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def step_pred(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ timestep: int,
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SdeVeOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ timestep = timestep * torch.ones(
+ sample.shape[0], device=sample.device
+ ) # torch.repeat_interleave(timestep, sample.shape[0])
+ timesteps = (timestep * (len(self.timesteps) - 1)).long()
+
+ # mps requires indices to be in the same device, so we use cpu as is the default with cuda
+ timesteps = timesteps.to(self.discrete_sigmas.device)
+
+ sigma = self.discrete_sigmas[timesteps].to(sample.device)
+ adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
+ drift = self.zeros_like(sample)
+ diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
+
+ # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
+ # also equation 47 shows the analog from SDE models to ancestral sampling methods
+ drift = drift - diffusion[:, None, None, None] ** 2 * model_output
+
+ # equation 6: sample noise for the diffusion term of
+ noise = self.randn_like(sample, generator=generator)
+ prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
+ # TODO is the variable diffusion the correct scaling term for the noise?
+ prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
+
+ if not return_dict:
+ return (prev_sample, prev_sample_mean)
+
+ return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
+
+ def step_correct(
+ self,
+ model_output: Union[torch.FloatTensor, np.ndarray],
+ sample: Union[torch.FloatTensor, np.ndarray],
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ **kwargs,
+ ) -> Union[SchedulerOutput, Tuple]:
+ """
+ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
+ after making the prediction for the previous timestep.
+
+ Args:
+ model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model.
+ sample (`torch.FloatTensor` or `np.ndarray`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_sde_ve.SdeVeOutput`] or `tuple`: [`~schedulers.scheduling_sde_ve.SdeVeOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ if "seed" in kwargs and kwargs["seed"] is not None:
+ self.set_seed(kwargs["seed"])
+
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
+ # sample noise for correction
+ noise = self.randn_like(sample, generator=generator)
+
+ # compute step size from the model_output, the noise, and the snr
+ grad_norm = self.norm(model_output)
+ noise_norm = self.norm(noise)
+ step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
+ step_size = step_size * torch.ones(sample.shape[0]).to(sample.device)
+ # self.repeat_scalar(step_size, sample.shape[0])
+
+ # compute corrected sample: model_output term and noise term
+ prev_sample_mean = sample + step_size[:, None, None, None] * model_output
+ prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_sde_vp.py b/diffusers/schedulers/scheduling_sde_vp.py
new file mode 100644
index 000000000..66e6ec661
--- /dev/null
+++ b/diffusers/schedulers/scheduling_sde_vp.py
@@ -0,0 +1,81 @@
+# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
+
+# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
+
+import numpy as np
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from .scheduling_utils import SchedulerMixin
+
+
+class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
+ """
+ The variance preserving stochastic differential equation (SDE) scheduler.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
+ [`~ConfigMixin.from_config`] functios.
+
+ For more information, see the original paper: https://arxiv.org/abs/2011.13456
+
+ UNDER CONSTRUCTION
+
+ """
+
+ @register_to_config
+ def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
+
+ self.sigmas = None
+ self.discrete_sigmas = None
+ self.timesteps = None
+
+ def set_timesteps(self, num_inference_steps):
+ self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
+
+ def step_pred(self, score, x, t):
+ if self.timesteps is None:
+ raise ValueError(
+ "`self.timesteps` is not set, you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # TODO(Patrick) better comments + non-PyTorch
+ # postprocess model score
+ log_mean_coeff = (
+ -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
+ )
+ std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
+ score = -score / std[:, None, None, None]
+
+ # compute
+ dt = -1.0 / len(self.timesteps)
+
+ beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
+ drift = -0.5 * beta_t[:, None, None, None] * x
+ diffusion = torch.sqrt(beta_t)
+ drift = drift - diffusion[:, None, None, None] ** 2 * score
+ x_mean = x + drift * dt
+
+ # add noise
+ noise = torch.randn_like(x)
+ x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
+
+ return x, x_mean
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/diffusers/schedulers/scheduling_utils.py b/diffusers/schedulers/scheduling_utils.py
new file mode 100644
index 000000000..f2bcd73ac
--- /dev/null
+++ b/diffusers/schedulers/scheduling_utils.py
@@ -0,0 +1,125 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Union
+
+import numpy as np
+import torch
+
+from ..utils import BaseOutput
+
+
+SCHEDULER_CONFIG_NAME = "scheduler_config.json"
+
+
+@dataclass
+class SchedulerOutput(BaseOutput):
+ """
+ Base class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class SchedulerMixin:
+ """
+ Mixin containing common functions for the schedulers.
+ """
+
+ config_name = SCHEDULER_CONFIG_NAME
+ ignore_for_config = ["tensor_format"]
+
+ def set_format(self, tensor_format="pt"):
+ self.tensor_format = tensor_format
+ if tensor_format == "pt":
+ for key, value in vars(self).items():
+ if isinstance(value, np.ndarray):
+ setattr(self, key, torch.from_numpy(value))
+
+ return self
+
+ def clip(self, tensor, min_value=None, max_value=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.clip(tensor, min_value, max_value)
+ elif tensor_format == "pt":
+ return torch.clamp(tensor, min_value, max_value)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def log(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+
+ if tensor_format == "np":
+ return np.log(tensor)
+ elif tensor_format == "pt":
+ return torch.log(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]):
+ """
+ Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
+
+ Args:
+ values: an array or tensor of values to extract.
+ broadcast_array: an array with a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ Returns:
+ a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+
+ tensor_format = getattr(self, "tensor_format", "pt")
+ values = values.flatten()
+
+ while len(values.shape) < len(broadcast_array.shape):
+ values = values[..., None]
+ if tensor_format == "pt":
+ values = values.to(broadcast_array.device)
+
+ return values
+
+ def norm(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.linalg.norm(tensor)
+ elif tensor_format == "pt":
+ return torch.norm(tensor.reshape(tensor.shape[0], -1), dim=-1).mean()
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def randn_like(self, tensor, generator=None):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.random.randn(*np.shape(tensor))
+ elif tensor_format == "pt":
+ # return torch.randn_like(tensor)
+ return torch.randn(tensor.shape, layout=tensor.layout, generator=generator).to(tensor.device)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
+
+ def zeros_like(self, tensor):
+ tensor_format = getattr(self, "tensor_format", "pt")
+ if tensor_format == "np":
+ return np.zeros_like(tensor)
+ elif tensor_format == "pt":
+ return torch.zeros_like(tensor)
+
+ raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
diff --git a/diffusers/testing_utils.py b/diffusers/testing_utils.py
new file mode 100644
index 000000000..ff8b6aa9b
--- /dev/null
+++ b/diffusers/testing_utils.py
@@ -0,0 +1,61 @@
+import os
+import random
+import unittest
+from distutils.util import strtobool
+
+import torch
+
+from packaging import version
+
+
+global_rng = random.Random()
+torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
+
+if is_torch_higher_equal_than_1_12:
+ torch_device = "mps" if torch.backends.mps.is_available() else torch_device
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+
+
+def floats_tensor(shape, scale=1.0, rng=None, name=None):
+ """Creates a random float32 tensor"""
+ if rng is None:
+ rng = global_rng
+
+ total_dims = 1
+ for dim in shape:
+ total_dims *= dim
+
+ values = []
+ for _ in range(total_dims):
+ values.append(rng.random() * scale)
+
+ return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
diff --git a/diffusers/training_utils.py b/diffusers/training_utils.py
new file mode 100644
index 000000000..fa1694161
--- /dev/null
+++ b/diffusers/training_utils.py
@@ -0,0 +1,125 @@
+import copy
+import os
+import random
+
+import numpy as np
+import torch
+
+
+def enable_full_determinism(seed: int):
+ """
+ Helper function for reproducible behavior during distributed training. See
+ - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+ """
+ # set seed first
+ set_seed(seed)
+
+ # Enable PyTorch deterministic mode. This potentially requires either the environment
+ # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+ # depending on the CUDA version, so we set them both here
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+ torch.use_deterministic_algorithms(True)
+
+ # Enable CUDNN deterministic mode
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def set_seed(seed: int):
+ """
+ Args:
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
+ seed (`int`): The seed to set.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ # ^^ safe to call this function even if cuda is not available
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(
+ self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999,
+ device=None,
+ ):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = copy.deepcopy(model).eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ if device is not None:
+ self.averaged_model = self.averaged_model.to(device=device)
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ ema_state_dict = {}
+ ema_params = self.averaged_model.state_dict()
+
+ self.decay = self.get_decay(self.optimization_step)
+
+ for key, param in new_model.named_parameters():
+ if isinstance(param, dict):
+ continue
+ try:
+ ema_param = ema_params[key]
+ except KeyError:
+ ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
+ ema_params[key] = ema_param
+
+ if not param.requires_grad:
+ ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
+ ema_param = ema_params[key]
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
+
+ ema_state_dict[key] = ema_param
+
+ for key, param in new_model.named_buffers():
+ ema_state_dict[key] = param
+
+ self.averaged_model.load_state_dict(ema_state_dict, strict=False)
+ self.optimization_step += 1
diff --git a/diffusers/utils/__init__.py b/diffusers/utils/__init__.py
new file mode 100644
index 000000000..c00a28e10
--- /dev/null
+++ b/diffusers/utils/__init__.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+
+from .import_utils import (
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
+ ENV_VARS_TRUE_VALUES,
+ USE_JAX,
+ USE_TF,
+ USE_TORCH,
+ DummyObject,
+ is_flax_available,
+ is_inflect_available,
+ is_modelcards_available,
+ is_onnx_available,
+ is_scipy_available,
+ is_tf_available,
+ is_torch_available,
+ is_transformers_available,
+ is_unidecode_available,
+ requires_backends,
+)
+from .logging import get_logger
+from .outputs import BaseOutput
+
+
+logger = get_logger(__name__)
+
+
+hf_cache_home = os.path.expanduser(
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
+)
+default_cache_path = os.path.join(hf_cache_home, "diffusers")
+
+
+CONFIG_NAME = "config.json"
+HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
+DIFFUSERS_CACHE = default_cache_path
+DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
+HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
diff --git a/diffusers/utils/dummy_scipy_objects.py b/diffusers/utils/dummy_scipy_objects.py
new file mode 100644
index 000000000..3706c5754
--- /dev/null
+++ b/diffusers/utils/dummy_scipy_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LMSDiscreteScheduler(metaclass=DummyObject):
+ _backends = ["scipy"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["scipy"])
diff --git a/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
new file mode 100644
index 000000000..8c2aec218
--- /dev/null
+++ b/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py
@@ -0,0 +1,10 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+from ..utils import DummyObject, requires_backends
+
+
+class GradTTSPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "inflect", "unidecode"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "inflect", "unidecode"])
diff --git a/diffusers/utils/dummy_transformers_and_onnx_objects.py b/diffusers/utils/dummy_transformers_and_onnx_objects.py
new file mode 100644
index 000000000..2e34b5ce0
--- /dev/null
+++ b/diffusers/utils/dummy_transformers_and_onnx_objects.py
@@ -0,0 +1,11 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class StableDiffusionOnnxPipeline(metaclass=DummyObject):
+ _backends = ["transformers", "onnx"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers", "onnx"])
diff --git a/diffusers/utils/dummy_transformers_objects.py b/diffusers/utils/dummy_transformers_objects.py
new file mode 100644
index 000000000..e05eb814d
--- /dev/null
+++ b/diffusers/utils/dummy_transformers_objects.py
@@ -0,0 +1,32 @@
+# This file is autogenerated by the command `make fix-copies`, do not edit.
+# flake8: noqa
+
+from ..utils import DummyObject, requires_backends
+
+
+class LDMTextToImagePipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionImg2ImgPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionInpaintPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
+
+
+class StableDiffusionPipeline(metaclass=DummyObject):
+ _backends = ["transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["transformers"])
diff --git a/diffusers/utils/import_utils.py b/diffusers/utils/import_utils.py
new file mode 100644
index 000000000..1f5e95ada
--- /dev/null
+++ b/diffusers/utils/import_utils.py
@@ -0,0 +1,274 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Import utilities: Utilities related to imports and our lazy inits.
+"""
+import importlib.util
+import os
+import sys
+from collections import OrderedDict
+
+from packaging import version
+
+from . import logging
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
+ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"})
+
+USE_TF = os.environ.get("USE_TF", "AUTO").upper()
+USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()
+USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper()
+
+_torch_version = "N/A"
+if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES:
+ _torch_available = importlib.util.find_spec("torch") is not None
+ if _torch_available:
+ try:
+ _torch_version = importlib_metadata.version("torch")
+ logger.info(f"PyTorch version {_torch_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _torch_available = False
+else:
+ logger.info("Disabling PyTorch because USE_TF is set")
+ _torch_available = False
+
+
+_tf_version = "N/A"
+if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES:
+ _tf_available = importlib.util.find_spec("tensorflow") is not None
+ if _tf_available:
+ candidates = (
+ "tensorflow",
+ "tensorflow-cpu",
+ "tensorflow-gpu",
+ "tf-nightly",
+ "tf-nightly-cpu",
+ "tf-nightly-gpu",
+ "intel-tensorflow",
+ "intel-tensorflow-avx512",
+ "tensorflow-rocm",
+ "tensorflow-macos",
+ "tensorflow-aarch64",
+ )
+ _tf_version = None
+ # For the metadata, we have to look for both tensorflow and tensorflow-cpu
+ for pkg in candidates:
+ try:
+ _tf_version = importlib_metadata.version(pkg)
+ break
+ except importlib_metadata.PackageNotFoundError:
+ pass
+ _tf_available = _tf_version is not None
+ if _tf_available:
+ if version.parse(_tf_version) < version.parse("2"):
+ logger.info(f"TensorFlow found but with version {_tf_version}. Diffusers requires version 2 minimum.")
+ _tf_available = False
+ else:
+ logger.info(f"TensorFlow version {_tf_version} available.")
+else:
+ logger.info("Disabling Tensorflow because USE_TORCH is set")
+ _tf_available = False
+
+
+if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
+ _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
+ if _flax_available:
+ try:
+ _jax_version = importlib_metadata.version("jax")
+ _flax_version = importlib_metadata.version("flax")
+ logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.")
+ except importlib_metadata.PackageNotFoundError:
+ _flax_available = False
+else:
+ _flax_available = False
+
+
+_transformers_available = importlib.util.find_spec("transformers") is not None
+try:
+ _transformers_version = importlib_metadata.version("transformers")
+ logger.debug(f"Successfully imported transformers version {_transformers_version}")
+except importlib_metadata.PackageNotFoundError:
+ _transformers_available = False
+
+
+_inflect_available = importlib.util.find_spec("inflect") is not None
+try:
+ _inflect_version = importlib_metadata.version("inflect")
+ logger.debug(f"Successfully imported inflect version {_inflect_version}")
+except importlib_metadata.PackageNotFoundError:
+ _inflect_available = False
+
+
+_unidecode_available = importlib.util.find_spec("unidecode") is not None
+try:
+ _unidecode_version = importlib_metadata.version("unidecode")
+ logger.debug(f"Successfully imported unidecode version {_unidecode_version}")
+except importlib_metadata.PackageNotFoundError:
+ _unidecode_available = False
+
+
+_modelcards_available = importlib.util.find_spec("modelcards") is not None
+try:
+ _modelcards_version = importlib_metadata.version("modelcards")
+ logger.debug(f"Successfully imported modelcards version {_modelcards_version}")
+except importlib_metadata.PackageNotFoundError:
+ _modelcards_available = False
+
+
+_onnx_available = importlib.util.find_spec("onnxruntime") is not None
+try:
+ _onnxruntime_version = importlib_metadata.version("onnxruntime")
+ logger.debug(f"Successfully imported onnxruntime version {_onnxruntime_version}")
+except importlib_metadata.PackageNotFoundError:
+ _onnx_available = False
+
+
+_scipy_available = importlib.util.find_spec("scipy") is not None
+try:
+ _scipy_version = importlib_metadata.version("scipy")
+ logger.debug(f"Successfully imported transformers version {_scipy_version}")
+except importlib_metadata.PackageNotFoundError:
+ _scipy_available = False
+
+
+def is_torch_available():
+ return _torch_available
+
+
+def is_tf_available():
+ return _tf_available
+
+
+def is_flax_available():
+ return _flax_available
+
+
+def is_transformers_available():
+ return _transformers_available
+
+
+def is_inflect_available():
+ return _inflect_available
+
+
+def is_unidecode_available():
+ return _unidecode_available
+
+
+def is_modelcards_available():
+ return _modelcards_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_scipy_available():
+ return _scipy_available
+
+
+# docstyle-ignore
+FLAX_IMPORT_ERROR = """
+{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
+installation page: https://github.com/google/flax and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+INFLECT_IMPORT_ERROR = """
+{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install
+inflect`
+"""
+
+# docstyle-ignore
+PYTORCH_IMPORT_ERROR = """
+{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
+installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+ONNX_IMPORT_ERROR = """
+{0} requires the onnxruntime library but it was not found in your environment. You can install it with pip: `pip
+install onnxruntime`
+"""
+
+# docstyle-ignore
+SCIPY_IMPORT_ERROR = """
+{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
+scipy`
+"""
+
+# docstyle-ignore
+TENSORFLOW_IMPORT_ERROR = """
+{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the
+installation page: https://www.tensorflow.org/install and follow the ones that match your environment.
+"""
+
+# docstyle-ignore
+TRANSFORMERS_IMPORT_ERROR = """
+{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip
+install transformers`
+"""
+
+# docstyle-ignore
+UNIDECODE_IMPORT_ERROR = """
+{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install
+Unidecode`
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("flax", (is_flax_available, FLAX_IMPORT_ERROR)),
+ ("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
+ ("onnx", (is_onnx_available, ONNX_IMPORT_ERROR)),
+ ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
+ ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)),
+ ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)),
+ ("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
+ ("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
diff --git a/diffusers/utils/logging.py b/diffusers/utils/logging.py
new file mode 100644
index 000000000..1f2d0227b
--- /dev/null
+++ b/diffusers/utils/logging.py
@@ -0,0 +1,344 @@
+# coding=utf-8
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Logging utilities."""
+
+import logging
+import os
+import sys
+import threading
+from logging import CRITICAL # NOQA
+from logging import DEBUG # NOQA
+from logging import ERROR # NOQA
+from logging import FATAL # NOQA
+from logging import INFO # NOQA
+from logging import NOTSET # NOQA
+from logging import WARN # NOQA
+from logging import WARNING # NOQA
+from typing import Optional
+
+from tqdm import auto as tqdm_lib
+
+
+_lock = threading.Lock()
+_default_handler: Optional[logging.Handler] = None
+
+log_levels = {
+ "debug": logging.DEBUG,
+ "info": logging.INFO,
+ "warning": logging.WARNING,
+ "error": logging.ERROR,
+ "critical": logging.CRITICAL,
+}
+
+_default_log_level = logging.WARNING
+
+_tqdm_active = True
+
+
+def _get_default_logging_level():
+ """
+ If DIFFUSERS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
+ not - fall back to `_default_log_level`
+ """
+ env_level_str = os.getenv("DIFFUSERS_VERBOSITY", None)
+ if env_level_str:
+ if env_level_str in log_levels:
+ return log_levels[env_level_str]
+ else:
+ logging.getLogger().warning(
+ f"Unknown option DIFFUSERS_VERBOSITY={env_level_str}, "
+ f"has to be one of: { ', '.join(log_levels.keys()) }"
+ )
+ return _default_log_level
+
+
+def _get_library_name() -> str:
+
+ return __name__.split(".")[0]
+
+
+def _get_library_root_logger() -> logging.Logger:
+
+ return logging.getLogger(_get_library_name())
+
+
+def _configure_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if _default_handler:
+ # This library has already configured the library root logger.
+ return
+ _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
+ _default_handler.flush = sys.stderr.flush
+
+ # Apply our default configuration to the library root logger.
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_get_default_logging_level())
+ library_root_logger.propagate = False
+
+
+def _reset_library_root_logger() -> None:
+
+ global _default_handler
+
+ with _lock:
+ if not _default_handler:
+ return
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.removeHandler(_default_handler)
+ library_root_logger.setLevel(logging.NOTSET)
+ _default_handler = None
+
+
+def get_log_levels_dict():
+ return log_levels
+
+
+def get_logger(name: Optional[str] = None) -> logging.Logger:
+ """
+ Return a logger with the specified name.
+
+ This function is not supposed to be directly accessed unless you are writing a custom diffusers module.
+ """
+
+ if name is None:
+ name = _get_library_name()
+
+ _configure_library_root_logger()
+ return logging.getLogger(name)
+
+
+def get_verbosity() -> int:
+ """
+ Return the current level for the 🤗 Diffusers' root logger as an int.
+
+ Returns:
+ `int`: The logging level.
+
+
+
+ 🤗 Diffusers has following logging levels:
+
+ - 50: `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - 40: `diffusers.logging.ERROR`
+ - 30: `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - 20: `diffusers.logging.INFO`
+ - 10: `diffusers.logging.DEBUG`
+
+ """
+
+ _configure_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ """
+ Set the verbosity level for the 🤗 Diffusers' root logger.
+
+ Args:
+ verbosity (`int`):
+ Logging level, e.g., one of:
+
+ - `diffusers.logging.CRITICAL` or `diffusers.logging.FATAL`
+ - `diffusers.logging.ERROR`
+ - `diffusers.logging.WARNING` or `diffusers.logging.WARN`
+ - `diffusers.logging.INFO`
+ - `diffusers.logging.DEBUG`
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_info():
+ """Set the verbosity to the `INFO` level."""
+ return set_verbosity(INFO)
+
+
+def set_verbosity_warning():
+ """Set the verbosity to the `WARNING` level."""
+ return set_verbosity(WARNING)
+
+
+def set_verbosity_debug():
+ """Set the verbosity to the `DEBUG` level."""
+ return set_verbosity(DEBUG)
+
+
+def set_verbosity_error():
+ """Set the verbosity to the `ERROR` level."""
+ return set_verbosity(ERROR)
+
+
+def disable_default_handler() -> None:
+ """Disable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().removeHandler(_default_handler)
+
+
+def enable_default_handler() -> None:
+ """Enable the default handler of the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert _default_handler is not None
+ _get_library_root_logger().addHandler(_default_handler)
+
+
+def add_handler(handler: logging.Handler) -> None:
+ """adds a handler to the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None
+ _get_library_root_logger().addHandler(handler)
+
+
+def remove_handler(handler: logging.Handler) -> None:
+ """removes given handler from the HuggingFace Diffusers' root logger."""
+
+ _configure_library_root_logger()
+
+ assert handler is not None and handler not in _get_library_root_logger().handlers
+ _get_library_root_logger().removeHandler(handler)
+
+
+def disable_propagation() -> None:
+ """
+ Disable propagation of the library log outputs. Note that log propagation is disabled by default.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = False
+
+
+def enable_propagation() -> None:
+ """
+ Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent
+ double logging if the root logger has been configured.
+ """
+
+ _configure_library_root_logger()
+ _get_library_root_logger().propagate = True
+
+
+def enable_explicit_format() -> None:
+ """
+ Enable explicit formatting for every HuggingFace Diffusers' logger. The explicit formatter is as follows:
+ ```
+ [LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
+ ```
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
+ handler.setFormatter(formatter)
+
+
+def reset_format() -> None:
+ """
+ Resets the formatting for HuggingFace Diffusers' loggers.
+
+ All handlers currently bound to the root logger are affected by this method.
+ """
+ handlers = _get_library_root_logger().handlers
+
+ for handler in handlers:
+ handler.setFormatter(None)
+
+
+def warning_advice(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warninging()`, but if env var DIFFUSERS_NO_ADVISORY_WARNINGS=1 is set, this
+ warning will not be printed
+ """
+ no_advisory_warnings = os.getenv("DIFFUSERS_NO_ADVISORY_WARNINGS", False)
+ if no_advisory_warnings:
+ return
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_advice = warning_advice
+
+
+class EmptyTqdm:
+ """Dummy tqdm which doesn't do anything."""
+
+ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
+ self._iterator = args[0] if args else None
+
+ def __iter__(self):
+ return iter(self._iterator)
+
+ def __getattr__(self, _):
+ """Return empty function."""
+
+ def empty_fn(*args, **kwargs): # pylint: disable=unused-argument
+ return
+
+ return empty_fn
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ return
+
+
+class _tqdm_cls:
+ def __call__(self, *args, **kwargs):
+ if _tqdm_active:
+ return tqdm_lib.tqdm(*args, **kwargs)
+ else:
+ return EmptyTqdm(*args, **kwargs)
+
+ def set_lock(self, *args, **kwargs):
+ self._lock = None
+ if _tqdm_active:
+ return tqdm_lib.tqdm.set_lock(*args, **kwargs)
+
+ def get_lock(self):
+ if _tqdm_active:
+ return tqdm_lib.tqdm.get_lock()
+
+
+tqdm = _tqdm_cls()
+
+
+def is_progress_bar_enabled() -> bool:
+ """Return a boolean indicating whether tqdm progress bars are enabled."""
+ global _tqdm_active
+ return bool(_tqdm_active)
+
+
+def enable_progress_bar():
+ """Enable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = True
+
+
+def disable_progress_bar():
+ """Disable tqdm progress bar."""
+ global _tqdm_active
+ _tqdm_active = False
diff --git a/diffusers/utils/outputs.py b/diffusers/utils/outputs.py
new file mode 100644
index 000000000..b02f62d02
--- /dev/null
+++ b/diffusers/utils/outputs.py
@@ -0,0 +1,109 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Generic utilities
+"""
+
+import warnings
+from collections import OrderedDict
+from dataclasses import fields
+from typing import Any, Tuple
+
+import numpy as np
+
+from .import_utils import is_torch_available
+
+
+def is_tensor(x):
+ """
+ Tests if `x` is a `torch.Tensor` or `np.ndarray`.
+ """
+ if is_torch_available():
+ import torch
+
+ if isinstance(x, torch.Tensor):
+ return True
+
+ return isinstance(x, np.ndarray)
+
+
+class BaseOutput(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not len(class_fields):
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = {k: v for (k, v) in self.items()}
+ if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
+ warnings.warn(
+ "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
+ " `'images'` instead.",
+ DeprecationWarning,
+ )
+ return inner_dict["images"]
+ return inner_dict[k]
+ else:
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(self[k] for k in self.keys())
diff --git a/edit.sh b/edit.sh
new file mode 100755
index 000000000..50234090c
--- /dev/null
+++ b/edit.sh
@@ -0,0 +1,3 @@
+vim diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+#vim /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+#cp /opt/miniconda3/envs/sd_gpu/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py .
diff --git a/geneticsd.py b/geneticsd.py
new file mode 100644
index 000000000..e6f44136c
--- /dev/null
+++ b/geneticsd.py
@@ -0,0 +1,773 @@
+# A ton of imports.
+from gfpgan.utils import GFPGANer
+import cv2
+import random
+import os
+import time
+import torch
+import numpy as np
+import shutil
+import PIL
+from PIL import Image
+from einops import rearrange, repeat
+from torch import autocast
+from diffusers import StableDiffusionPipeline
+import webbrowser
+from deep_translator import GoogleTranslator
+from langdetect import detect
+from joblib import Parallel, delayed
+import torch
+from PIL import Image
+from RealESRGAN import RealESRGAN
+import pyttsx3
+import pyfiglet
+import pygame
+from os import listdir
+from os.path import isfile, join
+
+# Let's parametrize a few things.
+os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+model_id = "CompVis/stable-diffusion-v1-4"
+device = "mps" #torch.device("mps")
+
+white = (255, 255, 255)
+green = (0, 255, 0)
+darkgreen = (0, 128, 0)
+red = (255, 0, 0)
+blue = (0, 0, 128)
+black = (0, 0, 0)
+
+os.environ["skl"] = "nn"
+os.environ["epsilon"] = "0.005"
+os.environ["decay"] = "0."
+os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne"
+os.environ["forcedlatent"] = ""
+latent_forcing = ""
+os.environ["good"] = "[]"
+os.environ["bad"] = "[]"
+num_iterations = 50
+gs = 7.5
+sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000))
+all_files = []
+llambda = 15
+
+# Creating the voice engine.
+noise = pyttsx3.init()
+noise.setProperty("rate", 240)
+def speak(text):
+ noise.say(text)
+ noise.runAndWait()
+
+
+# Initialization.
+all_selected = [] # List of all selected images, over all the run.
+all_selected_latent = [] # The corresponding latent variables.
+final_selection = [] # Selection of files during the final iteration.
+final_selection_latent = [] # Selection of files during the final iteration.
+forcedlatents = [] # Latent variables that we want to see soon.
+forcedgs = [] # forcedgs[i] is the guidance strength that we want to see for image number i.
+assert llambda < 16, "lambda < 16 for convenience in pygame."
+bad = []
+five_best = []
+latent = []
+images = []
+onlyfiles = []
+
+# Creating the main pipeline.
+pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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")
+pipe = pipe.to(device)
+
+
+# A ton of prompts, for fun.
+prompt = "a photo of an astronaut riding a horse on mars"
+prompt = "a photo of a red panda with a hat playing table tennis"
+prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"])
+prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"])
+prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"])
+prompt = "Photo of a female Terminator."
+prompt = random.choice([
+ "Photo of Tarzan as a lawyer with a tie",
+ "Photo of Scarlett Johansson as a sumo-tori",
+ "Photo of the little mermaid as a young black girl",
+ "Photo of Schwarzy with tentacles",
+ "Photo of Meg Myers with an Egyptian dress",
+ "Photo of Schwarzy as a ballet dancer",
+ ])
+
+
+name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"])
+prompt = f"Photo of {name} as a sumo-tori."
+
+prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori."
+prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori."
+prompt = "A close up photographic portrait of a young woman with uniformly colored hair."
+prompt = "Zombies raising and worshipping a flying human."
+prompt = "Zombies trying to kill Meg Myers."
+prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun."
+prompt = "Meg Myers grabbing a vampire by the scruff of the neck."
+prompt = "Mark Zuckerberg chokes a vampire to death."
+prompt = "Mark Zuckerberg riding an animal."
+prompt = "A giant cute animal worshipped by zombies."
+prompt = "Several faces."
+prompt = "An armoured Yann LeCun fighting tentacles in the jungle."
+prompt = "Tentacles everywhere."
+prompt = "A photo of a smiling Medusa."
+prompt = "Medusa."
+prompt = "Meg Myers in bloody armor fending off tentacles with a sword."
+prompt = "A red-haired woman with red hair. Her head is tilted."
+prompt = "A bloody heavy-metal zombie with a chainsaw."
+prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw."
+prompt = "Bizarre art."
+prompt = "Beautiful bizarre woman."
+prompt = "Yann LeCun as the grim reaper: bizarre art."
+prompt = "Un chat en sang et en armure joue de la batterie."
+prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber."
+prompt = "A ferocious cyborg bear."
+prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber."
+prompt = "A bear with horns and blood and big teeth."
+prompt = "A photo of a bear and Yoda, good friends."
+prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center."
+prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background."
+prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat."
+prompt = "A flying white owl above 4 colored pots with fire."
+prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake."
+prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber."
+prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling."
+prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head."
+prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster."
+prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames."
+prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other."
+prompt = "Photo of a cute red panda."
+prompt = "Photo of a cute smiling white-haired woman with pink eyes."
+prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber."
+prompt = "A portrait of a cute smiling woman."
+prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes."
+prompt = "Photo of a red haired man with tilted head."
+prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda."
+prompt = "A photo of Yoda fighting Meg Myers with light sabers."
+prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair."
+prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping."
+prompt = "A star with flashy colors."
+prompt = "Portrait of a green haired woman with blue eyes."
+prompt = "Portrait of a female kung-fu master."
+prompt = "In a dark cave, in the middle of computers, a bearded red-haired geek with squared glasses meets the devil."
+prompt = "Photo of the devil, with horns. There are flames in the background."
+prompt = "Yann LeCun fighting Pinocchio with light sabers."
+prompt = "Yann LeCun attacks a triceratops with a lightsaber."
+prompt = "A cyberpunk man next to a cyberpunk woman."
+prompt = "A smiling woman with a Katana and electronic patches."
+prompt = "Photo of a bearded, long-haired man with glasses and a blonde-haired woman. Both are smiling. Cats and drums and computers on shelves in the background."
+print(f"The prompt is {prompt}")
+
+
+print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !"))
+print(pyfiglet.figlet_format("First, let us choose the text :-)!"))
+
+
+
+print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n")
+speak("Hey!")
+user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n")
+if len(user_prompt) > 2:
+ prompt = user_prompt
+
+# On the fly translation.
+language = detect(prompt)
+english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
+def to_native(stri):
+ return GoogleTranslator(source='en', target=language).translate(stri)
+
+def pretty_print(stri):
+ print(pyfiglet.figlet_format(to_native(stri)))
+
+print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.")
+
+
+# Converting a latent var to an image.
+def latent_to_image(latent):
+ os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten()))
+ with autocast("cuda"):
+ image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0]
+ os.environ["forcedlatent"] = "[]"
+ return image
+
+# Creating the super-resolution stuff. RealESRGAN is fantastic!
+sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+esrmodel = RealESRGAN(sr_device, scale=4)
+esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True)
+esrmodel2 = RealESRGAN(sr_device, scale=2)
+esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True)
+
+def fe(path):
+ fe = GFPGANer(model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2)
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ _, _, output = fe.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
+ cv2.imwrite(path, output)
+
+def singleeg(path_to_image):
+ image = Image.open(path_to_image).convert('RGB')
+ sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Type before SR = {type(image)}")
+ sr_image = esrmodel.predict(image)
+ print(f"Type after SR = {type(sr_image)}")
+ output_filename = path_to_image + ".SR.png"
+ sr_image.save(output_filename)
+ fe(output_filename)
+ return output_filename
+
+# A version with x2.
+def singleeg2(path_to_image):
+ image = Image.open(path_to_image).convert('RGB')
+ sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ print(f"Type before SR = {type(image)}")
+ sr_image = esrmodel2.predict(image)
+ print(f"Type after SR = {type(sr_image)}")
+ output_filename = path_to_image + ".SR.png"
+ sr_image.save(output_filename)
+ fe(output_filename)
+ return output_filename
+
+
+# realESRGan applied to many files.
+def eg(list_of_files, last_list_of_files):
+ pretty_print("Should I convert images below to high resolution ?")
+ print(list_of_files)
+ print("Last iteration:")
+ print(last_list_of_files)
+ speak("Go to the text window!")
+ answer = input(" [y]es / [n]o / [j]ust the ones in last iteration ?")
+ if "y" in answer or "Y" in answer or "j" in answer or "J" in answer:
+ if j in answer or "J" in answer:
+ list_of_files = last_list_of_files
+ #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files)
+ #print(to_native(f"Created the super-resolution files {images}"))
+ for path_to_image in list_of_files:
+ output_filename = singleeg(path_to_image)
+ print(to_native(f"Created the super-resolution file {output_filename}"))
+
+# When we stop the run and check and propose to do super-resolution and/or animations.
+def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_latent):
+ print(to_native("Your selected images and the last generation:"))
+ print(list_of_files)
+ eg(list_of_files, last_list_of_files)
+ pretty_print("Should we create animations ?")
+ answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?")
+ if "y" in answer or "Y" in answer or "j" in answer or "J" in answer:
+ assert len(list_of_files) == len(list_of_latent)
+ if "j" in answer or "J" in answer:
+ list_of_latent = last_list_of_latent
+ pretty_print("Let us create animations!")
+ for c in sorted([0.0025, 0.005, 0.01, 0.02]):
+ for idx in range(len(list_of_files)):
+ images = []
+ l = list_of_latent[idx].reshape(1,4,64,64)
+ l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l
+ l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64)
+ l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1
+ l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64)
+ l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2
+ num_animation_steps = 13
+ index = 0
+ for u in np.linspace(0., 2*3.14159 * (1-1/30), 30):
+ cc = np.cos(u)
+ ss = np.sin(u*2)
+ index += 1
+ image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l))
+ image_name = f"imgA{index}.png"
+ image.save(image_name)
+ fe(image_name)
+ images += [image_name]
+
+ print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}"))
+ images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images)
+ frames = [Image.open(image) for image in images]
+ frame_one = frames[0]
+ gif_name = list_of_files[idx] + "_" + str(c) + ".gif"
+ frame_one.save(gif_name, format="GIF", append_images=frames,
+ save_all=True, duration=100, loop=0)
+ webbrowser.open(os.environ["PWD"] + "/" + gif_name)
+
+ pretty_print("Should we create a meme ?")
+ answer = input(" [y]es or [n]o ?")
+ if "y" in answer or "Y" in answer:
+ url = 'https://imgflip.com/memegenerator'
+ webbrowser.open(url)
+ pretty_print("Good bye!")
+ exit()
+
+
+
+
+
+pretty_print("Now let us choose (if you want) an image as a start.")
+image_name = input(to_native("Name of image for starting ? (enter if no start image)"))
+
+# activate the pygame library .
+pygame.init()
+X = 2000 # > 1500 = buttons
+Y = 900
+scrn = pygame.display.set_mode((1700, Y + 100))
+font = pygame.font.Font('freesansbold.ttf', 22)
+minifont = pygame.font.Font('freesansbold.ttf', 11)
+bigfont = pygame.font.Font('freesansbold.ttf', 44)
+
+def load_img(path):
+ image = Image.open(path).convert("RGB")
+ w, h = image.size
+ print(to_native(f"loaded input image of size ({w}, {h}) from {path}"))
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+ image = image.resize((512, 512), resample=PIL.Image.LANCZOS)
+ #image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image[None].transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image)
+ return 2.*image - 1.
+
+model = pipe.vae
+
+def img_to_latent(path):
+ #init_image = 1.8 * load_img(path).to(device)
+ init_image = load_img(path).to(device)
+ init_image = repeat(init_image, '1 ... -> b ...', b=1)
+ forced_latent = model.encode(init_image.to(device)).latent_dist.sample()
+ new_fl = forced_latent.cpu().detach().numpy().flatten()
+ new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+ return new_fl
+
+def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None):
+ base_init_image = load_img(image_name).to(device)
+ new_base_init_image = base_init_image
+ c = np.exp(np.random.randn()) if c is None else c
+ f = np.exp(-3. * np.random.rand()) if f is None else f
+ init_image_shape = base_init_image.cpu().numpy().shape
+ init_image = c * new_base_init_image
+ init_image = repeat(init_image, '1 ... -> b ...', b=1)
+ forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample()
+ new_fl = forced_latent.cpu().detach().numpy().flatten()
+ basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl
+ basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+ epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon
+ new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64)
+ scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale
+ new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+ #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"])))
+ #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png")
+ #gs=7.5, f=0.12, scale=3.7, epsilon=0.01,1 c=2.05
+ return new_fl
+
+# In case the user wants to start from a given image.
+if len(image_name) > 0:
+ pretty_print("Importing an image !")
+ try:
+ init_image = load_img(image_name).to(device)
+ except:
+ pretty_print("Try again!")
+ pretty_print("Loading failed!!")
+ image_name = input(to_native("Name of image for starting ? (enter if no start image)"))
+
+ base_init_image = load_img(image_name).to(device)
+ speak("Image loaded!")
+ print(base_init_image.shape)
+ print(np.max(base_init_image.cpu().detach().numpy().flatten()))
+ print(np.min(base_init_image.cpu().detach().numpy().flatten()))
+
+ forcedlatents = []
+ try:
+ latent_file = image_name + ".latent.txt"
+ print(to_native(f"Trying to load latent variables in {latent_file}."))
+ f = open(latent_file, "r")
+ print(to_native("File opened."))
+ latent_str = f.read()
+ print("Latent string read.")
+ latent_found = True
+ for i in range(llambda):
+ basic_new_fl = np.asarray(eval(latent_str))
+ if i > 0:
+ basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+ epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6)
+ #print(f"{i} -- {i % 7} {c} {f} {epsilon}")
+ new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64)
+ else:
+ new_fl = basic_new_fl
+ new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+ forcedlatents += [new_fl]
+ except:
+ print(to_native("No latent file: guessing."))
+ for i in range(llambda):
+ forcedlatents += [randomized_image_to_latent(image_name)] #img_to_latent(voronoi_name)
+
+# We start the big time consuming loop!
+for iteration in range(3000): # Kind of an infinite loop.
+ latent = [latent[f] for f in five_best]
+ images = [images[f] for f in five_best]
+ onlyfiles = [onlyfiles[f] for f in five_best]
+ early_stop = []
+ speak("Wait!")
+ final_selection = []
+ final_selection_latent = []
+ for k in range(llambda):
+ if len(early_stop) > 0:
+ break
+ max_created_index = k
+ if k < len(forcedlatents):
+ latent_forcing = str(list(forcedlatents[k].flatten()))
+ print(f"We play with {latent_forcing[:20]}")
+ if k < len(five_best):
+ imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300))
+ scrn.blit(imp, (300 * (k // 3), 300 * (k % 3)))
+ pygame.display.flip()
+ continue
+ pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100))
+ pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100))
+ text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4))
+ text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8))
+ text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2))
+ text0 = font.render(to_native(f'... ... ... '), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8))
+
+ text1 = minifont.render(to_native('Undo: click for '), True, green, blue)
+ text1 = pygame.transform.rotate(text1, 90)
+ scrn.blit(text1, (X*3/4+X/16+X/64 - X/32, Y/12))
+ text1 = font.render(to_native('resetting your clicks.'), True, green, blue)
+ text1 = pygame.transform.rotate(text1, 90)
+ scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, Y/12))
+ # Button for quitting and effects
+ text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3))
+ text2 = font.render(to_native('Click for stopping,'), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3))
+ text2 = font.render(to_native('and get the effects.'), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3))
+
+ pygame.display.flip()
+ os.environ["earlystop"] = "False" if k > len(five_best) else "True"
+ os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda)
+ os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2)
+ os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3]
+ previous_gs = gs
+ if k < len(forcedgs):
+ gs = forcedgs[k]
+ image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"])))
+ gs = previous_gs
+
+ images += [image]
+ filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png"
+ image.save(filename)
+ fe(filename)
+ onlyfiles += [filename]
+ imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300))
+ scrn.blit(imp, (300 * (k // 3), 300 * (k % 3)))
+ pygame.display.flip()
+ print('\a') # beep!
+ str_latent = eval((os.environ["latent_sd"]))
+ array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)")
+ print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}")
+ latent += [array_latent]
+ with open(filename + ".latent.txt", 'w') as f:
+ f.write(f"{str_latent}")
+
+ # In case of early stopping, we stop the loop.
+ first_event = True
+ for i in pygame.event.get():
+ if i.type == pygame.MOUSEBUTTONUP:
+ if first_event:
+ speak("Ok I stop!")
+ first_event = False
+ pos = pygame.mouse.get_pos()
+ index = 3 * (pos[0] // 300) + (pos[1] // 300)
+ if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3:
+ stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent)
+ exit()
+ if index <= k:
+ pretty_print(("You clicked for requesting an early stopping."))
+ early_stop = [pos]
+ break
+ early_stop = [(1,1)]
+ satus = False
+ forcedgs = []
+
+ speak("Please choose!")
+ pretty_print("Please choose your images.")
+ text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4))
+ text0 = font.render(to_native(f'=================================='), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8))
+ text0 = font.render(to_native(f'=================================='), True, green, blue)
+ scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2))
+ # Add rectangles
+ pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2)
+ pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2)
+ pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2)
+ pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2)
+
+ # Button for loading a starting point
+ #text1 = font.render('Manually edit an image.', True, green, blue)
+ #text1 = pygame.transform.rotate(text1, 90)
+ #scrn.blit(text1, (X*3/4+X/16 - X/32, 0))
+ #text1 = font.render('& latent ', True, green, blue)
+ #text1 = pygame.transform.rotate(text1, 90)
+ #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0))
+
+ # Button for stopping now.
+ text2 = font.render(to_native('Click ,'), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10))
+ text2 = font.render(to_native('for finishing with effects.'), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10))
+ text2 = font.render(to_native('or manually edit.'), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16+X/32 , Y/3+10))
+
+ # Button for new generation
+ text3 = font.render(to_native(f"I don't want to select images"), True, green, blue)
+ text3 = pygame.transform.rotate(text3, 90)
+ scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10))
+ text3 = font.render(to_native(f"Just rerun."), True, green, blue)
+ text3 = pygame.transform.rotate(text3, 90)
+ scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10))
+ text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue)
+ scrn.blit(text4, (300, Y + 30))
+ pygame.display.flip()
+
+ for idx in range(max_created_index + 1):
+ # set the pygame window name
+ pygame.display.set_caption(prompt)
+ print(to_native(f"Pasting image {onlyfiles[idx]}..."))
+ imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300))
+ scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3)))
+
+ # paint screen one time
+ pygame.display.flip()
+ status = True
+ indices = []
+ good = []
+ five_best = []
+ for i in pygame.event.get():
+ if i.type == pygame.MOUSEBUTTONUP:
+ print(to_native(".... too early for clicking !!!!"))
+
+
+ pretty_print("Please click on your favorite elements!")
+ print(to_native("You might just click on one image and we will provide variations."))
+ print(to_native("Or you can click on the top of an image and the bottom of another one."))
+ print(to_native("Click on the << new generation >> when you're done."))
+ while (status):
+
+ # iterate over the list of Event objects
+ # that was returned by pygame.event.get() method.
+ for i in pygame.event.get():
+ if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP:
+ pos = pygame.mouse.get_pos()
+ pretty_print(f"Detected! Click at {pos}")
+ if pos[1] > Y:
+ pretty_print("Let us update parameters!")
+ text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue)
+ scrn.blit(text4, (300, Y + 30))
+ pygame.display.flip()
+ try:
+ num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n")))
+ except:
+ num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n")))
+ gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n")))
+ print(to_native(f"The current text is << {prompt} >>."))
+ print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text."))
+ new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt))
+ if len(new_prompt) > 2:
+ if new_prompt[0] == "+":
+ prompt += new_prompt[1:]
+ else:
+ prompt = new_prompt
+ language = detect(prompt)
+ english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
+ pretty_print("Ok! Parameters updated.")
+ pretty_print("==> go back to the window!")
+ text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue)
+ scrn.blit(text4, (300, Y + 30))
+ pygame.display.flip()
+ elif pos[0] > 1500: # Not in the images.
+ if pos[1] < Y/3:
+ indices = []
+ good = []
+ final_selection = []
+ final_selection_latent = []
+ #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n"))
+ #status = False
+ #with open(filename, 'r') as f:
+ # latent = f.read()
+ #break
+ #pretty_print("Easy! I exit now, you edit the file and you save it.")
+ #pretty_print("Then just relaunch me and provide the text and the image.")
+ #exit()
+ if pos[1] < 2*Y/3:
+ #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))]
+ #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)]
+ assert len(onlyfiles) == len(latent)
+ assert len(all_selected) == len(all_selected_latent)
+ stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) # + onlyfiles, all_selected_latent + latent)
+ exit()
+ status = False
+ break
+ index = 3 * (pos[0] // 300) + (pos[1] // 300)
+ pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0)
+ if index <= max_created_index:
+ selected_filename = to_native("Selected") + onlyfiles[index]
+ shutil.copyfile(onlyfiles[index], selected_filename)
+ assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}"
+ all_selected += [selected_filename]
+ all_selected_latent += [latent[index]]
+ final_selection += [selected_filename]
+ final_selection_latent += [latent[index]]
+ text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue)
+ text2 = pygame.transform.rotate(text2, 90)
+ scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3))
+ if index not in five_best and len(five_best) < 5:
+ five_best += [index]
+ indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]]
+ # Update the button for new generation.
+ pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y))
+ pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2)
+ text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue)
+ text3 = pygame.transform.rotate(text3, 90)
+ scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3))
+ text3 = font.render(to_native(f" Click for new generation!"), True, green, blue)
+ text3 = pygame.transform.rotate(text3, 90)
+ scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3))
+ pygame.display.flip()
+ #text3Rect = text3.get_rect()
+ #text3Rect.center = (750+750*3/4, 1000)
+ good += [list(latent[index].flatten())]
+ else:
+ speak("Bad click ! Click on an image.")
+ pretty_print("Bad click! Click on image.")
+
+ if i.type == pygame.QUIT:
+ status = False
+
+ # Covering old images with full circles.
+ for _ in range(123):
+ x = np.random.randint(1500)
+ y = np.random.randint(900)
+ pygame.draw.circle(scrn, darkgreen,
+ [x, y], 17, 0)
+ pygame.display.update()
+ if len(indices) == 0:
+ print("The user did not like anything! Rerun :-(")
+ continue
+ print(f"Clicks at {indices}")
+ os.environ["mu"] = str(len(indices))
+ forcedlatents = []
+ bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]
+ #sauron = 0 * latent[0]
+ #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]:
+ # sauron += latent[u]
+ #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron
+ if len(bad) > 500:
+ bad = bad[(len(bad) - 500):]
+ print(to_native(f"{len(indices)} indices are selected."))
+ #print(f"indices = {indices}")
+ os.environ["good"] = str(good)
+ os.environ["bad"] = str(bad)
+ coefficients = np.zeros(len(indices))
+ numpy_images = [np.array(image) for image in images]
+ for a in range(llambda):
+ voronoi_in_images = False #(a % 2 == 1) and len(good) > 1
+ if voronoi_in_images:
+ image = np.array(numpy_images[0])
+ print(f"Voronoi in the image space! {a} / {llambda}")
+ for i in range(len(indices)):
+ coefficients[i] = np.exp(np.random.randn())
+ # Creating a forcedlatent.
+ for i in range(512):
+ x = i / 511.
+ for j in range(512):
+ y = j / 511
+ mindistances = 10000000000.
+ for u in range(len(indices)):
+ distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) )
+ if distance < mindistances:
+ mindistances = distance
+ uu = indices[u][0]
+ image[i][j][:] = numpy_images[uu][i][j][:]
+ # Conversion before using img2latent
+ pil_image = Image.fromarray(image)
+ voronoi_name = f"voronoi{a}_iteration{iteration}.png"
+ pil_image.save(voronoi_name)
+ #timage = np.array([image]).astype(np.float32) / 255.0
+ #timage = timage.transpose(0, 3, 1, 2)
+ #timage = torch.from_numpy(timage).to(device)
+ #timage = repeat(timage, '1 ... -> b ...', b=1)
+ #timage = 2.*timage - 1.
+ #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten()
+ #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+ basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name)
+ basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+ #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+ if len(good) > 1:
+ print("Directly copying latent vars !!!")
+ #forcedlatents += [4.6 * basic_new_fl]
+ forcedlatents += [basic_new_fl]
+ else:
+ epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2)
+ forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64)
+ forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+ forcedlatents += [forcedlatent]
+ #forcedlatents += [4.6 * forcedlatent]
+ else:
+ print(f"Voronoi in the latent space! {a} / {llambda}")
+ forcedlatent = np.zeros((4, 64, 64))
+ #print(type(numpy_image))
+ #print(numpy_image.shape)
+ #print(np.max(numpy_image))
+ #print(np.min(numpy_image))
+ #assert False
+ for i in range(len(indices)):
+ coefficients[i] = np.exp(np.random.randn())
+ for i in range(64):
+ x = i / 63.
+ for j in range(64):
+ y = j / 63
+ mindistances = 10000000000.
+ for u in range(len(indices)):
+ #print(a, i, x, j, y, u)
+ #print(indices[u][1])
+ #print(indices[u][2])
+ #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}")
+ distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) )
+ if distance < mindistances:
+ mindistances = distance
+ uu = indices[u][0]
+ for k in range(4):
+ assert k < len(forcedlatent), k
+ assert i < len(forcedlatent[k]), i
+ assert j < len(forcedlatent[k][i]), j
+ assert uu < len(latent)
+ assert k < len(latent[uu]), k
+ assert i < len(latent[uu][k]), i
+ assert j < len(latent[uu][k][i]), j
+ forcedlatent[k][i][j] = float(latent[uu][k][i][j])
+ #if a % 2 == 0:
+ # forcedlatent -= np.random.rand() * sauron
+ forcedlatent = forcedlatent.flatten()
+ basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+ if len(good) > 1 or len(forcedlatents) < len(good) + 1:
+ forcedlatents += [basic_new_fl]
+ else:
+ epsilon = (( (a + .5 - len(good)) / (llambda - len(good) - 1)))
+ forcedlatent = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(4*64*64)
+ coef = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2))
+ forcedlatent = coef * forcedlatent
+ print("we get ", sum(forcedlatent) ** 2)
+ forcedlatents += [forcedlatent]
+ #for uu in range(len(latent)):
+ # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}")
+ os.environ["good"] = "[]"
+ os.environ["bad"] = "[]"
+
+pygame.quit()
diff --git a/inoculate_evo_sd.sh b/inoculate_evo_sd.sh
new file mode 100755
index 000000000..26919fa5e
--- /dev/null
+++ b/inoculate_evo_sd.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+echo deprecated
diff --git a/maketgz.sh b/maketgz.sh
new file mode 100755
index 000000000..d0b91883a
--- /dev/null
+++ b/maketgz.sh
@@ -0,0 +1 @@
+echo deprecated
diff --git a/minisd.py b/minisd.py
new file mode 100644
index 000000000..83699367c
--- /dev/null
+++ b/minisd.py
@@ -0,0 +1,830 @@
+assert False, "Deprecated! Use geneticsd.py instead."
+############### DEPRECATED: see geneticsd.py import random
+############### DEPRECATED: see geneticsd.py import os
+############### DEPRECATED: see geneticsd.py import time
+############### DEPRECATED: see geneticsd.py import torch
+############### DEPRECATED: see geneticsd.py import numpy as np
+############### DEPRECATED: see geneticsd.py import shutil
+############### DEPRECATED: see geneticsd.py import PIL
+############### DEPRECATED: see geneticsd.py from PIL import Image
+############### DEPRECATED: see geneticsd.py from einops import rearrange, repeat
+############### DEPRECATED: see geneticsd.py from torch import autocast
+############### DEPRECATED: see geneticsd.py from diffusers import StableDiffusionPipeline
+############### DEPRECATED: see geneticsd.py import webbrowser
+############### DEPRECATED: see geneticsd.py from deep_translator import GoogleTranslator
+############### DEPRECATED: see geneticsd.py from langdetect import detect
+############### DEPRECATED: see geneticsd.py from joblib import Parallel, delayed
+############### DEPRECATED: see geneticsd.py import torch
+############### DEPRECATED: see geneticsd.py from PIL import Image
+############### DEPRECATED: see geneticsd.py from RealESRGAN import RealESRGAN
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
+############### DEPRECATED: see geneticsd.py model_id = "CompVis/stable-diffusion-v1-4"
+############### DEPRECATED: see geneticsd.py #device = "cuda"
+############### DEPRECATED: see geneticsd.py device = "mps" #torch.device("mps")
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py white = (255, 255, 255)
+############### DEPRECATED: see geneticsd.py green = (0, 255, 0)
+############### DEPRECATED: see geneticsd.py darkgreen = (0, 128, 0)
+############### DEPRECATED: see geneticsd.py red = (255, 0, 0)
+############### DEPRECATED: see geneticsd.py blue = (0, 0, 128)
+############### DEPRECATED: see geneticsd.py black = (0, 0, 0)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py os.environ["skl"] = "nn"
+############### DEPRECATED: see geneticsd.py os.environ["epsilon"] = "0.005"
+############### DEPRECATED: see geneticsd.py os.environ["decay"] = "0."
+############### DEPRECATED: see geneticsd.py os.environ["ngoptim"] = "DiscreteLenglerOnePlusOne"
+############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = ""
+############### DEPRECATED: see geneticsd.py latent_forcing = ""
+############### DEPRECATED: see geneticsd.py #os.environ["enforcedlatent"] = ""
+############### DEPRECATED: see geneticsd.py os.environ["good"] = "[]"
+############### DEPRECATED: see geneticsd.py os.environ["bad"] = "[]"
+############### DEPRECATED: see geneticsd.py num_iterations = 50
+############### DEPRECATED: see geneticsd.py gs = 7.5
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py import pyttsx3
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py noise = pyttsx3.init()
+############### DEPRECATED: see geneticsd.py noise.setProperty("rate", 240)
+############### DEPRECATED: see geneticsd.py noise.setProperty('voice', 'mb-us1')
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py #voice = noise.getProperty('voices')
+############### DEPRECATED: see geneticsd.py #for v in voice:
+############### DEPRECATED: see geneticsd.py # if v.name == "Kyoko":
+############### DEPRECATED: see geneticsd.py # noise.setProperty('voice', v.id)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py all_selected = []
+############### DEPRECATED: see geneticsd.py all_selected_latent = []
+############### DEPRECATED: see geneticsd.py final_selection = []
+############### DEPRECATED: see geneticsd.py final_selection_latent = []
+############### DEPRECATED: see geneticsd.py forcedlatents = []
+############### DEPRECATED: see geneticsd.py forcedgs = []
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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")
+############### DEPRECATED: see geneticsd.py pipe = pipe.to(device)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py prompt = "a photo of an astronaut riding a horse on mars"
+############### DEPRECATED: see geneticsd.py prompt = "a photo of a red panda with a hat playing table tennis"
+############### DEPRECATED: see geneticsd.py prompt = "a photorealistic portrait of " + random.choice(["Mary Cury", "Scarlett Johansson", "Marilyn Monroe", "Poison Ivy", "Black Widow", "Medusa", "Batman", "Albert Einstein", "Louis XIV", "Tarzan"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"])
+############### DEPRECATED: see geneticsd.py prompt = "a photorealistic portrait of " + random.choice(["Nelson Mandela", "Superman", "Superwoman", "Volodymyr Zelenskyy", "Tsai Ing-Wen", "Lzzy Hale", "Meg Myers"]) + random.choice([" with glasses", " with a hat", " with a cigarette", "with a scarf"])
+############### DEPRECATED: see geneticsd.py prompt = random.choice(["A woman with three eyes", "Meg Myers", "The rock band Ankor", "Miley Cyrus", "The man named Rahan", "A murder", "Rambo playing table tennis"])
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a female Terminator."
+############### DEPRECATED: see geneticsd.py prompt = random.choice([
+############### DEPRECATED: see geneticsd.py "Photo of Tarzan as a lawyer with a tie",
+############### DEPRECATED: see geneticsd.py "Photo of Scarlett Johansson as a sumo-tori",
+############### DEPRECATED: see geneticsd.py "Photo of the little mermaid as a young black girl",
+############### DEPRECATED: see geneticsd.py "Photo of Schwarzy with tentacles",
+############### DEPRECATED: see geneticsd.py "Photo of Meg Myers with an Egyptian dress",
+############### DEPRECATED: see geneticsd.py "Photo of Schwarzy as a ballet dancer",
+############### DEPRECATED: see geneticsd.py ])
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py name = random.choice(["Mark Zuckerbeg", "Zendaya", "Yann LeCun", "Scarlett Johansson", "Superman", "Meg Myers"])
+############### DEPRECATED: see geneticsd.py name = "Zendaya"
+############### DEPRECATED: see geneticsd.py prompt = f"Photo of {name} as a sumo-tori."
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py prompt = "Full length portrait of Mark Zuckerberg as a Sumo-Tori."
+############### DEPRECATED: see geneticsd.py prompt = "Full length portrait of Scarlett Johansson as a Sumo-Tori."
+############### DEPRECATED: see geneticsd.py prompt = "A close up photographic portrait of a young woman with uniformly colored hair."
+############### DEPRECATED: see geneticsd.py prompt = "Zombies raising and worshipping a flying human."
+############### DEPRECATED: see geneticsd.py prompt = "Zombies trying to kill Meg Myers."
+############### DEPRECATED: see geneticsd.py prompt = "Meg Myers with an Egyptian dress killing a vampire with a gun."
+############### DEPRECATED: see geneticsd.py prompt = "Meg Myers grabbing a vampire by the scruff of the neck."
+############### DEPRECATED: see geneticsd.py prompt = "Mark Zuckerberg chokes a vampire to death."
+############### DEPRECATED: see geneticsd.py prompt = "Mark Zuckerberg riding an animal."
+############### DEPRECATED: see geneticsd.py prompt = "A giant cute animal worshipped by zombies."
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py prompt = "Several faces."
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py prompt = "An armoured Yann LeCun fighting tentacles in the jungle."
+############### DEPRECATED: see geneticsd.py prompt = "Tentacles everywhere."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of a smiling Medusa."
+############### DEPRECATED: see geneticsd.py prompt = "Medusa."
+############### DEPRECATED: see geneticsd.py prompt = "Meg Myers in bloody armor fending off tentacles with a sword."
+############### DEPRECATED: see geneticsd.py prompt = "A red-haired woman with red hair. Her head is tilted."
+############### DEPRECATED: see geneticsd.py prompt = "A bloody heavy-metal zombie with a chainsaw."
+############### DEPRECATED: see geneticsd.py prompt = "Tentacles attacking a bloody Meg Myers in Eyptian dress. Meg Myers has a chainsaw."
+############### DEPRECATED: see geneticsd.py prompt = "Bizarre art."
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py prompt = "Beautiful bizarre woman."
+############### DEPRECATED: see geneticsd.py prompt = "Yann LeCun as the grim reaper: bizarre art."
+############### DEPRECATED: see geneticsd.py prompt = "Un chat en sang et en armure joue de la batterie."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a cyberpunk Mark Zuckerberg killing Cthulhu with a light saber."
+############### DEPRECATED: see geneticsd.py prompt = "A ferocious cyborg bear."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of Mark Zuckerberg killing Cthulhu with a light saber."
+############### DEPRECATED: see geneticsd.py prompt = "A bear with horns and blood and big teeth."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of a bear and Yoda, good friends."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of Yoda on the left, a blue octopus on the right, an explosion in the center."
+############### DEPRECATED: see geneticsd.py prompt = "A bird is on a hippo. They fight a black and red octopus. Jungle in the background."
+############### DEPRECATED: see geneticsd.py prompt = "A flying white owl above 4 colored pots with fire. The owl has a hat."
+############### DEPRECATED: see geneticsd.py prompt = "A flying white owl above 4 colored pots with fire."
+############### DEPRECATED: see geneticsd.py prompt = "Yann LeCun rides a dragon which spits fire on a cherry on a cake."
+############### DEPRECATED: see geneticsd.py prompt = "An armored Mark Zuckerberg fighting off a monster with bloody tentacles in the jungle with a light saber."
+############### DEPRECATED: see geneticsd.py prompt = "Cute woman, portrait, photo, red hair, green eyes, smiling."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of Tarzan as a lawyer with a tie and an octopus on his head."
+############### DEPRECATED: see geneticsd.py prompt = "An armored bloody Yann Lecun has a lightsabar and fights a red tentacular monster."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a giant armored insect attacking a building. The building is broken. There are flames."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of Meg Myers, on the left, in Egyptian dress, fights Cthulhu (on the right) with a light saber. They stare at each other."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a cute red panda."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a cute smiling white-haired woman with pink eyes."
+############### DEPRECATED: see geneticsd.py prompt = "A muscular Jesus with and assault rifle, a cap and and a light saber."
+############### DEPRECATED: see geneticsd.py prompt = "A portrait of a cute smiling woman."
+############### DEPRECATED: see geneticsd.py prompt = "A woman with black skin, red hair, egyptian dress, yellow eyes."
+############### DEPRECATED: see geneticsd.py prompt = "Photo of a red haired man with tilted head."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of Cleopatra with Egyptian Dress kissing Yoda."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of Yoda fighting Meg Myers with light sabers."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of Meg Myers, laughing, pulling Gandalf's hair."
+############### DEPRECATED: see geneticsd.py prompt = "A photo of Meg Myers laughing and pulling Gandalf's hair. Gandalf is stooping."
+############### DEPRECATED: see geneticsd.py prompt = "A star with flashy colors."
+############### DEPRECATED: see geneticsd.py prompt = "Portrait of a green haired woman with blue eyes."
+############### DEPRECATED: see geneticsd.py prompt = "Portrait of a female kung-fu master."
+############### DEPRECATED: see geneticsd.py prompt = "In a dark cave, in the middle of computers, a geek meets the devil."
+############### DEPRECATED: see geneticsd.py print(f"The prompt is {prompt}")
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py import pyfiglet
+############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format("Welcome in Genetic Stable Diffusion !"))
+############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format("First, let us choose the text :-)!"))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py print(f"Francais: Proposez un nouveau texte si vous ne voulez pas dessiner << {prompt} >>.\n")
+############### DEPRECATED: see geneticsd.py noise.say("Hey!")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py user_prompt = input(f"English: Enter a new prompt if you prefer something else than << {prompt} >>.\n")
+############### DEPRECATED: see geneticsd.py if len(user_prompt) > 2:
+############### DEPRECATED: see geneticsd.py prompt = user_prompt
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # On the fly translation.
+############### DEPRECATED: see geneticsd.py language = detect(prompt)
+############### DEPRECATED: see geneticsd.py english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def to_native(stri):
+############### DEPRECATED: see geneticsd.py return GoogleTranslator(source='en', target=language).translate(stri)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def pretty_print(stri):
+############### DEPRECATED: see geneticsd.py print(pyfiglet.figlet_format(to_native(stri)))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py print(f"{to_native('Working on')} {english_prompt}, a.k.a {prompt}.")
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def latent_to_image(latent):
+############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = str(list(latent.flatten())) #str(list(forcedlatents[k].flatten()))
+############### DEPRECATED: see geneticsd.py with autocast("cuda"):
+############### DEPRECATED: see geneticsd.py image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0]
+############### DEPRECATED: see geneticsd.py os.environ["forcedlatent"] = "[]"
+############### DEPRECATED: see geneticsd.py return image
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py sr_device = torch.device('cpu') #device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+############### DEPRECATED: see geneticsd.py esrmodel = RealESRGAN(sr_device, scale=4)
+############### DEPRECATED: see geneticsd.py esrmodel.load_weights('weights/RealESRGAN_x4.pth', download=True)
+############### DEPRECATED: see geneticsd.py esrmodel2 = RealESRGAN(sr_device, scale=2)
+############### DEPRECATED: see geneticsd.py esrmodel2.load_weights('weights/RealESRGAN_x2.pth', download=True)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def singleeg(path_to_image):
+############### DEPRECATED: see geneticsd.py image = Image.open(path_to_image).convert('RGB')
+############### DEPRECATED: see geneticsd.py sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+############### DEPRECATED: see geneticsd.py print(f"Type before SR = {type(image)}")
+############### DEPRECATED: see geneticsd.py sr_image = esrmodel.predict(image)
+############### DEPRECATED: see geneticsd.py print(f"Type after SR = {type(sr_image)}")
+############### DEPRECATED: see geneticsd.py output_filename = path_to_image + ".SR.png"
+############### DEPRECATED: see geneticsd.py sr_image.save(output_filename)
+############### DEPRECATED: see geneticsd.py return output_filename
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def singleeg2(path_to_image):
+############### DEPRECATED: see geneticsd.py time.sleep(0.5*np.random.rand())
+############### DEPRECATED: see geneticsd.py image = Image.open(path_to_image).convert('RGB')
+############### DEPRECATED: see geneticsd.py sr_device = device #('mps') #torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+############### DEPRECATED: see geneticsd.py print(f"Type before SR = {type(image)}")
+############### DEPRECATED: see geneticsd.py sr_image = esrmodel2.predict(image)
+############### DEPRECATED: see geneticsd.py print(f"Type after SR = {type(sr_image)}")
+############### DEPRECATED: see geneticsd.py output_filename = path_to_image + ".SR.png"
+############### DEPRECATED: see geneticsd.py sr_image.save(output_filename)
+############### DEPRECATED: see geneticsd.py return output_filename
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def eg(list_of_files, last_list_of_files):
+############### DEPRECATED: see geneticsd.py pretty_print("Should I convert images below to high resolution ?")
+############### DEPRECATED: see geneticsd.py print(list_of_files)
+############### DEPRECATED: see geneticsd.py noise.say("Go to the text window!")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py answer = input(" [y]es / [n]o / [j]ust the last batch of {len(last_list_of_files)} images ?")
+############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer or "j" in answer or "J" in answer:
+############### DEPRECATED: see geneticsd.py if j in answer or "J" in answer:
+############### DEPRECATED: see geneticsd.py list_of_files = last_list_of_files
+############### DEPRECATED: see geneticsd.py #images = Parallel(n_jobs=12)(delayed(singleeg)(image) for image in list_of_files)
+############### DEPRECATED: see geneticsd.py #print(to_native(f"Created the super-resolution files {images}"))
+############### DEPRECATED: see geneticsd.py for path_to_image in list_of_files:
+############### DEPRECATED: see geneticsd.py output_filename = singleeg(path_to_image)
+############### DEPRECATED: see geneticsd.py print(to_native(f"Created the super-resolution file {output_filename}"))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def stop_all(list_of_files, list_of_latent, last_list_of_files, last_list_of_latent):
+############### DEPRECATED: see geneticsd.py print(to_native("Your selected images and the last generation:"))
+############### DEPRECATED: see geneticsd.py print(list_of_files)
+############### DEPRECATED: see geneticsd.py eg(list_of_files, last_list_of_files)
+############### DEPRECATED: see geneticsd.py pretty_print("Should we create animations ?")
+############### DEPRECATED: see geneticsd.py answer = input(" [y]es or [n]o or [j]ust the selection on the last panel ?")
+############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer or "j" in answer or "J" in answer:
+############### DEPRECATED: see geneticsd.py assert len(list_of_files) == len(list_of_latent)
+############### DEPRECATED: see geneticsd.py if "j" in answer or "J" in answer:
+############### DEPRECATED: see geneticsd.py list_of_latent = last_list_of_latent
+############### DEPRECATED: see geneticsd.py pretty_print("Let us create animations!")
+############### DEPRECATED: see geneticsd.py for c in sorted([0.05, 0.04,0.03,0.02,0.01]):
+############### DEPRECATED: see geneticsd.py for idx in range(len(list_of_files)):
+############### DEPRECATED: see geneticsd.py images = []
+############### DEPRECATED: see geneticsd.py l = list_of_latent[idx].reshape(1,4,64,64)
+############### DEPRECATED: see geneticsd.py l = np.sqrt(len(l.flatten()) / np.sum(l**2)) * l
+############### DEPRECATED: see geneticsd.py l1 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64)
+############### DEPRECATED: see geneticsd.py l1 = np.sqrt(len(l1.flatten()) / np.sum(l1**2)) * l1
+############### DEPRECATED: see geneticsd.py l2 = l + c * np.random.randn(len(l.flatten())).reshape(1,4,64,64)
+############### DEPRECATED: see geneticsd.py l2 = np.sqrt(len(l2.flatten()) / np.sum(l2**2)) * l2
+############### DEPRECATED: see geneticsd.py num_animation_steps = 13
+############### DEPRECATED: see geneticsd.py index = 0
+############### DEPRECATED: see geneticsd.py for u in np.linspace(0., 2*3.14159 * (1-1/30), 30):
+############### DEPRECATED: see geneticsd.py cc = np.cos(u)
+############### DEPRECATED: see geneticsd.py ss = np.sin(u*2)
+############### DEPRECATED: see geneticsd.py index += 1
+############### DEPRECATED: see geneticsd.py image = latent_to_image(l + cc * (l1 - l) + ss * (l2 - l))
+############### DEPRECATED: see geneticsd.py image_name = f"imgA{index}.png"
+############### DEPRECATED: see geneticsd.py image.save(image_name)
+############### DEPRECATED: see geneticsd.py images += [image_name]
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1., num_animation_steps):
+############### DEPRECATED: see geneticsd.py # index += 1
+############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l1 + (1-u)*l)
+############### DEPRECATED: see geneticsd.py # image_name = f"imgA{index}.png"
+############### DEPRECATED: see geneticsd.py # image.save(image_name)
+############### DEPRECATED: see geneticsd.py # images += [image_name]
+############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1., num_animation_steps):
+############### DEPRECATED: see geneticsd.py # index += 1
+############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l2 + (1-u)*l1)
+############### DEPRECATED: see geneticsd.py # image_name = f"imgB{index}.png"
+############### DEPRECATED: see geneticsd.py # image.save(image_name)
+############### DEPRECATED: see geneticsd.py # images += [image_name]
+############### DEPRECATED: see geneticsd.py # for u in np.linspace(0., 1.,num_animation_steps):
+############### DEPRECATED: see geneticsd.py # index += 1
+############### DEPRECATED: see geneticsd.py # image = latent_to_image(u*l + (1-u)*l2)
+############### DEPRECATED: see geneticsd.py # image_name = f"imgC{index}.png"
+############### DEPRECATED: see geneticsd.py # image.save(image_name)
+############### DEPRECATED: see geneticsd.py # images += [image_name]
+############### DEPRECATED: see geneticsd.py print(to_native(f"Base images created for perturbation={c} and file {list_of_files[idx]}"))
+############### DEPRECATED: see geneticsd.py #images = Parallel(n_jobs=8)(delayed(process)(i) for i in range(10))
+############### DEPRECATED: see geneticsd.py images = Parallel(n_jobs=10)(delayed(singleeg2)(image) for image in images)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py frames = [Image.open(image) for image in images]
+############### DEPRECATED: see geneticsd.py frame_one = frames[0]
+############### DEPRECATED: see geneticsd.py gif_name = list_of_files[idx] + "_" + str(c) + ".gif"
+############### DEPRECATED: see geneticsd.py frame_one.save(gif_name, format="GIF", append_images=frames,
+############### DEPRECATED: see geneticsd.py save_all=True, duration=100, loop=0)
+############### DEPRECATED: see geneticsd.py webbrowser.open(os.environ["PWD"] + "/" + gif_name)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pretty_print("Should we create a meme ?")
+############### DEPRECATED: see geneticsd.py answer = input(" [y]es or [n]o ?")
+############### DEPRECATED: see geneticsd.py if "y" in answer or "Y" in answer:
+############### DEPRECATED: see geneticsd.py url = 'https://imgflip.com/memegenerator'
+############### DEPRECATED: see geneticsd.py webbrowser.open(url)
+############### DEPRECATED: see geneticsd.py pretty_print("Good bye!")
+############### DEPRECATED: see geneticsd.py exit()
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py import os
+############### DEPRECATED: see geneticsd.py import pygame
+############### DEPRECATED: see geneticsd.py from os import listdir
+############### DEPRECATED: see geneticsd.py from os.path import isfile, join
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py sentinel = str(random.randint(0,100000)) + "XX" + str(random.randint(0,100000))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py all_files = []
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py llambda = 15
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py assert llambda < 16, "lambda < 16 for convenience in pygame."
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py bad = []
+############### DEPRECATED: see geneticsd.py five_best = []
+############### DEPRECATED: see geneticsd.py latent = []
+############### DEPRECATED: see geneticsd.py images = []
+############### DEPRECATED: see geneticsd.py onlyfiles = []
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pretty_print("Now let us choose (if you want) an image as a start.")
+############### DEPRECATED: see geneticsd.py image_name = input(to_native("Name of image for starting ? (enter if no start image)"))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # activate the pygame library .
+############### DEPRECATED: see geneticsd.py pygame.init()
+############### DEPRECATED: see geneticsd.py X = 2000 # > 1500 = buttons
+############### DEPRECATED: see geneticsd.py Y = 900
+############### DEPRECATED: see geneticsd.py scrn = pygame.display.set_mode((1700, Y + 100))
+############### DEPRECATED: see geneticsd.py font = pygame.font.Font('freesansbold.ttf', 22)
+############### DEPRECATED: see geneticsd.py bigfont = pygame.font.Font('freesansbold.ttf', 44)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def load_img(path):
+############### DEPRECATED: see geneticsd.py image = Image.open(path).convert("RGB")
+############### DEPRECATED: see geneticsd.py w, h = image.size
+############### DEPRECATED: see geneticsd.py print(to_native(f"loaded input image of size ({w}, {h}) from {path}"))
+############### DEPRECATED: see geneticsd.py w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
+############### DEPRECATED: see geneticsd.py image = image.resize((512, 512), resample=PIL.Image.LANCZOS)
+############### DEPRECATED: see geneticsd.py #image = image.resize((w, h), resample=PIL.Image.LANCZOS)
+############### DEPRECATED: see geneticsd.py image = np.array(image).astype(np.float32) / 255.0
+############### DEPRECATED: see geneticsd.py image = image[None].transpose(0, 3, 1, 2)
+############### DEPRECATED: see geneticsd.py image = torch.from_numpy(image)
+############### DEPRECATED: see geneticsd.py return 2.*image - 1.
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py model = pipe.vae
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def img_to_latent(path):
+############### DEPRECATED: see geneticsd.py #init_image = 1.8 * load_img(path).to(device)
+############### DEPRECATED: see geneticsd.py init_image = load_img(path).to(device)
+############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1)
+############### DEPRECATED: see geneticsd.py forced_latent = model.encode(init_image.to(device)).latent_dist.sample()
+############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten()
+############### DEPRECATED: see geneticsd.py new_fl = np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+############### DEPRECATED: see geneticsd.py return new_fl
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py def randomized_image_to_latent(image_name, scale=None, epsilon=None, c=None, f=None):
+############### DEPRECATED: see geneticsd.py base_init_image = load_img(image_name).to(device)
+############### DEPRECATED: see geneticsd.py new_base_init_image = base_init_image
+############### DEPRECATED: see geneticsd.py c = np.exp(np.random.randn()) if c is None else c
+############### DEPRECATED: see geneticsd.py f = np.exp(-3. * np.random.rand()) if f is None else f
+############### DEPRECATED: see geneticsd.py init_image_shape = base_init_image.cpu().numpy().shape
+############### DEPRECATED: see geneticsd.py init_image = c * new_base_init_image
+############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1)
+############### DEPRECATED: see geneticsd.py forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample()
+############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten()
+############### DEPRECATED: see geneticsd.py basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl
+############### DEPRECATED: see geneticsd.py basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+############### DEPRECATED: see geneticsd.py epsilon = 0.1 * np.exp(-3 * np.random.rand()) if epsilon is None else epsilon
+############### DEPRECATED: see geneticsd.py new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64)
+############### DEPRECATED: see geneticsd.py scale = 2.8 + 3.6 * np.random.rand() if scale is None else scale
+############### DEPRECATED: see geneticsd.py new_fl = scale * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+############### DEPRECATED: see geneticsd.py #image = latent_to_image(np.asarray(new_fl)) #eval(os.environ["forcedlatent"])))
+############### DEPRECATED: see geneticsd.py #image.save(f"rebuild_{f}_{scale}_{epsilon}_{c}.png")
+############### DEPRECATED: see geneticsd.py return new_fl
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py if len(image_name) > 0:
+############### DEPRECATED: see geneticsd.py pretty_print("Importing an image !")
+############### DEPRECATED: see geneticsd.py try:
+############### DEPRECATED: see geneticsd.py init_image = load_img(image_name).to(device)
+############### DEPRECATED: see geneticsd.py except:
+############### DEPRECATED: see geneticsd.py pretty_print("Try again!")
+############### DEPRECATED: see geneticsd.py pretty_print("Loading failed!!")
+############### DEPRECATED: see geneticsd.py image_name = input(to_native("Name of image for starting ? (enter if no start image)"))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py base_init_image = load_img(image_name).to(device)
+############### DEPRECATED: see geneticsd.py noise.say("Image loaded")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py print(base_init_image.shape)
+############### DEPRECATED: see geneticsd.py print(np.max(base_init_image.cpu().detach().numpy().flatten()))
+############### DEPRECATED: see geneticsd.py print(np.min(base_init_image.cpu().detach().numpy().flatten()))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py forcedlatents = []
+############### DEPRECATED: see geneticsd.py divider = 1.5
+############### DEPRECATED: see geneticsd.py latent_found = False
+############### DEPRECATED: see geneticsd.py try:
+############### DEPRECATED: see geneticsd.py latent_file = image_name + ".latent.txt"
+############### DEPRECATED: see geneticsd.py print(to_native(f"Trying to load latent variables in {latent_file}."))
+############### DEPRECATED: see geneticsd.py f = open(latent_file, "r")
+############### DEPRECATED: see geneticsd.py print(to_native("File opened."))
+############### DEPRECATED: see geneticsd.py latent_str = f.read()
+############### DEPRECATED: see geneticsd.py print("Latent string read.")
+############### DEPRECATED: see geneticsd.py latent_found = True
+############### DEPRECATED: see geneticsd.py except:
+############### DEPRECATED: see geneticsd.py print(to_native("No latent file: guessing."))
+############### DEPRECATED: see geneticsd.py for i in range(llambda):
+############### DEPRECATED: see geneticsd.py new_base_init_image = base_init_image
+############### DEPRECATED: see geneticsd.py if not latent_found: # In case of latent vars we need less exploration.
+############### DEPRECATED: see geneticsd.py if (i % 7) == 1:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider
+############### DEPRECATED: see geneticsd.py if (i % 7) == 2:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider
+############### DEPRECATED: see geneticsd.py if (i % 7) == 3:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider
+############### DEPRECATED: see geneticsd.py if (i % 7) == 4:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider
+############### DEPRECATED: see geneticsd.py if (i % 7) == 5:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,1,:,:] /= divider
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider
+############### DEPRECATED: see geneticsd.py if (i % 7) == 6:
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,0,:,:] /= divider
+############### DEPRECATED: see geneticsd.py new_base_init_image[0,2,:,:] /= divider
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py c = np.exp(np.random.randn() - 5)
+############### DEPRECATED: see geneticsd.py f = np.exp(-3. * np.random.rand())
+############### DEPRECATED: see geneticsd.py init_image_shape = base_init_image.cpu().numpy().shape
+############### DEPRECATED: see geneticsd.py if i > 0 and not latent_found:
+############### DEPRECATED: see geneticsd.py init_image = new_base_init_image + torch.from_numpy(c * np.random.randn(np.prod(init_image_shape))).reshape(init_image_shape).float().to(device)
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py init_image = new_base_init_image
+############### DEPRECATED: see geneticsd.py init_image = repeat(init_image, '1 ... -> b ...', b=1)
+############### DEPRECATED: see geneticsd.py if latent_found:
+############### DEPRECATED: see geneticsd.py new_fl = np.asarray(eval(latent_str))
+############### DEPRECATED: see geneticsd.py assert len(new_fl) > 1
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py forced_latent = 1. * model.encode(init_image.to(device)).latent_dist.sample()
+############### DEPRECATED: see geneticsd.py new_fl = forced_latent.cpu().detach().numpy().flatten()
+############### DEPRECATED: see geneticsd.py basic_new_fl = new_fl #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl
+############### DEPRECATED: see geneticsd.py #new_fl = forced_latent + (1. / 1.1**(llambda-i)) * torch.from_numpy(np.random.randn(1*4*64*64).reshape(1,4,64,64)).float().to(device)
+############### DEPRECATED: see geneticsd.py #forcedlatents += [new_fl.cpu().detach().numpy()]
+############### DEPRECATED: see geneticsd.py if i > 0:
+############### DEPRECATED: see geneticsd.py #epsilon = 0.3 / 1.1**i
+############### DEPRECATED: see geneticsd.py basic_new_fl = f * np.sqrt(len(new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+############### DEPRECATED: see geneticsd.py epsilon = .7 * ((i-1)/(llambda-1)) #1.0 / 2**(2 + (llambda - i) / 6)
+############### DEPRECATED: see geneticsd.py print(f"{i} -- {i % 7} {c} {f} {epsilon}")
+############### DEPRECATED: see geneticsd.py # 1 -- 1 0.050020045300292804 0.0790648688521246 0.0
+############### DEPRECATED: see geneticsd.py new_fl = (1. - epsilon) * basic_new_fl + epsilon * np.random.randn(1*4*64*64)
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py new_fl = basic_new_fl
+############### DEPRECATED: see geneticsd.py new_fl = 6. * np.sqrt(len(new_fl)) * new_fl / np.sqrt(np.sum(new_fl ** 2))
+############### DEPRECATED: see geneticsd.py forcedlatents += [new_fl] #np.clip(new_fl, -3., 3.)] #np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl]
+############### DEPRECATED: see geneticsd.py forcedgs += [7.5] #np.random.choice([7.5, 15.0, 30.0, 60.0])] TODO
+############### DEPRECATED: see geneticsd.py #forcedlatents += [np.sqrt(len(new_fl) / sum(new_fl ** 2)) * new_fl]
+############### DEPRECATED: see geneticsd.py #print(f"{i} --> {forcedlatents[i][:10]}")
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # We start the big time consuming loop!
+############### DEPRECATED: see geneticsd.py for iteration in range(30):
+############### DEPRECATED: see geneticsd.py latent = [latent[f] for f in five_best]
+############### DEPRECATED: see geneticsd.py images = [images[f] for f in five_best]
+############### DEPRECATED: see geneticsd.py onlyfiles = [onlyfiles[f] for f in five_best]
+############### DEPRECATED: see geneticsd.py early_stop = []
+############### DEPRECATED: see geneticsd.py noise.say("WAIT!")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py final_selection = []
+############### DEPRECATED: see geneticsd.py final_selection_latent = []
+############### DEPRECATED: see geneticsd.py for k in range(llambda):
+############### DEPRECATED: see geneticsd.py if len(early_stop) > 0:
+############### DEPRECATED: see geneticsd.py break
+############### DEPRECATED: see geneticsd.py max_created_index = k
+############### DEPRECATED: see geneticsd.py if len(forcedlatents) > 0 and k < len(forcedlatents):
+############### DEPRECATED: see geneticsd.py #os.environ["forcedlatent"] = str(list(forcedlatents[k].flatten()))
+############### DEPRECATED: see geneticsd.py latent_forcing = str(list(forcedlatents[k].flatten()))
+############### DEPRECATED: see geneticsd.py print(f"We play with {latent_forcing[:20]}")
+############### DEPRECATED: see geneticsd.py if k < len(five_best):
+############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[k]).convert(), (300, 300))
+############### DEPRECATED: see geneticsd.py # Using blit to copy content from one surface to other
+############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (k // 3), 300 * (k % 3)))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py continue
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(0, Y, 1700, Y+100))
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(1500, 0, 2000, Y+100))
+############### DEPRECATED: see geneticsd.py text0 = bigfont.render(to_native(f'Please wait !!! {k} / {llambda}'), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4))
+############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'Or, for an early stopping,'), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8))
+############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'click and WAIT a bit'), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2))
+############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'... ... ... '), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2+Y/8))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # Button for early stopping
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native(f'Total: {len(all_selected)} chosen images! '), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3))
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('Click for stopping,'), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/64 - X/32, Y/3))
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('and get the effects.'), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py os.environ["earlystop"] = "False" if k > len(five_best) else "True"
+############### DEPRECATED: see geneticsd.py os.environ["epsilon"] = str(0. if k == len(five_best) else (k - len(five_best)) / llambda)
+############### DEPRECATED: see geneticsd.py os.environ["budget"] = str(np.random.randint(400) if k > len(five_best) else 2)
+############### DEPRECATED: see geneticsd.py os.environ["skl"] = {0: "nn", 1: "tree", 2: "logit"}[k % 3]
+############### DEPRECATED: see geneticsd.py #enforcedlatent = os.environ.get("enforcedlatent", "")
+############### DEPRECATED: see geneticsd.py #if len(enforcedlatent) > 2:
+############### DEPRECATED: see geneticsd.py # os.environ["forcedlatent"] = enforcedlatent
+############### DEPRECATED: see geneticsd.py # os.environ["enforcedlatent"] = ""
+############### DEPRECATED: see geneticsd.py #with autocast("cuda"):
+############### DEPRECATED: see geneticsd.py # image = pipe(english_prompt, guidance_scale=gs, num_inference_steps=num_iterations)["sample"][0]
+############### DEPRECATED: see geneticsd.py previous_gs = gs
+############### DEPRECATED: see geneticsd.py if k < len(forcedgs):
+############### DEPRECATED: see geneticsd.py gs = forcedgs[k]
+############### DEPRECATED: see geneticsd.py image = latent_to_image(np.asarray(latent_forcing)) #eval(os.environ["forcedlatent"])))
+############### DEPRECATED: see geneticsd.py gs = previous_gs
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py images += [image]
+############### DEPRECATED: see geneticsd.py filename = f"SD_{prompt.replace(' ','_')}_image_{sentinel}_{iteration:05d}_{k:05d}.png"
+############### DEPRECATED: see geneticsd.py image.save(filename)
+############### DEPRECATED: see geneticsd.py onlyfiles += [filename]
+############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[-1]).convert(), (300, 300))
+############### DEPRECATED: see geneticsd.py # Using blit to copy content from one surface to other
+############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (k // 3), 300 * (k % 3)))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py #noise.say("Dong")
+############### DEPRECATED: see geneticsd.py #noise.runAndWait()
+############### DEPRECATED: see geneticsd.py print('\a')
+############### DEPRECATED: see geneticsd.py str_latent = eval((os.environ["latent_sd"]))
+############### DEPRECATED: see geneticsd.py array_latent = eval(f"np.array(str_latent).reshape(4, 64, 64)")
+############### DEPRECATED: see geneticsd.py print(f"Debug info: array_latent sumsq/var {sum(array_latent.flatten() ** 2) / len(array_latent.flatten())}")
+############### DEPRECATED: see geneticsd.py latent += [array_latent]
+############### DEPRECATED: see geneticsd.py with open(filename + ".latent.txt", 'w') as f:
+############### DEPRECATED: see geneticsd.py f.write(f"{str_latent}")
+############### DEPRECATED: see geneticsd.py # In case of early stopping.
+############### DEPRECATED: see geneticsd.py first_event = True
+############### DEPRECATED: see geneticsd.py for i in pygame.event.get():
+############### DEPRECATED: see geneticsd.py if i.type == pygame.MOUSEBUTTONUP:
+############### DEPRECATED: see geneticsd.py if first_event:
+############### DEPRECATED: see geneticsd.py noise.say("Ok I stop")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py first_event = False
+############### DEPRECATED: see geneticsd.py pos = pygame.mouse.get_pos()
+############### DEPRECATED: see geneticsd.py index = 3 * (pos[0] // 300) + (pos[1] // 300)
+############### DEPRECATED: see geneticsd.py if pos[0] > X and pos[1] > Y /3 and pos[1] < 2*Y/3:
+############### DEPRECATED: see geneticsd.py stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent)
+############### DEPRECATED: see geneticsd.py exit()
+############### DEPRECATED: see geneticsd.py if index <= k:
+############### DEPRECATED: see geneticsd.py pretty_print(("You clicked for requesting an early stopping."))
+############### DEPRECATED: see geneticsd.py early_stop = [pos]
+############### DEPRECATED: see geneticsd.py break
+############### DEPRECATED: see geneticsd.py early_stop = [(1,1)]
+############### DEPRECATED: see geneticsd.py satus = False
+############### DEPRECATED: see geneticsd.py forcedgs = []
+############### DEPRECATED: see geneticsd.py # Stop the forcing from disk!
+############### DEPRECATED: see geneticsd.py #os.environ["enforcedlatent"] = ""
+############### DEPRECATED: see geneticsd.py # importing required library
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py #mypath = "./"
+############### DEPRECATED: see geneticsd.py #onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))]
+############### DEPRECATED: see geneticsd.py #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)]
+############### DEPRECATED: see geneticsd.py #print()
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # create the display surface object
+############### DEPRECATED: see geneticsd.py # of specific dimension..e(X, Y).
+############### DEPRECATED: see geneticsd.py noise.say("Ok I'm ready! Choose")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py pretty_print("Please choose your images.")
+############### DEPRECATED: see geneticsd.py text0 = bigfont.render(to_native(f'Choose your favorite images !!!========='), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/2 - X/32, Y/2-Y/4))
+############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'=================================='), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2-Y/8))
+############### DEPRECATED: see geneticsd.py text0 = font.render(to_native(f'=================================='), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text0, ((X*3/4)/3 - X/32, Y/2))
+############### DEPRECATED: see geneticsd.py # Add rectangles
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 0, X*3/4+X/16+X/32, Y/3), 2)
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, Y/3, X*3/4+X/16+X/32, 2*Y/3), 2)
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2)
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(0, Y, X/2, Y+100), 2)
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # Button for loading a starting point
+############### DEPRECATED: see geneticsd.py text1 = font.render('Manually edit an image.', True, green, blue)
+############### DEPRECATED: see geneticsd.py text1 = pygame.transform.rotate(text1, 90)
+############### DEPRECATED: see geneticsd.py #scrn.blit(text1, (X*3/4+X/16 - X/32, 0))
+############### DEPRECATED: see geneticsd.py #text1 = font.render('& latent ', True, green, blue)
+############### DEPRECATED: see geneticsd.py #text1 = pygame.transform.rotate(text1, 90)
+############### DEPRECATED: see geneticsd.py #scrn.blit(text1, (X*3/4+X/16+X/32 - X/32, 0))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # Button for creating a meme
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('Click ,'), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3+10))
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native('for finishing with effects.'), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16+X/32 - X/32, Y/3+10))
+############### DEPRECATED: see geneticsd.py # Button for new generation
+############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f"I don't want to select images"), True, green, blue)
+############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3+10))
+############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f"Just rerun."), True, green, blue)
+############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3+10))
+############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"Modify parameters or text!"), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py for idx in range(max_created_index + 1):
+############### DEPRECATED: see geneticsd.py # set the pygame window name
+############### DEPRECATED: see geneticsd.py pygame.display.set_caption(prompt)
+############### DEPRECATED: see geneticsd.py print(to_native(f"Pasting image {onlyfiles[idx]}..."))
+############### DEPRECATED: see geneticsd.py imp = pygame.transform.scale(pygame.image.load(onlyfiles[idx]).convert(), (300, 300))
+############### DEPRECATED: see geneticsd.py scrn.blit(imp, (300 * (idx // 3), 300 * (idx % 3)))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # paint screen one time
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py status = True
+############### DEPRECATED: see geneticsd.py indices = []
+############### DEPRECATED: see geneticsd.py good = []
+############### DEPRECATED: see geneticsd.py five_best = []
+############### DEPRECATED: see geneticsd.py for i in pygame.event.get():
+############### DEPRECATED: see geneticsd.py if i.type == pygame.MOUSEBUTTONUP:
+############### DEPRECATED: see geneticsd.py print(to_native(".... too early for clicking !!!!"))
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pretty_print("Please click on your favorite elements!")
+############### DEPRECATED: see geneticsd.py print(to_native("You might just click on one image and we will provide variations."))
+############### DEPRECATED: see geneticsd.py print(to_native("Or you can click on the top of an image and the bottom of another one."))
+############### DEPRECATED: see geneticsd.py print(to_native("Click on the << new generation >> when you're done."))
+############### DEPRECATED: see geneticsd.py while (status):
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # iterate over the list of Event objects
+############### DEPRECATED: see geneticsd.py # that was returned by pygame.event.get() method.
+############### DEPRECATED: see geneticsd.py for i in pygame.event.get():
+############### DEPRECATED: see geneticsd.py if hasattr(i, "type") and i.type == pygame.MOUSEBUTTONUP:
+############### DEPRECATED: see geneticsd.py pos = pygame.mouse.get_pos()
+############### DEPRECATED: see geneticsd.py pretty_print(f"Detected! Click at {pos}")
+############### DEPRECATED: see geneticsd.py if pos[1] > Y:
+############### DEPRECATED: see geneticsd.py pretty_print("Let us update parameters!")
+############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"ok, go to text window!"), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py try:
+############### DEPRECATED: see geneticsd.py num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n")))
+############### DEPRECATED: see geneticsd.py except:
+############### DEPRECATED: see geneticsd.py num_iterations = int(input(to_native(f"Number of iterations ? (current = {num_iterations})\n")))
+############### DEPRECATED: see geneticsd.py gs = float(input(to_native(f"Guidance scale ? (current = {gs})\n")))
+############### DEPRECATED: see geneticsd.py print(to_native(f"The current text is << {prompt} >>."))
+############### DEPRECATED: see geneticsd.py print(to_native("Start your answer with a symbol << + >> if this is an edit and not a new text."))
+############### DEPRECATED: see geneticsd.py new_prompt = str(input(to_native(f"Enter a text if you want to change from ") + prompt))
+############### DEPRECATED: see geneticsd.py if len(new_prompt) > 2:
+############### DEPRECATED: see geneticsd.py if new_prompt[0] == "+":
+############### DEPRECATED: see geneticsd.py prompt += new_prompt[1:]
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py prompt = new_prompt
+############### DEPRECATED: see geneticsd.py language = detect(prompt)
+############### DEPRECATED: see geneticsd.py english_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
+############### DEPRECATED: see geneticsd.py pretty_print("Ok! Parameters updated.")
+############### DEPRECATED: see geneticsd.py pretty_print("==> go back to the window!")
+############### DEPRECATED: see geneticsd.py text4 = font.render(to_native(f"Ok! parameters changed!"), True, green, blue)
+############### DEPRECATED: see geneticsd.py scrn.blit(text4, (300, Y + 30))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py elif pos[0] > 1500: # Not in the images.
+############### DEPRECATED: see geneticsd.py if pos[1] < Y/3:
+############### DEPRECATED: see geneticsd.py #filename = input(to_native("Filename (please provide the latent file, of the format SD*latent*.txt) ?\n"))
+############### DEPRECATED: see geneticsd.py #status = False
+############### DEPRECATED: see geneticsd.py #with open(filename, 'r') as f:
+############### DEPRECATED: see geneticsd.py # latent = f.read()
+############### DEPRECATED: see geneticsd.py #break
+############### DEPRECATED: see geneticsd.py pretty_print("Easy! I exit now, you edit the file and you save it.")
+############### DEPRECATED: see geneticsd.py pretty_print("Then just relaunch me and provide the text and the image.")
+############### DEPRECATED: see geneticsd.py exit()
+############### DEPRECATED: see geneticsd.py if pos[1] < 2*Y/3:
+############### DEPRECATED: see geneticsd.py #onlyfiles = [f for f in listdir(".") if isfile(join(mypath, f))]
+############### DEPRECATED: see geneticsd.py #onlyfiles = [str(f) for f in onlyfiles if "SD_" in str(f) and ".png" in str(f) and str(f) not in all_files and sentinel in str(f)]
+############### DEPRECATED: see geneticsd.py assert len(onlyfiles) == len(latent)
+############### DEPRECATED: see geneticsd.py assert len(all_selected) == len(all_selected_latent)
+############### DEPRECATED: see geneticsd.py stop_all(all_selected, all_selected_latent, final_selection, final_selection_latent) # + onlyfiles, all_selected_latent + latent)
+############### DEPRECATED: see geneticsd.py exit()
+############### DEPRECATED: see geneticsd.py status = False
+############### DEPRECATED: see geneticsd.py break
+############### DEPRECATED: see geneticsd.py index = 3 * (pos[0] // 300) + (pos[1] // 300)
+############### DEPRECATED: see geneticsd.py pygame.draw.circle(scrn, red, [pos[0], pos[1]], 13, 0)
+############### DEPRECATED: see geneticsd.py if index <= max_created_index:
+############### DEPRECATED: see geneticsd.py selected_filename = to_native("Selected") + onlyfiles[index]
+############### DEPRECATED: see geneticsd.py shutil.copyfile(onlyfiles[index], selected_filename)
+############### DEPRECATED: see geneticsd.py assert len(onlyfiles) == len(latent), f"{len(onlyfiles)} != {len(latent)}"
+############### DEPRECATED: see geneticsd.py all_selected += [selected_filename]
+############### DEPRECATED: see geneticsd.py all_selected_latent += [latent[index]]
+############### DEPRECATED: see geneticsd.py final_selection += [selected_filename]
+############### DEPRECATED: see geneticsd.py final_selection_latent += [latent[index]]
+############### DEPRECATED: see geneticsd.py text2 = font.render(to_native(f'==> {len(all_selected)} chosen images! '), True, green, blue)
+############### DEPRECATED: see geneticsd.py text2 = pygame.transform.rotate(text2, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text2, (X*3/4+X/16 - X/32, Y/3))
+############### DEPRECATED: see geneticsd.py if index not in five_best and len(five_best) < 5:
+############### DEPRECATED: see geneticsd.py five_best += [index]
+############### DEPRECATED: see geneticsd.py indices += [[index, (pos[0] - (pos[0] // 300) * 300) / 300, (pos[1] - (pos[1] // 300) * 300) / 300]]
+############### DEPRECATED: see geneticsd.py # Update the button for new generation.
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, black, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y))
+############### DEPRECATED: see geneticsd.py pygame.draw.rect(scrn, red, pygame.Rect(X*3/4, 2*Y/3, X*3/4+X/16+X/32, Y), 2)
+############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f" You have chosen {len(indices)} images:"), True, green, blue)
+############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16 - X/32, Y*2/3))
+############### DEPRECATED: see geneticsd.py text3 = font.render(to_native(f" Click for new generation!"), True, green, blue)
+############### DEPRECATED: see geneticsd.py text3 = pygame.transform.rotate(text3, 90)
+############### DEPRECATED: see geneticsd.py scrn.blit(text3, (X*3/4+X/16+X/32 - X/32, Y*2/3))
+############### DEPRECATED: see geneticsd.py pygame.display.flip()
+############### DEPRECATED: see geneticsd.py #text3Rect = text3.get_rect()
+############### DEPRECATED: see geneticsd.py #text3Rect.center = (750+750*3/4, 1000)
+############### DEPRECATED: see geneticsd.py good += [list(latent[index].flatten())]
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py noise.say("Bad click! Click on image.")
+############### DEPRECATED: see geneticsd.py noise.runAndWait()
+############### DEPRECATED: see geneticsd.py pretty_print("Bad click! Click on image.")
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py if i.type == pygame.QUIT:
+############### DEPRECATED: see geneticsd.py status = False
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py # Covering old images with full circles.
+############### DEPRECATED: see geneticsd.py for _ in range(123):
+############### DEPRECATED: see geneticsd.py x = np.random.randint(1500)
+############### DEPRECATED: see geneticsd.py y = np.random.randint(900)
+############### DEPRECATED: see geneticsd.py pygame.draw.circle(scrn, darkgreen,
+############### DEPRECATED: see geneticsd.py [x, y], 17, 0)
+############### DEPRECATED: see geneticsd.py pygame.display.update()
+############### DEPRECATED: see geneticsd.py if len(indices) == 0:
+############### DEPRECATED: see geneticsd.py print("The user did not like anything! Rerun :-(")
+############### DEPRECATED: see geneticsd.py continue
+############### DEPRECATED: see geneticsd.py print(f"Clicks at {indices}")
+############### DEPRECATED: see geneticsd.py os.environ["mu"] = str(len(indices))
+############### DEPRECATED: see geneticsd.py forcedlatents = []
+############### DEPRECATED: see geneticsd.py bad += [list(latent[u].flatten()) for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]
+############### DEPRECATED: see geneticsd.py #sauron = 0 * latent[0]
+############### DEPRECATED: see geneticsd.py #for u in [u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]]:
+############### DEPRECATED: see geneticsd.py # sauron += latent[u]
+############### DEPRECATED: see geneticsd.py #sauron = (1 / len([u for u in range(len(onlyfiles)) if u not in [i[0] for i in indices]])) * sauron
+############### DEPRECATED: see geneticsd.py if len(bad) > 500:
+############### DEPRECATED: see geneticsd.py bad = bad[(len(bad) - 500):]
+############### DEPRECATED: see geneticsd.py print(to_native(f"{len(indices)} indices are selected."))
+############### DEPRECATED: see geneticsd.py #print(f"indices = {indices}")
+############### DEPRECATED: see geneticsd.py os.environ["good"] = str(good)
+############### DEPRECATED: see geneticsd.py os.environ["bad"] = str(bad)
+############### DEPRECATED: see geneticsd.py coefficients = np.zeros(len(indices))
+############### DEPRECATED: see geneticsd.py numpy_images = [np.array(image) for image in images]
+############### DEPRECATED: see geneticsd.py for a in range(llambda):
+############### DEPRECATED: see geneticsd.py voronoi_in_images = False #(a % 2 == 1) and len(good) > 1
+############### DEPRECATED: see geneticsd.py if voronoi_in_images:
+############### DEPRECATED: see geneticsd.py image = np.array(numpy_images[0])
+############### DEPRECATED: see geneticsd.py print(f"Voronoi in the image space! {a} / {llambda}")
+############### DEPRECATED: see geneticsd.py for i in range(len(indices)):
+############### DEPRECATED: see geneticsd.py coefficients[i] = np.exp(np.random.randn())
+############### DEPRECATED: see geneticsd.py # Creating a forcedlatent.
+############### DEPRECATED: see geneticsd.py for i in range(512):
+############### DEPRECATED: see geneticsd.py x = i / 511.
+############### DEPRECATED: see geneticsd.py for j in range(512):
+############### DEPRECATED: see geneticsd.py y = j / 511
+############### DEPRECATED: see geneticsd.py mindistances = 10000000000.
+############### DEPRECATED: see geneticsd.py for u in range(len(indices)):
+############### DEPRECATED: see geneticsd.py distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) )
+############### DEPRECATED: see geneticsd.py if distance < mindistances:
+############### DEPRECATED: see geneticsd.py mindistances = distance
+############### DEPRECATED: see geneticsd.py uu = indices[u][0]
+############### DEPRECATED: see geneticsd.py image[i][j][:] = numpy_images[uu][i][j][:]
+############### DEPRECATED: see geneticsd.py # Conversion before using img2latent
+############### DEPRECATED: see geneticsd.py pil_image = Image.fromarray(image)
+############### DEPRECATED: see geneticsd.py voronoi_name = f"voronoi{a}_iteration{iteration}.png"
+############### DEPRECATED: see geneticsd.py pil_image.save(voronoi_name)
+############### DEPRECATED: see geneticsd.py #timage = np.array([image]).astype(np.float32) / 255.0
+############### DEPRECATED: see geneticsd.py #timage = timage.transpose(0, 3, 1, 2)
+############### DEPRECATED: see geneticsd.py #timage = torch.from_numpy(timage).to(device)
+############### DEPRECATED: see geneticsd.py #timage = repeat(timage, '1 ... -> b ...', b=1)
+############### DEPRECATED: see geneticsd.py #timage = 2.*timage - 1.
+############### DEPRECATED: see geneticsd.py #forcedlatent = model.encode(timage).latent_dist.sample().cpu().detach().numpy().flatten()
+############### DEPRECATED: see geneticsd.py #basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+############### DEPRECATED: see geneticsd.py basic_new_fl = randomized_image_to_latent(voronoi_name) #img_to_latent(voronoi_name)
+############### DEPRECATED: see geneticsd.py basic_new_fl = np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+############### DEPRECATED: see geneticsd.py #basic_new_fl = 0.8 * np.sqrt(len(basic_new_fl) / np.sum(basic_new_fl**2)) * basic_new_fl
+############### DEPRECATED: see geneticsd.py if len(good) > 1:
+############### DEPRECATED: see geneticsd.py print("Directly copying latent vars !!!")
+############### DEPRECATED: see geneticsd.py #forcedlatents += [4.6 * basic_new_fl]
+############### DEPRECATED: see geneticsd.py forcedlatents += [basic_new_fl]
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py epsilon = 1.0 * (((a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2)
+############### DEPRECATED: see geneticsd.py forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64)
+############### DEPRECATED: see geneticsd.py forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+############### DEPRECATED: see geneticsd.py forcedlatents += [forcedlatent]
+############### DEPRECATED: see geneticsd.py #forcedlatents += [4.6 * forcedlatent]
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py print(f"Voronoi in the latent space! {a} / {llambda}")
+############### DEPRECATED: see geneticsd.py forcedlatent = np.zeros((4, 64, 64))
+############### DEPRECATED: see geneticsd.py #print(type(numpy_image))
+############### DEPRECATED: see geneticsd.py #print(numpy_image.shape)
+############### DEPRECATED: see geneticsd.py #print(np.max(numpy_image))
+############### DEPRECATED: see geneticsd.py #print(np.min(numpy_image))
+############### DEPRECATED: see geneticsd.py #assert False
+############### DEPRECATED: see geneticsd.py for i in range(len(indices)):
+############### DEPRECATED: see geneticsd.py coefficients[i] = np.exp(np.random.randn())
+############### DEPRECATED: see geneticsd.py for i in range(64):
+############### DEPRECATED: see geneticsd.py x = i / 63.
+############### DEPRECATED: see geneticsd.py for j in range(64):
+############### DEPRECATED: see geneticsd.py y = j / 63
+############### DEPRECATED: see geneticsd.py mindistances = 10000000000.
+############### DEPRECATED: see geneticsd.py for u in range(len(indices)):
+############### DEPRECATED: see geneticsd.py #print(a, i, x, j, y, u)
+############### DEPRECATED: see geneticsd.py #print(indices[u][1])
+############### DEPRECATED: see geneticsd.py #print(indices[u][2])
+############### DEPRECATED: see geneticsd.py #print(f" {coefficients[u]}* np.linalg.norm({np.array((x, y))}-{np.array((indices[u][1], indices[u][2]))}")
+############### DEPRECATED: see geneticsd.py distance = coefficients[u] * np.linalg.norm( np.array((x, y)) - np.array((indices[u][2], indices[u][1])) )
+############### DEPRECATED: see geneticsd.py if distance < mindistances:
+############### DEPRECATED: see geneticsd.py mindistances = distance
+############### DEPRECATED: see geneticsd.py uu = indices[u][0]
+############### DEPRECATED: see geneticsd.py for k in range(4):
+############### DEPRECATED: see geneticsd.py assert k < len(forcedlatent), k
+############### DEPRECATED: see geneticsd.py assert i < len(forcedlatent[k]), i
+############### DEPRECATED: see geneticsd.py assert j < len(forcedlatent[k][i]), j
+############### DEPRECATED: see geneticsd.py assert uu < len(latent)
+############### DEPRECATED: see geneticsd.py assert k < len(latent[uu]), k
+############### DEPRECATED: see geneticsd.py assert i < len(latent[uu][k]), i
+############### DEPRECATED: see geneticsd.py assert j < len(latent[uu][k][i]), j
+############### DEPRECATED: see geneticsd.py forcedlatent[k][i][j] = float(latent[uu][k][i][j])
+############### DEPRECATED: see geneticsd.py #if a % 2 == 0:
+############### DEPRECATED: see geneticsd.py # forcedlatent -= np.random.rand() * sauron
+############### DEPRECATED: see geneticsd.py forcedlatent = forcedlatent.flatten()
+############### DEPRECATED: see geneticsd.py basic_new_fl = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+############### DEPRECATED: see geneticsd.py if len(good) > 1 or len(forcedlatents) < len(good) + 1:
+############### DEPRECATED: see geneticsd.py forcedlatents += [basic_new_fl]
+############### DEPRECATED: see geneticsd.py else:
+############### DEPRECATED: see geneticsd.py epsilon = ((0.5 * (a + .5 - len(good)) / (llambda - len(good) - 1)) ** 2)
+############### DEPRECATED: see geneticsd.py forcedlatent = (1. - epsilon) * basic_new_fl.flatten() + epsilon * np.random.randn(4*64*64)
+############### DEPRECATED: see geneticsd.py #forcedlatent = np.sqrt(len(forcedlatent) / np.sum(forcedlatent**2)) * forcedlatent
+############### DEPRECATED: see geneticsd.py forcedlatents += [forcedlatent]
+############### DEPRECATED: see geneticsd.py #for uu in range(len(latent)):
+############### DEPRECATED: see geneticsd.py # print(f"--> latent[{uu}] sum of sq / variable = {np.sum(latent[uu].flatten()**2) / len(latent[uu].flatten())}")
+############### DEPRECATED: see geneticsd.py os.environ["good"] = "[]"
+############### DEPRECATED: see geneticsd.py os.environ["bad"] = "[]"
+############### DEPRECATED: see geneticsd.py
+############### DEPRECATED: see geneticsd.py pygame.quit()
diff --git a/minisd.sh b/minisd.sh
new file mode 100755
index 000000000..26919fa5e
--- /dev/null
+++ b/minisd.sh
@@ -0,0 +1,2 @@
+#!/bin/bash
+echo deprecated
diff --git a/multi_minisd.sh b/multi_minisd.sh
new file mode 100755
index 000000000..8699b6cc4
--- /dev/null
+++ b/multi_minisd.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+echo deprecated
diff --git a/multiminisd.sh b/multiminisd.sh
new file mode 100755
index 000000000..8699b6cc4
--- /dev/null
+++ b/multiminisd.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+echo deprecated
diff --git a/pipeline_stable_diffusion.py b/pipeline_stable_diffusion.py
new file mode 100644
index 000000000..8e3199b44
--- /dev/null
+++ b/pipeline_stable_diffusion.py
@@ -0,0 +1,397 @@
+# Modification of the original file by O. Teytaud for facilitating genetic stable diffusion.
+
+import inspect
+import os
+import numpy as np
+import random
+import warnings
+from typing import List, Optional, Union
+
+import torch
+
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from ...models import AutoencoderKL, UNet2DConditionModel
+from ...pipeline_utils import DiffusionPipeline
+from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from . import StableDiffusionPipelineOutput
+from .safety_checker import StableDiffusionSafetyChecker
+
+
+class StableDiffusionPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offsensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPFeatureExtractor,
+ ):
+ super().__init__()
+ scheduler = scheduler.set_format("pt")
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+
+ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
+ `attention_head_dim` must be a multiple of `slice_size`.
+ """
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = self.unet.config.attention_head_dim // 2
+ self.unet.set_attention_slice(slice_size)
+
+ def disable_attention_slicing(self):
+ r"""
+ Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
+ back to computing attention in one step.
+ """
+ # set slice_size = `None` to disable `attention slicing`
+ self.enable_attention_slicing(None)
+
+# def get_latent(self, image):
+# return self.vae.encode(image)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ height: Optional[int] = 512,
+ width: Optional[int] = 512,
+ num_inference_steps: Optional[int] = 50,
+ guidance_scale: Optional[float] = 7.5,
+ eta: Optional[float] = 0.0,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
+ deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+
+ if "torch_device" in kwargs:
+ device = kwargs.pop("torch_device")
+ warnings.warn(
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
+ " Consider using `pipe.to(torch_device)` instead."
+ )
+
+ # Set device as before (to be removed in 0.3.0)
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ self.to(device)
+
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ # get prompt text embeddings
+ text_input = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance:
+ max_length = text_input.input_ids.shape[-1]
+ uncond_input = self.tokenizer(
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
+ )
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+
+ # get the initial random noise unless the user supplied it
+
+ # Unlike in other pipelines, latents need to be generated in the target device
+ # for 1-to-1 results reproducibility with the CompVis implementation.
+ # However this currently doesn't work in `mps`.
+ latents_device = "cpu" if self.device.type == "mps" else self.device
+ latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ latents_intermediate_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
+ speedup = 1
+ if latents is None:
+ latents = torch.randn(
+ latents_intermediate_shape,
+ generator=generator,
+ device=latents_device,
+ )
+ if len(os.environ["forcedlatent"]) > 10:
+ stri = os.environ["forcedlatent"]
+ print(f"we get a forcing for the latent z: {stri[:20]}.")
+ if len(eval(stri)) == 1:
+ stri = str(eval(stri)[0])
+ speedup = 1
+ latents = np.array(list(eval(stri))).flatten()
+ #latents = latents + np.exp(0.1 * np.random.randn()) * np.random.rand(len(latents))
+ #latents = np.sqrt(len(latents) / np.sum(latents ** 2)) * latents
+ #latents = np.sqrt(len(latents)) * latents / np.sqrt(np.sum(latents ** 2))
+ print(f"As an array, this is {latents[:10]}")
+ print(f"immediately after loading latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}")
+ latents = torch.from_numpy(latents.reshape((1,4,64,64))).float().to(latents_device)
+ os.environ["forcedlatent"] = ""
+ good = eval(os.environ["good"])
+ bad = eval(os.environ["bad"])
+ print(f"{len(good)} good and {len(bad)} bad")
+ i_believe_in_evolution = len(good) > 0 and len(bad) > 10
+ print(f"I believe in evolution = {i_believe_in_evolution}")
+ if i_believe_in_evolution:
+ from sklearn import tree
+ from sklearn.neural_network import MLPClassifier
+ #from sklearn.neighbors import NearestCentroid
+ from sklearn.linear_model import LogisticRegression
+ #z = (np.random.randn(4*64*64))
+ z = latents.cpu().numpy().flatten()
+ if os.environ.get("skl", "tree") == "tree":
+ clf = tree.DecisionTreeClassifier()#min_samples_split=0.1)
+ elif os.environ.get("skl", "tree") == "logit":
+ clf = LogisticRegression()
+ else:
+ clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1)
+ #clf = NearestCentroid()
+
+
+
+ X=good + bad
+ Y = [1] * len(good) + [0] * len(bad)
+ clf = clf.fit(X,Y)
+ epsilon = 0.0001 # for astronauts
+ epsilon = 1.0
+
+ def loss(x):
+ return clf.predict_proba([x])[0][0] # for astronauts
+ #return clf.predict_proba([(1-epsilon)*z+epsilon*x])[0][0] # for astronauts
+ #return clf.predict_proba([z+epsilon*x])[0][0]
+
+
+ budget = int(os.environ.get("budget", "300"))
+ if i_believe_in_evolution and budget > 20:
+ import nevergrad as ng
+ #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget)
+ #nevergrad_optimizer = ng.optimizers.RandomSearch(len(z), budget)
+ optim_class = ng.optimizers.registry[os.environ.get("ngoptim", "DiscreteLenglerOnePlusOne")]
+ #nevergrad_optimizer = ng.optimizers.DiscreteLenglerOnePlusOne(len(z), budget)
+ nevergrad_optimizer = optim_class(len(z), budget)
+ #nevergrad_optimizer = ng.optimizers.DiscreteOnePlusOne(len(z), budget)
+# for k in range(5):
+# z1 = np.array(random.choice(good))
+# z2 = np.array(random.choice(good))
+# z3 = np.array(random.choice(good))
+# z4 = np.array(random.choice(good))
+# z5 = np.array(random.choice(good))
+# #z = 0.99 * z1 + 0.01 * (z2+z3+z4+z5)/4.
+# z = 0.2 * (z1 + z2 + z3 + z4 + z5)
+# mu = int(os.environ.get("mu", "5"))
+# parents = [z1, z2, z3, z4, z5]
+# weights = [np.exp(np.random.randn() - i * float(os.environ.get("decay", "1."))) for i in range(5)]
+# z = weights[0] * z1
+# for u in range(mu):
+# if u > 0:
+# z += weights[u] * parents[u]
+# z = (1. / sum(weights[:mu])) * z
+# z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+#
+# #for u in range(len(z)):
+# # z[u] = random.choice([z1[u],z2[u],z3[u],z4[u],z5[u]])
+# nevergrad_optimizer.suggest
+ if len(os.environ["forcedlatent"]) > 0:
+ print("we get a forcing for the latent z.")
+ z0 = eval(os.environ["forcedlatent"])
+ #nevergrad_optimizer.suggest(eval(os.environ["forcedlatent"]))
+ else:
+ z0 = z
+ for i in range(budget):
+ x = nevergrad_optimizer.ask()
+ z = z0 + float(os.environ.get("epsilon", "0.001")) * x.value
+ z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+ l = loss(z)
+ nevergrad_optimizer.tell(x, l)
+ if np.log2(i+1) == int(np.log2(i+1)):
+ print(f"iteration {i} --> {l}")
+ print("var/variable = ", sum(z**2)/len(z))
+ #z = (1.-epsilon) * z + epsilon * x / np.sqrt(np.sum(x ** 2))
+ if l < 0.0000001 and os.environ.get("earlystop", "False") in ["true", "True"]:
+ print(f"we find proba(bad)={l}")
+ break
+ x = nevergrad_optimizer.recommend().value
+ z = z0 + float(os.environ.get("epsilon", "0.001")) * x
+ z = np.sqrt(len(z)) * z / np.linalg.norm(z)
+ latents = torch.from_numpy(z.reshape(latents_intermediate_shape)).float() #.half()
+ else:
+ if latents.shape != latents_intermediate_shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_intermediate_shape}")
+ print(f"latent ==> {sum(latents.flatten()**2) / len(latents.flatten())}")
+ print(f"latent ==> {torch.max(latents)}")
+ print(f"latent ==> {torch.min(latents)}")
+ os.environ["latent_sd"] = str(list(latents.flatten().cpu().numpy()))
+ for i in [2, 3]:
+ latents = torch.repeat_interleave(latents, repeats=latents_shape[i] // latents_intermediate_shape[i], dim=i) #/ np.sqrt(np.sqrt(latents_shape[i] // latents_intermediate_shape[i]))
+ latents = latents.float().to(self.device)
+
+ # set timesteps
+ accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
+ extra_set_kwargs = {}
+ if accepts_offset:
+ extra_set_kwargs["offset"] = 1
+
+ self.scheduler.set_timesteps(num_inference_steps // speedup, **extra_set_kwargs)
+
+ # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = latents * self.scheduler.sigmas[0]
+
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ sigma = self.scheduler.sigmas[i]
+ # the model input needs to be scaled to match the continuous ODE formulation in K-LMS
+ latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
+
+ # predict the noise residual
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if isinstance(self.scheduler, LMSDiscreteScheduler):
+ latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
+ else:
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # scale and decode the image latents with vae
+ #os.environ["latent_sd"] = str(list(latents.flatten().cpu().detach().numpy()))
+ latents = 1 / 0.18215 * latents
+ image = self.vae.decode(latents).sample
+
+ image = (image / 2 + 0.5).clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
+
+ # run safety checker
+ safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/view_history.sh b/view_history.sh
new file mode 100755
index 000000000..09f5c1618
--- /dev/null
+++ b/view_history.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+#montage `ls -ctr SD*imag*.png | head -n 15 | tail -n 14` -mode concatenate -tile 7x zuck1.png
+#montage `ls -ctr SD*imag*.png | head -n 29 | tail -n 14` -mode concatenate -tile 7x zuck2.png
+#montage `ls -ctr SD*imag*.png | tail -n 28` -mode concatenate -tile 7x history.png
+#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort | tail -n 60 | sort ) -mode concatenate -tile 5x history.png
+#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_11.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_0_4.png | sort ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_?.png | sort -n ) $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`_1_??.png | sort -n ) -mode concatenate -tile 5x history.png
+#montage $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | sort ) -mode concatenate -tile 5x history.png
+#open history.png
+open $( ls *`ls -ctr SD*.png | sed 's/.*image_//g' | tail -n 1 | sed 's/_.*//g'`*.png | tail -n 15 | sort )
+#cp history.png zuck3.png