diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30086654a2f1..237e8236caf4 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from importlib import import_module -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch import torch.nn.functional as F @@ -2195,15 +2195,33 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - if len(ip_adapter_masks) != len(self.scale): - raise ValueError( - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2211,26 +2229,44 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + if mask is not None: + if not isinstance(scale, list): + scale = [scale] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2369,15 +2405,33 @@ def __call__( hidden_states = hidden_states.to(query.dtype) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - raise ValueError( - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - if len(ip_adapter_masks) != len(self.scale): + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2385,33 +2439,57 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + if mask is not None: + if not isinstance(scale, list): + scale = [scale] - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index bed3ca82b8d2..a4217e73a9a2 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -544,3 +544,33 @@ def test_ip_adapter_multiple_masks(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + + def test_ip_adapter_multiple_masks_one_adapter(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + torch_dtype=self.dtype, + ) + pipeline.enable_model_cpu_offload() + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] + ) + pipeline.set_ip_adapter_scale([[0.7, 0.7]]) + + inputs = self.get_dummy_inputs(for_masks=True) + masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"] + processor = IPAdapterMaskProcessor() + masks = processor.preprocess(masks) + masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3]) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks] + ip_images = inputs["ip_adapter_image"] + inputs["ip_adapter_image"] = [[image[0] for image in ip_images]] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 73a171797036..acff5f2cdf8f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -238,6 +238,11 @@ def test_pipeline_signature(self): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((2, 1, cross_attention_dim), device=torch_device) + def _get_dummy_masks(self, input_size: int = 64): + _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) + _masks[0, :, :, : int(input_size / 2)] = 1 + return _masks + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): parameters = inspect.signature(self.pipeline_class.__call__).parameters if "image" in parameters.keys() and "strength" in parameters.keys(): @@ -365,6 +370,51 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4): assert out_cfg.shape == out_no_cfg.shape + def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + sample_size = pipe.unet.config.get("sample_size", 32) + block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512]) + input_size = sample_size * (2 ** (len(block_out_channels) - 1)) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter and masks, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter and masks, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" + ) + class PipelineLatentTesterMixin: """