Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/community/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,8 @@ def __call__(
raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
if isinstance(seed_tiles_mode, str):
seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
if any(
mode not in (modes := [mode.value for mode in self.SeedTilesMode])
for row in seed_tiles_mode
for mode in row
):
modes = [mode.value for mode in self.SeedTilesMode]
if any(mode not in modes for row in seed_tiles_mode for mode in row):
raise ValueError(f"Seed tiles mode must be one of {modes}")
if seed_reroll_regions is None:
seed_reroll_regions = []
Expand Down
Binary file added frog.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 89 additions & 8 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,22 +166,28 @@ def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
is_lora = hasattr(self, "processor") and isinstance(
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor)
)
is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
)
is_added_kv_processor = hasattr(self, "processor") and isinstance(
self.processor,
(
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
),
)

if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
"Memory efficient attention with `xformers` is currently not supported when"
" `self.added_kv_proj_dim` is defined."
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
)
elif not is_xformers_available():
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
Expand Down Expand Up @@ -233,6 +239,15 @@ def set_use_memory_efficient_attention_xformers(
processor.load_state_dict(self.processor.state_dict())
if hasattr(self.processor, "to_k_custom_diffusion"):
processor.to(self.processor.to_k_custom_diffusion.weight.device)
elif is_added_kv_processor:
# TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
# which uses this type of cross attention ONLY because the attention mask of format
# [0, ..., -10.000, ..., 0, ...,] is not supported
# throw warning
logger.info(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
Expand Down Expand Up @@ -889,6 +904,71 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a
return hidden_states


class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.

Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""

def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape

attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)

encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)

if not attn.only_cross_attention:
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
else:
key = encoder_hidden_states_key_proj
value = encoder_hidden_states_value_proj

hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual

return hidden_states


class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Expand Down Expand Up @@ -1428,6 +1508,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None,
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor,
LoRAAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def __init__(

if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

if encoder_hid_dim is None and encoder_hid_dim_type is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def __init__(

if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")

if encoder_hid_dim is None and encoder_hid_dim_type is not None:
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
IFSuperResolutionPipeline,
)
from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
Expand All @@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand Down Expand Up @@ -81,6 +80,13 @@ def test_inference_batch_single_identical(self):
expected_max_diff=1e-2,
)

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)


@slow
@require_torch_gpu
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFImg2ImgPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand All @@ -63,6 +62,13 @@ def get_dummy_inputs(self, device, seed=0):
def test_save_load_optional_components(self):
self._test_save_load_optional_components()

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self):
# Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder
Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
Expand All @@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -59,6 +58,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_inpainting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFInpaintingPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_dummy_components()

Expand All @@ -62,6 +61,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFInpaintingSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import (
Expand All @@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"})
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -64,6 +63,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down
10 changes: 8 additions & 2 deletions tests/pipelines/deepfloyd_if/test_if_superresolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from diffusers import IFSuperResolutionPipeline
from diffusers.utils import floats_tensor
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import skip_mps, torch_device

from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
Expand All @@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}

test_xformers_attention = False

def get_dummy_components(self):
return self._get_superresolution_dummy_components()

Expand All @@ -57,6 +56,13 @@ def get_dummy_inputs(self, device, seed=0):

return inputs

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)

def test_save_load_optional_components(self):
self._test_save_load_optional_components()

Expand Down