diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 22b6589973a0..e4a69641d580 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,5 +1,4 @@ from .utils import ( - is_accelerate_available, is_flax_available, is_inflect_available, is_onnx_available, @@ -17,13 +16,6 @@ from .utils import logging -# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py" -# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available -if is_torch_available() and not is_accelerate_available(): - error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501 - raise ImportError(error_msg) - - if is_torch_available(): from .modeling_utils import ModelMixin from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 9e05672bf163..1e91ccd56a15 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -21,15 +21,20 @@ import torch from torch import Tensor, device -import accelerate -from accelerate.utils import set_module_tensor_to_device -from accelerate.utils.versions import is_torch_version from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError from . import __version__ -from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging +from .utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_NAME, + is_accelerate_available, + is_torch_version, + logging, +) logger = logging.get_logger(__name__) @@ -41,6 +46,12 @@ _LOW_CPU_MEM_USAGE_DEFAULT = False +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + def get_parameter_device(parameter: torch.nn.Module): try: return next(parameter.parameters()).device @@ -319,6 +330,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warn( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + # Check if we can handle device_map and dispatching the weights if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 36c2d5b888ef..97e196e72338 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -25,7 +25,6 @@ import diffusers import PIL -from accelerate.utils.versions import is_torch_version from huggingface_hub import snapshot_download from packaging import version from PIL import Image @@ -43,6 +42,8 @@ WEIGHTS_NAME, BaseOutput, deprecate, + is_accelerate_available, + is_torch_version, is_transformers_available, logging, ) @@ -397,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device_map = kwargs.pop("device_map", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warn( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + if device_map is not None and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 7395f4edfa26..3fa477e7dce8 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -31,6 +31,7 @@ is_scipy_available, is_tf_available, is_torch_available, + is_torch_version, is_transformers_available, is_unidecode_available, requires_backends, diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 833f2b6c5074..25aa82d6c5b2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -272,21 +272,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class VQDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class DDIMScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py b/src/diffusers/utils/dummy_torch_and_accelerate_objects.py deleted file mode 100644 index 335e3ca24d2a..000000000000 --- a/src/diffusers/utils/dummy_torch_and_accelerate_objects.py +++ /dev/null @@ -1,452 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -# flake8: noqa - -from ..utils import DummyObject, requires_backends - - -class ModelMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class AutoencoderKL(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class Transformer2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet1DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DConditionModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class UNet2DModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -def get_constant_schedule(*args, **kwargs): - requires_backends(get_constant_schedule, ["torch", "accelerate"]) - - -def get_constant_schedule_with_warmup(*args, **kwargs): - requires_backends(get_constant_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs): - requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_linear_schedule_with_warmup(*args, **kwargs): - requires_backends(get_linear_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_polynomial_decay_schedule_with_warmup(*args, **kwargs): - requires_backends(get_polynomial_decay_schedule_with_warmup, ["torch", "accelerate"]) - - -def get_scheduler(*args, **kwargs): - requires_backends(get_scheduler, ["torch", "accelerate"]) - - -class DiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DanceDiffusionPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class LDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintPipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVePipeline(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDIMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class DDPMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerAncestralDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EulerDiscreteScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class IPNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class KarrasVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class PNDMScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class RePaintScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class SchedulerMixin(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class ScoreSdeVeScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class VQDiffusionScheduler(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - -class EMAModel(metaclass=DummyObject): - _backends = ["torch", "accelerate"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "accelerate"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "accelerate"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 4ea02dcc94da..005cbb6170f4 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -15,11 +15,14 @@ Import utilities: Utilities related to imports and our lazy inits. """ import importlib.util +import operator as op import os import sys from collections import OrderedDict +from typing import Union from packaging import version +from packaging.version import Version, parse from . import logging @@ -40,6 +43,8 @@ USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + _torch_version = "N/A" if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: _torch_available = importlib.util.find_spec("torch") is not None @@ -309,3 +314,36 @@ def __getattr__(cls, key): if key.startswith("_"): return super().__getattr__(cls, key) requires_backends(cls, cls._backends) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Args: + Compares a library version to some requirement using a given operation. + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib_metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +def is_torch_version(operation: str, version: str): + """ + Args: + Compares the current PyTorch version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(parse(_torch_version), operation, version) diff --git a/tests/repo_utils/test_check_dummies.py b/tests/repo_utils/test_check_dummies.py index 0331b5e8c2cc..d8fa9ce10547 100644 --- a/tests/repo_utils/test_check_dummies.py +++ b/tests/repo_utils/test_check_dummies.py @@ -52,13 +52,13 @@ def test_find_backend(self): def test_read_init(self): objects = read_init() # We don't assert on the exact list of keys to allow for smooth grow of backend-specific objects - self.assertIn("torch_and_accelerate", objects) + self.assertIn("torch", objects) self.assertIn("torch_and_transformers", objects) self.assertIn("flax_and_transformers", objects) self.assertIn("torch_and_transformers_and_onnx", objects) # Likewise, we can't assert on the exact content of a key - self.assertIn("UNet2DModel", objects["torch_and_accelerate"]) + self.assertIn("UNet2DModel", objects["torch"]) self.assertIn("FlaxUNet2DConditionModel", objects["flax"]) self.assertIn("StableDiffusionPipeline", objects["torch_and_transformers"]) self.assertIn("FlaxStableDiffusionPipeline", objects["flax_and_transformers"])