From 1141aaff28826056a4d54bdb0af75ca8a41727c2 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 25 Mar 2024 14:26:39 +0100 Subject: [PATCH 01/10] Support multiimage masking --- src/diffusers/image_processor.py | 8 ++- src/diffusers/models/attention_processor.py | 69 +++++++++++++-------- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index daeb8fd6fa6d..7a058075e46f 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -948,14 +948,16 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value The downsampled mask tensor. """ - o_h = mask.shape[1] - o_w = mask.shape[2] + if mask.ndim == 3: + mask = mask.unsqueeze(0) + o_h = mask.shape[2] + o_w = mask.shape[3] ratio = o_w / o_h mask_h = int(math.sqrt(num_queries / ratio)) mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) mask_w = num_queries // mask_h - mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) + mask_downsample = F.interpolate(mask, size=(mask_h, mask_w), mode="bicubic").squeeze(0) # Repeat batch_size times if mask_downsample.shape[0] < batch_size: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 0c6dfe068d5c..cc14975e0c98 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2372,11 +2372,11 @@ def __call__( hidden_states = hidden_states.to(query.dtype) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - raise ValueError( - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) + # if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: + # raise ValueError( + # " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." + # " Please use `IPAdapterMaskProcessor` to preprocess your mask" + # ) if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" @@ -2388,33 +2388,52 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + if mask is not None: + if not isinstance(scale, list): + scale = [scale] + for i in range(mask.shape[1]): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2] + ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) From 854b843d7b5be2e2c7b475884540c1fe56d1b821 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 27 Mar 2024 21:43:49 +0100 Subject: [PATCH 02/10] Restore ipadaptermaskprocessor --- src/diffusers/image_processor.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 7a058075e46f..daeb8fd6fa6d 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -948,16 +948,14 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value The downsampled mask tensor. """ - if mask.ndim == 3: - mask = mask.unsqueeze(0) - o_h = mask.shape[2] - o_w = mask.shape[3] + o_h = mask.shape[1] + o_w = mask.shape[2] ratio = o_w / o_h mask_h = int(math.sqrt(num_queries / ratio)) mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0) mask_w = num_queries // mask_h - mask_downsample = F.interpolate(mask, size=(mask_h, mask_w), mode="bicubic").squeeze(0) + mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0) # Repeat batch_size times if mask_downsample.shape[0] < batch_size: From d42d358ae89bd70ce190fbbe79a1e72c678be569 Mon Sep 17 00:00:00 2001 From: fabiorigano Date: Sun, 31 Mar 2024 16:16:48 +0200 Subject: [PATCH 03/10] Update conditions --- src/diffusers/models/attention_processor.py | 103 +++++++++++++++----- 1 file changed, 76 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index cc14975e0c98..d843a2799d12 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2198,14 +2198,10 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - raise ValueError( - " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) if len(ip_adapter_masks) != len(self.scale): raise ValueError( - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Lenght of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"number of IP-Adapters ({len(self.scale)})" ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2214,26 +2210,61 @@ def __call__( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) + if mask is not None: + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) + if mask.shape[1] != current_ip_hidden_states.shape[1]: + raise ValueError( + f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match " + f"number of input images ({current_ip_hidden_states.shape[1]})" + ) - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + if not isinstance(scale, list): + scale = [scale] - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + if mask.shape[1] != len(scale): + raise ValueError( + f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match number of scales ({len(scale)})" + ) + + for i in range(mask.shape[1]): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) - hidden_states = hidden_states + scale * current_ip_hidden_states + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2372,14 +2403,10 @@ def __call__( hidden_states = hidden_states.to(query.dtype) if ip_adapter_masks is not None: - # if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: - # raise ValueError( - # " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." - # " Please use `IPAdapterMaskProcessor` to preprocess your mask" - # ) if len(ip_adapter_masks) != len(self.scale): raise ValueError( - f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" + f"Lenght of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"number of IP-Adapters ({len(self.scale)})" ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2389,8 +2416,27 @@ def __call__( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): if mask is not None: + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + + if mask.shape[1] != current_ip_hidden_states.shape[1]: + raise ValueError( + f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match " + f"number of input images ({current_ip_hidden_states.shape[1]})" + ) + if not isinstance(scale, list): scale = [scale] + + if mask.shape[1] != len(scale): + raise ValueError( + f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match number of scales ({len(scale)})" + ) + for i in range(mask.shape[1]): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) @@ -2410,7 +2456,10 @@ def __call__( _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) mask_downsample = IPAdapterMaskProcessor.downsample( - mask[:, i, :, :], batch_size, _current_ip_hidden_states.shape[1], _current_ip_hidden_states.shape[2] + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], ) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) From c63e6affd989ed9734f61c9e4ba64e2d2763a746 Mon Sep 17 00:00:00 2001 From: fabiorigano Date: Sun, 31 Mar 2024 18:38:23 +0200 Subject: [PATCH 04/10] Add test --- .../test_ip_adapter_stable_diffusion.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index bed3ca82b8d2..74642e0ee119 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -544,3 +544,33 @@ def test_ip_adapter_multiple_masks(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + + def test_ip_adapter_multiple_masks_one_adapter(self): + image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") + pipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + image_encoder=image_encoder, + torch_dtype=self.dtype, + variant="fp16", + ) + pipeline.load_ip_adapter( + "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] + ) + pipeline.set_ip_adapter_scale([[0.7, 0.7]]) + + inputs = self.get_dummy_inputs(for_masks=True) + masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"] + processor = IPAdapterMaskProcessor() + masks = processor.preprocess(masks) + masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3]) + inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks] + ip_images = inputs["ip_adapter_image"] + inputs["ip_adapter_image"] = [[image[0] for image in ip_images]] + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array( + [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424] + ) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 From b6fdc96e247219882c3d32746b85f433a98bf182 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 1 Apr 2024 19:40:44 +0200 Subject: [PATCH 05/10] Update src/diffusers/models/attention_processor.py Co-authored-by: Sayak Paul --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d843a2799d12..543c658b9f7b 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2200,7 +2200,7 @@ def __call__( if ip_adapter_masks is not None: if len(ip_adapter_masks) != len(self.scale): raise ValueError( - f"Lenght of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"number of IP-Adapters ({len(self.scale)})" ) else: From df31fe6ae5792b8278ecff77992dcafc82005a4b Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 5 Apr 2024 08:55:06 +0200 Subject: [PATCH 06/10] Apply suggestions from code review Co-authored-by: YiYi Xu --- src/diffusers/models/attention_processor.py | 32 ++++++++++----------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 543c658b9f7b..a6fd83825b6c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2198,6 +2198,19 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: + if not isinstance(ip_adapter_mask, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_mask = list(ip_adapter_mask.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states) == len(self.to_k_ip) == len(self.to_v_ip)): + raise ValueError(...) + else: + for mask, scale, ip_state in zip(ip_adapter_masks, self.scale, ip_hidden_states): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError("each ip_adapter_mask should be a tensor with shape [1, num_images, height, width] ..." ) + if mask.shape[1] != current_ip_hidden_states.shape[1]: + raise ValueError(...) + if isinstance(scale, list) and not len(scale) == len(mask.shape[1]: + raise ValueError(...) if len(ip_adapter_masks) != len(self.scale): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " @@ -2211,28 +2224,13 @@ def __call__( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): if mask is not None: - if not isinstance(mask, torch.Tensor) or mask.ndim != 4: - raise ValueError( - "Each element of the ip_adapter_masks array should be a tensor with shape " - "[1, num_images, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - - if mask.shape[1] != current_ip_hidden_states.shape[1]: - raise ValueError( - f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match " - f"number of input images ({current_ip_hidden_states.shape[1]})" - ) if not isinstance(scale, list): scale = [scale] - if mask.shape[1] != len(scale): - raise ValueError( - f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match number of scales ({len(scale)})" - ) - for i in range(mask.shape[1]): + current_num_images = mask.shape[1] + for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) From d0d930836992e5b9c4c0d698534ce3d5a8de9fdd Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 5 Apr 2024 09:36:20 +0200 Subject: [PATCH 07/10] Fix conditions --- src/diffusers/models/attention_processor.py | 92 ++++++++++++--------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index a6fd83825b6c..ca783511c3a7 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect from importlib import import_module -from typing import Callable, Optional, Union +from typing import Callable, List, Optional, Union import torch import torch.nn.functional as F @@ -2198,24 +2198,33 @@ def __call__( hidden_states = attn.batch_to_head_dim(hidden_states) if ip_adapter_masks is not None: - if not isinstance(ip_adapter_mask, List): + if not isinstance(ip_adapter_masks, List): # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] - ip_adapter_mask = list(ip_adapter_mask.unsqueeze(1)) - if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states) == len(self.to_k_ip) == len(self.to_v_ip)): - raise ValueError(...) - else: - for mask, scale, ip_state in zip(ip_adapter_masks, self.scale, ip_hidden_states): - if not isinstance(mask, torch.Tensor) or mask.ndim != 4: - raise ValueError("each ip_adapter_mask should be a tensor with shape [1, num_images, height, width] ..." ) - if mask.shape[1] != current_ip_hidden_states.shape[1]: - raise ValueError(...) - if isinstance(scale, list) and not len(scale) == len(mask.shape[1]: - raise ValueError(...) - if len(ip_adapter_masks) != len(self.scale): + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " - f"number of IP-Adapters ({len(self.scale)})" - ) + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2228,7 +2237,6 @@ def __call__( if not isinstance(scale, list): scale = [scale] - current_num_images = mask.shape[1] for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) @@ -2401,11 +2409,33 @@ def __call__( hidden_states = hidden_states.to(query.dtype) if ip_adapter_masks is not None: - if len(ip_adapter_masks) != len(self.scale): + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)): raise ValueError( - f"Lenght of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " - f"number of IP-Adapters ({len(self.scale)})" - ) + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " + f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " + f"({len(ip_hidden_states)})" + ) + else: + for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a tensor with shape " + "[1, num_images_for_ip_adapter, height, width]." + " Please use `IPAdapterMaskProcessor` to preprocess your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(scale, list) and not len(scale) == mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(scale)}) at index {index}" + ) else: ip_adapter_masks = [None] * len(self.scale) @@ -2414,28 +2444,12 @@ def __call__( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): if mask is not None: - if not isinstance(mask, torch.Tensor) or mask.ndim != 4: - raise ValueError( - "Each element of the ip_adapter_masks array should be a tensor with shape " - "[1, num_images, height, width]." - " Please use `IPAdapterMaskProcessor` to preprocess your mask" - ) - - if mask.shape[1] != current_ip_hidden_states.shape[1]: - raise ValueError( - f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match " - f"number of input images ({current_ip_hidden_states.shape[1]})" - ) if not isinstance(scale, list): scale = [scale] - if mask.shape[1] != len(scale): - raise ValueError( - f"Number of masks in ip_adapter_masks ({mask.shape[1]}) must match number of scales ({len(scale)})" - ) - - for i in range(mask.shape[1]): + current_num_images = mask.shape[1] + for i in range(current_num_images): ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) From 7d3a79448859ddb2bd4a414e4a790e07eaf22946 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 6 Apr 2024 10:56:22 +0200 Subject: [PATCH 08/10] Add fast test --- tests/pipelines/test_pipelines_common.py | 54 ++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index d7f0c6baa339..bf876243717a 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -226,6 +226,11 @@ def test_pipeline_signature(self): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((2, 1, cross_attention_dim), device=torch_device) + def _get_dummy_masks(self, input_size: int = 64): + _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) + _masks[0, :, :, :int(input_size / 2)] = 1 + return _masks + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): parameters = inspect.signature(self.pipeline_class.__call__).parameters if "image" in parameters.keys() and "strength" in parameters.keys(): @@ -353,6 +358,55 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4): assert out_cfg.shape == out_no_cfg.shape + def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + sample_size = pipe.unet.config.get("sample_size", 32) + block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512]) + input_size = sample_size * (2 ** (len(block_out_channels) - 1)) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter and masks, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = { + "ip_adapter_masks": [self._get_dummy_masks(input_size)] + } + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter and masks, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + inputs["cross_attention_kwargs"] = { + "ip_adapter_masks": [self._get_dummy_masks(input_size)] + } + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + ) class PipelineLatentTesterMixin: """ From 654c15e38dc06baba49b2e6c594611cd4b099b4c Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 6 Apr 2024 11:10:16 +0200 Subject: [PATCH 09/10] Fix style --- src/diffusers/models/attention_processor.py | 6 ++---- tests/pipelines/test_pipelines_common.py | 14 +++++--------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ca783511c3a7..a5af4b7ce736 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2206,7 +2206,7 @@ def __call__( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" - ) + ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: @@ -2233,7 +2233,6 @@ def __call__( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): if mask is not None: - if not isinstance(scale, list): scale = [scale] @@ -2417,7 +2416,7 @@ def __call__( f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match " f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states " f"({len(ip_hidden_states)})" - ) + ) else: for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)): if not isinstance(mask, torch.Tensor) or mask.ndim != 4: @@ -2444,7 +2443,6 @@ def __call__( ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ): if mask is not None: - if not isinstance(scale, list): scale = [scale] diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index bf876243717a..a2a9b6e898fe 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -228,7 +228,7 @@ def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): def _get_dummy_masks(self, input_size: int = 64): _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) - _masks[0, :, :, :int(input_size / 2)] = 1 + _masks[0, :, :, : int(input_size / 2)] = 1 return _masks def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): @@ -359,14 +359,13 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4): assert out_cfg.shape == out_no_cfg.shape def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): - components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) pipe.set_progress_bar_config(disable=None) cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) sample_size = pipe.unet.config.get("sample_size", 32) block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512]) - input_size = sample_size * (2 ** (len(block_out_channels) - 1)) + input_size = sample_size * (2 ** (len(block_out_channels) - 1)) # forward pass without ip adapter inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) @@ -379,9 +378,7 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): # forward pass with single ip adapter and masks, but scale=0 which should have no effect inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["cross_attention_kwargs"] = { - "ip_adapter_masks": [self._get_dummy_masks(input_size)] - } + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} pipe.set_ip_adapter_scale(0.0) output_without_adapter_scale = pipe(**inputs)[0] output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() @@ -389,9 +386,7 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): # forward pass with single ip adapter and masks, but with scale of adapter weights inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] - inputs["cross_attention_kwargs"] = { - "ip_adapter_masks": [self._get_dummy_masks(input_size)] - } + inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]} pipe.set_ip_adapter_scale(42.0) output_with_adapter_scale = pipe(**inputs)[0] output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() @@ -408,6 +403,7 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" ) + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. From cf1ddce9358f3a4cafd038abbd0eb65ac7e37694 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 8 Apr 2024 22:14:19 +0200 Subject: [PATCH 10/10] Fix tests --- tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py | 2 +- tests/pipelines/test_pipelines_common.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 74642e0ee119..a4217e73a9a2 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -551,8 +551,8 @@ def test_ip_adapter_multiple_masks_one_adapter(self): "stabilityai/stable-diffusion-xl-base-1.0", image_encoder=image_encoder, torch_dtype=self.dtype, - variant="fp16", ) + pipeline.enable_model_cpu_offload() pipeline.load_ip_adapter( "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"] ) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index a2a9b6e898fe..90b923bf150c 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -400,7 +400,7 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): "Output without ip-adapter must be same as normal inference", ) self.assertGreater( - max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" )