From 7bb6e4ab99ce7d3a51f47098dad97b408a39ce41 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Sat, 30 Aug 2025 21:51:49 +0530 Subject: [PATCH 1/6] fix: preserve boolean dtype for attention masks in ChromaPipeline - Convert attention masks to bool and prevent dtype corruption - Fix both positive and negative mask handling in _get_t5_prompt_embeds - Remove float conversion in _prepare_attention_mask method Fixes #12116 --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index a3dd1422b876..8c787bfc2b00 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -233,11 +233,12 @@ def _get_t5_prompt_embeds( ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask.clone() + attention_mask = attention_mask.bool() # fix here mine # Chroma requires the attention mask to include one padding token seq_lengths = attention_mask.sum(dim=1) mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1) - attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long() + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool() prompt_embeds = self.text_encoder( text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device) @@ -245,7 +246,7 @@ def _get_t5_prompt_embeds( dtype = self.text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - attention_mask = attention_mask.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(device=device) _, seq_len, _ = prompt_embeds.shape @@ -580,10 +581,10 @@ def _prepare_attention_mask( # Extend the prompt attention mask to account for image tokens in the final sequence attention_mask = torch.cat( - [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)], + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], dim=1, ) - attention_mask = attention_mask.to(dtype) + # attention_mask = attention_mask.to(dtype) return attention_mask From 1cf90c414708d353eb89cbb2ba83834009cb1869 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Sun, 31 Aug 2025 03:59:29 +0530 Subject: [PATCH 2/6] test: add ChromaPipeline attention mask dtype tests --- .../pipelines/chroma/test_pipeline_chroma.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index 3edd58b75f82..16757053480e 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -158,3 +158,24 @@ def test_chroma_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + + +class ChromaPipelineAttentionMaskTests(unittest.TestCase): + def setUp(self): + self.pipe = ChromaPipeline.from_pretrained( + "lodestones/Chroma1-Base", + torch_dtype=torch.float16, + ) + + def test_attention_mask_dtype_is_bool_short_prompt(self): + prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man") + self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}") + self.assertGreater(prompt_embeds.shape[0], 0) + self.assertGreater(prompt_embeds.shape[1], 0) + + def test_attention_mask_dtype_is_bool_long_prompt(self): + long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees" + prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt) + self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}") + self.assertGreater(prompt_embeds.shape[0], 0) + self.assertGreater(prompt_embeds.shape[1], 0) From 3f9228bdb284237f1a0db44eb0332149cd5a6508 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Sun, 31 Aug 2025 04:06:58 +0530 Subject: [PATCH 3/6] test: add slow ChromaPipeline attention mask tests --- tests/pipelines/chroma/test_pipeline_chroma.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index 16757053480e..2b936d27a3d4 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler +from diffusers.utils.testing_utils import slow from ...testing_utils import torch_device from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist @@ -167,12 +168,14 @@ def setUp(self): torch_dtype=torch.float16, ) + @slow def test_attention_mask_dtype_is_bool_short_prompt(self): prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man") self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}") self.assertGreater(prompt_embeds.shape[0], 0) self.assertGreater(prompt_embeds.shape[1], 0) + @slow def test_attention_mask_dtype_is_bool_long_prompt(self): long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees" prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt) From 60d738510a22add8a2574ff2136a8a63521b5c83 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Sun, 31 Aug 2025 04:18:14 +0530 Subject: [PATCH 4/6] chore: removed comments --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 8c787bfc2b00..75a4722cf442 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -233,7 +233,7 @@ def _get_t5_prompt_embeds( ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask.clone() - attention_mask = attention_mask.bool() # fix here mine + attention_mask = attention_mask.bool() # Chroma requires the attention mask to include one padding token seq_lengths = attention_mask.sum(dim=1) @@ -584,7 +584,6 @@ def _prepare_attention_mask( [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], dim=1, ) - # attention_mask = attention_mask.to(dtype) return attention_mask From 26c33efabb28aeb03aabbce7859b6bd3c9196c33 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:54:31 +0530 Subject: [PATCH 5/6] refactor: removing redundant type conversion --- src/diffusers/pipelines/chroma/pipeline_chroma.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index 75a4722cf442..f28eb1fc876c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -233,7 +233,6 @@ def _get_t5_prompt_embeds( ) text_input_ids = text_inputs.input_ids attention_mask = text_inputs.attention_mask.clone() - attention_mask = attention_mask.bool() # Chroma requires the attention mask to include one padding token seq_lengths = attention_mask.sum(dim=1) From 7c16aa8eada455784dc36057755208495fc0a936 Mon Sep 17 00:00:00 2001 From: Akshay Babbar <19975437+akshay-babbar@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:05:01 +0530 Subject: [PATCH 6/6] Remove dedicated dtype tests as per feedback --- .../pipelines/chroma/test_pipeline_chroma.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py index 2b936d27a3d4..3edd58b75f82 100644 --- a/tests/pipelines/chroma/test_pipeline_chroma.py +++ b/tests/pipelines/chroma/test_pipeline_chroma.py @@ -5,7 +5,6 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler -from diffusers.utils.testing_utils import slow from ...testing_utils import torch_device from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist @@ -159,26 +158,3 @@ def test_chroma_image_output_shape(self): image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) - - -class ChromaPipelineAttentionMaskTests(unittest.TestCase): - def setUp(self): - self.pipe = ChromaPipeline.from_pretrained( - "lodestones/Chroma1-Base", - torch_dtype=torch.float16, - ) - - @slow - def test_attention_mask_dtype_is_bool_short_prompt(self): - prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds("man") - self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}") - self.assertGreater(prompt_embeds.shape[0], 0) - self.assertGreater(prompt_embeds.shape[1], 0) - - @slow - def test_attention_mask_dtype_is_bool_long_prompt(self): - long_prompt = "a detailed portrait of a man standing in a garden with flowers and trees" - prompt_embeds, attn_mask = self.pipe._get_t5_prompt_embeds(long_prompt) - self.assertEqual(attn_mask.dtype, torch.bool, f"Expected bool, got {attn_mask.dtype}") - self.assertGreater(prompt_embeds.shape[0], 0) - self.assertGreater(prompt_embeds.shape[1], 0)