From 6ea2abccd8fe438d139c5b25976bc1aad6200546 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Thu, 5 Dec 2024 00:51:50 +0100 Subject: [PATCH] pass attn mask arg for flux --- src/diffusers/models/attention_processor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 7351801368dd..13d910db6135 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1908,7 +1908,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype)