diff --git a/examples/community/mixture.py b/examples/community/mixture.py index 60d0ee2d09d3..845ad76b6a2e 100644 --- a/examples/community/mixture.py +++ b/examples/community/mixture.py @@ -215,11 +215,8 @@ def __call__( raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") if isinstance(seed_tiles_mode, str): seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] - if any( - mode not in (modes := [mode.value for mode in self.SeedTilesMode]) - for row in seed_tiles_mode - for mode in row - ): + modes = [mode.value for mode in self.SeedTilesMode] + if any(mode not in modes for row in seed_tiles_mode for mode in row): raise ValueError(f"Seed tiles mode must be one of {modes}") if seed_reroll_regions is None: seed_reroll_regions = [] diff --git a/frog.png b/frog.png new file mode 100644 index 000000000000..dce094c892a9 Binary files /dev/null and b/frog.png differ diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 1bfaa0258155..e39bdc0429c1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -166,22 +166,28 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) + self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor) ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) if use_memory_efficient_attention_xformers: - if self.added_kv_proj_dim is not None: - # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported + if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( - "Memory efficient attention with `xformers` is currently not supported when" - " `self.added_kv_proj_dim` is defined." + f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" ) - elif not is_xformers_available(): + if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" @@ -233,6 +239,15 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -889,6 +904,71 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + class XFormersAttnProcessor: r""" Processor for implementing memory efficient attention using xFormers. @@ -1428,6 +1508,7 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 484f9323c69f..106346070d94 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,6 +261,7 @@ def __init__( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index af647fe810aa..a0dbdaa75230 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -364,6 +364,7 @@ def __init__( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index f4cb52d25a8d..2e7383067eec 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -28,6 +28,7 @@ IFSuperResolutionPipeline, ) from diffusers.models.attention_processor import AttnAddedKVProcessor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T batch_params = TEXT_TO_IMAGE_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -81,6 +80,13 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index c85063af9e30..ec4598906a6f 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -20,6 +20,7 @@ from diffusers import IFImg2ImgPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -63,6 +62,13 @@ def get_dummy_inputs(self, device, seed=0): def test_save_load_optional_components(self): self._test_save_load_optional_components() + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py index e7c8d58a3e0c..500557108aed 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFImg2ImgSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS @@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"}) required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -59,6 +58,13 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py index 6837ad36baf5..1317fcb64e81 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -20,6 +20,7 @@ from diffusers import IFInpaintingPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -62,6 +61,13 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py index fc130091b5e5..961a22675f33 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFInpaintingSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"}) required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -64,6 +63,13 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py index 9e418ca6aff5..52fb38308892 100644 --- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS @@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -57,6 +56,13 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components()