From c31b5049460b8c92ca7f8cb7902faf623430cfe4 Mon Sep 17 00:00:00 2001 From: christopher5106 Date: Fri, 29 Nov 2024 11:38:06 +0100 Subject: [PATCH] use attention mask parameter in flux attention --- scripts/convert_flux_to_diffusers.py | 2 +- src/diffusers/models/attention_processor.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/convert_flux_to_diffusers.py b/scripts/convert_flux_to_diffusers.py index 33668fed8120..fccac70dd855 100644 --- a/scripts/convert_flux_to_diffusers.py +++ b/scripts/convert_flux_to_diffusers.py @@ -31,7 +31,7 @@ --vae """ -CTX = init_empty_weights if is_accelerate_available else nullcontext +CTX = init_empty_weights if is_accelerate_available() else nullcontext parser = argparse.ArgumentParser() parser.add_argument("--original_state_dict_repo_id", default=None, type=str) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index ffbf4a0056c6..8b2a8688619a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1907,7 +1907,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask, + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2012,9 +2014,12 @@ def __call__( keep_prob=1.0, sync=False, inner_precise=0, + attenMask=attention_mask, )[0] else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask, + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)