diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index c531d5a519f2..240f514b3dcd 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -136,6 +136,7 @@ def load_ip_adapter( token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + target_blocks = kwargs.pop("target_blocks", ["block"]) if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False @@ -226,7 +227,7 @@ def load_ip_adapter( # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 8bbec26189b0..0817197c2ad7 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -800,7 +800,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us return image_projection - def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False): + def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False, target_blocks=["block"]): from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, @@ -864,11 +864,14 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] with init_context(): + selected = any(block_name in name for block_name in target_blocks) + attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=num_image_text_embeds, + skip=not selected, ) value_dict = {} @@ -887,14 +890,16 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F return attn_procs - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False, target_blocks=["block"]): if not isinstance(state_dicts, list): state_dicts = [state_dicts] # Set encoder_hid_proj after loading ip_adapter weights, # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None - attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + attn_procs = self._convert_ip_adapter_attn_to_diffusers( + state_dicts, low_cpu_mem_usage=low_cpu_mem_usage, target_blocks=target_blocks + ) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30086654a2f1..11996851a481 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2108,7 +2108,7 @@ class IPAdapterAttnProcessor(nn.Module): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False): super().__init__() self.hidden_size = hidden_size @@ -2117,6 +2117,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale if not isinstance(num_tokens, (tuple, list)): num_tokens = [num_tokens] self.num_tokens = num_tokens + self.skip = skip if not isinstance(scale, list): scale = [scale] * len(num_tokens) @@ -2208,29 +2209,30 @@ def __call__( ip_adapter_masks = [None] * len(self.scale) # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) - - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] - ) + if not self.skip: + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states) + + if mask is not None: + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) @@ -2263,7 +2265,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, skip=False): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -2283,6 +2285,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale if len(scale) != len(num_tokens): raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") self.scale = scale + self.skip = skip self.to_k_ip = nn.ModuleList( [nn.Linear(cross_attention_dim, hidden_size, bias=False) for _ in range(len(num_tokens))] @@ -2382,36 +2385,37 @@ def __call__( ip_adapter_masks = [None] * len(self.scale) # for ip-adapter - for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + if not self.skip: + for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks + ): + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - if mask is not None: - mask_downsample = IPAdapterMaskProcessor.downsample( - mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim ) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + if mask is not None: + mask_downsample = IPAdapterMaskProcessor.downsample( + mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] + ) - mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) + mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) - current_ip_hidden_states = current_ip_hidden_states * mask_downsample + current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states + hidden_states = hidden_states + scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states)