From 95a478167cdc9a9e3e6c2c5afe97e061de44c7a5 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Thu, 23 Apr 2026 05:45:13 +0000 Subject: [PATCH 1/8] fix 2 failed test cases for blt model on XPU Signed-off-by: Liu, Kaixuan --- tests/models/blt/test_modeling_blt.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index a3f50157b38a..fe2ca9555e69 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -20,6 +20,7 @@ from transformers import AutoTokenizer, is_torch_available from transformers.testing_utils import ( + Expectations, cleanup, require_torch, require_torch_accelerator, @@ -343,7 +344,14 @@ def test_model_logits(self): def test_model_bf16(self): """Test Blt model with bfloat16 precision.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + # fmt: off + EXPECTED_TEXT = Expectations( + { + (None, None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m", + ("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", + } + ) + # fmt: on prompt = "my name is" @@ -360,7 +368,7 @@ def test_model_bf16(self): ) output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(output_text, EXPECTED_TEXT) + self.assertEqual(output_text, EXPECTED_TEXT.get_expectation()) @slow @require_torch_bf16 @@ -473,7 +481,14 @@ def test_model_eager(self): def test_model_bf16_static_cache(self): """Test Blt model with bfloat16 precision and static cache.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m" + # fmt: off + EXPECTED_TEXT = Expectations( + { + (None, None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m", + ("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", + } + ) + # fmt: on prompt = "my name is" @@ -492,4 +507,4 @@ def test_model_bf16_static_cache(self): ) output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(output_text, EXPECTED_TEXT) + self.assertEqual(output_text, EXPECTED_TEXT.get_expectation()) From 01dd2ed918ef9529a1edc9b409a21ff6dd88f6d9 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 17:30:04 +0200 Subject: [PATCH 2/8] update --- tests/models/blt/test_modeling_blt.py | 142 ++------------------------ 1 file changed, 11 insertions(+), 131 deletions(-) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index fe2ca9555e69..a9bcb220a6a7 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -261,74 +261,14 @@ def test_model(self): @slow def test_model_logits(self): + # fmt: off EXPECTED_OUTPUT = torch.tensor( [ - [ - -10.4948, - -10.7065, - -6.1813, - -10.5545, - -10.3428, - -9.1493, - -8.4937, - -8.6382, - -9.2159, - -9.5907, - -9.3679, - -8.4184, - -9.0655, - -3.4436, - 2.9616, - -10.3157, - -6.3723, - -6.0133, - -9.7100, - -9.2128, - -8.8064, - -9.8179, - -9.7516, - -9.4681, - -9.7715, - -9.4897, - -9.0491, - -9.8098, - -9.4648, - -9.3294, - ], - [ - -13.3010, - -13.1910, - -5.7230, - -13.2895, - -13.4864, - -8.7140, - -7.0275, - -7.0182, - -10.1362, - -10.3762, - -9.9086, - -7.8049, - -8.8660, - -5.2711, - -3.5778, - -12.5346, - -9.1609, - -6.7925, - -10.3717, - -9.2650, - -10.6393, - -11.4807, - -11.2128, - -10.9615, - -10.5806, - -10.8873, - -11.0651, - -11.3471, - -10.5437, - -9.9688, - ], + [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], + [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], ] ).to(torch_device) + # fmt: on input_ids = [1, 42, 21, 12, 43, 23, 1, 4] @@ -337,7 +277,7 @@ def test_model_logits(self): with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] - torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) @slow @require_torch_bf16 @@ -347,7 +287,7 @@ def test_model_bf16(self): # fmt: off EXPECTED_TEXT = Expectations( { - (None, None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m", + (None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", ("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", } ) @@ -375,74 +315,14 @@ def test_model_bf16(self): def test_model_logits_bf16(self): """Test Blt model logits with bfloat16 precision.""" + # fmt: off EXPECTED_OUTPUT = torch.tensor( [ - [ - -10.5000, - -10.6875, - -6.1875, - -10.5625, - -10.3125, - -9.1875, - -8.5000, - -8.6875, - -9.1875, - -9.5625, - -9.3750, - -8.5000, - -9.0625, - -3.4219, - 2.9531, - -10.3125, - -6.4062, - -6.0000, - -9.6875, - -9.1875, - -8.8125, - -9.8125, - -9.7500, - -9.4375, - -9.8125, - -9.5000, - -9.0000, - -9.8125, - -9.4375, - -9.3125, - ], - [ - -13.2500, - -13.1875, - -5.6875, - -13.3125, - -13.5000, - -8.7500, - -7.0625, - -7.0312, - -10.1250, - -10.3750, - -9.8750, - -7.8438, - -8.8750, - -5.2812, - -3.5625, - -12.5000, - -9.1875, - -6.8125, - -10.3750, - -9.3125, - -10.6250, - -11.5000, - -11.2500, - -11.0000, - -10.5625, - -10.8750, - -11.0625, - -11.3750, - -10.5625, - -10.0000, - ], + [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], + [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], ] ).to(torch_device) + # fmt: on input_ids = [1, 42, 21, 12, 43, 23, 1, 4] @@ -484,7 +364,7 @@ def test_model_bf16_static_cache(self): # fmt: off EXPECTED_TEXT = Expectations( { - (None, None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m", + (None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", ("xpu", None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", } ) From 1dc2daaa86f03f990178899eb5a38b5191b2a2c6 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 17:50:46 +0200 Subject: [PATCH 3/8] update --- src/transformers/models/blt/modeling_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 778f7ba80cf6..32ffe165a019 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1228,7 +1228,7 @@ def forward( else: batch_size, sequence_length = input_ids.shape encoder_embeds = compute_hash_embeddings( - input_ids, + input_ids.to(self.local_encoder.embed_tokens.weight.device), self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, From 61790a4517308bde3c2dd3ead0d6596a4ef53074 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 18:15:06 +0200 Subject: [PATCH 4/8] update --- src/transformers/models/blt/modeling_blt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 32ffe165a019..7292dfd1bed8 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -1241,7 +1241,7 @@ def forward( if input_ids is None: raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( - input_ids, + input_ids.to(self.patcher.embed_tokens.weight.device), patch_size=self.config.patch_size, threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, From c1e1da5486602865a72fc4bcb7db3a69e1a5bc17 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 18:17:38 +0200 Subject: [PATCH 5/8] update --- tests/models/blt/test_modeling_blt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index a9bcb220a6a7..a67e2fdfcfb4 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -277,7 +277,7 @@ def test_model_logits(self): with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] - torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30].to(torch_device), rtol=1e-3, atol=1e-3) @slow @require_torch_bf16 @@ -333,7 +333,7 @@ def test_model_logits_bf16(self): with torch.no_grad(): output = model(torch.tensor([input_ids]).to(torch_device))[0] - torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30], rtol=1e-3, atol=1e-3) + torch.testing.assert_close(EXPECTED_OUTPUT, output[0, :2, :30].to(torch_device), rtol=1e-3, atol=1e-3) @slow def test_model_eager(self): From 9a394f6b0083a706490d77c8ed71cf441e16da51 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 18:22:33 +0200 Subject: [PATCH 6/8] update --- src/transformers/models/blt/modular_blt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 4dfff77263db..fbe4ca0aa1ce 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -955,7 +955,7 @@ def forward( else: batch_size, sequence_length = input_ids.shape encoder_embeds = compute_hash_embeddings( - input_ids, + input_ids.to(self.local_encoder.embed_tokens.weight.device), self.local_encoder, self.encoder_hash_tok_embedding, self.config.encoder_hash_byte_group_nb_functions, @@ -968,7 +968,7 @@ def forward( if input_ids is None: raise ValueError("input_ids is required for entropy-based patching") _, patch_lengths, _ = self.patcher( - input_ids, + input_ids.to(self.patcher.embed_tokens.weight.device), patch_size=self.config.patch_size, threshold=self.config.patching_threshold, max_patch_length=self.config.max_patch_length, From dc624a04988c3cdb30c66e35718821d26d2e1431 Mon Sep 17 00:00:00 2001 From: ydshieh Date: Tue, 28 Apr 2026 19:14:14 +0200 Subject: [PATCH 7/8] update --- src/transformers/models/blt/modeling_blt.py | 2 +- src/transformers/models/blt/modular_blt.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 7292dfd1bed8..b6501ab78f3f 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -670,7 +670,7 @@ def forward( attention_mask=encoder_attention_mask, **kwargs, ) - patch_embeds = patch_embeds + cross_attention_output + patch_embeds = patch_embeds + cross_attention_output.to(patch_embeds.device) encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index fbe4ca0aa1ce..8796e2a40e82 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -604,7 +604,7 @@ def forward( attention_mask=encoder_attention_mask, **kwargs, ) - patch_embeds = patch_embeds + cross_attention_output + patch_embeds = patch_embeds + cross_attention_output.to(patch_embeds) encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states From 63682199f4ccaff2c72abf43518a3ab1b4b0aa58 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Wed, 29 Apr 2026 02:44:13 +0000 Subject: [PATCH 8/8] fix another failed cases Signed-off-by: Liu, Kaixuan --- src/transformers/models/blt/modular_blt.py | 2 +- tests/models/blt/test_modeling_blt.py | 61 +++++++++++++++++----- 2 files changed, 48 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 8796e2a40e82..eb0624495873 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -604,7 +604,7 @@ def forward( attention_mask=encoder_attention_mask, **kwargs, ) - patch_embeds = patch_embeds + cross_attention_output.to(patch_embeds) + patch_embeds = patch_embeds + cross_attention_output.to(patch_embeds.device) encoder_cross_states = patch_embeds return hidden_states, encoder_cross_states diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index a67e2fdfcfb4..7c7a36b0a0df 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -185,6 +185,10 @@ class BltModelTest(CausalLMModelTest, unittest.TestCase): def test_generate_from_inputs_embeds(self, _, num_beams): pass + @pytest.mark.generate + def test_generate_with_quant_cache(self): + self.skipTest("BLT uses EncoderDecoderCache internally and does not support quantized cache") + @pytest.mark.generate @unittest.skip( "Blt requires real token IDs for its hash-based embedding computation, making inputs_embeds generation incompatible with identical outputs" @@ -262,12 +266,23 @@ def test_model(self): @slow def test_model_logits(self): # fmt: off - EXPECTED_OUTPUT = torch.tensor( - [ - [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], - [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], - ] - ).to(torch_device) + EXPECTED_OUTPUT = Expectations( + { + (None, None): torch.tensor( + [ + [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], + [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], + ] + ), + ("xpu", None): torch.tensor( + [ + [-10.4375, -10.6875, -6.1875, -10.5000, -10.3125, -9.1250, -8.4375, -8.6250, -9.1875, -9.5625, -9.3125, -8.4375, -9.0625, -3.4375, 2.9531, -10.2500, -6.4062, -6.0000, -9.6875, -9.1875, -8.8125, -9.8125, -9.7500, -9.4375, -9.7500, -9.4375, -9.0000, -9.8125, -9.4375, -9.3125], + [-13.3125, -13.2500, -5.5938, -13.3125, -13.5000, -8.7500, -7.0625, -7.0312, -10.1875, -10.3750, -9.9375, -7.8438, -8.8750, -5.3438, -3.5938, -12.5625, -9.2500, -6.8125, -10.3750, -9.3125, -10.6875, -11.5625, -11.3125, -11.0000, -10.6250, -10.9375, -11.0625, -11.3750, -10.5625, -10.0000], + ] + ), + } + ).get_expectation() + EXPECTED_OUTPUT = EXPECTED_OUTPUT.to(torch_device) # fmt: on input_ids = [1, 42, 21, 12, 43, 23, 1, 4] @@ -316,12 +331,23 @@ def test_model_logits_bf16(self): """Test Blt model logits with bfloat16 precision.""" # fmt: off - EXPECTED_OUTPUT = torch.tensor( - [ - [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], - [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], - ] - ).to(torch_device) + EXPECTED_OUTPUT = Expectations( + { + (None, None): torch.tensor( + [ + [-10.5000, -10.6875, -6.2500, -10.5625, -10.3125, -9.1875, -8.5000, -8.5625, -9.1875, -9.6250, -9.3750, -8.5000, -9.1250, -3.3906, 2.9688, -10.3125, -6.4688, -6.0312, -9.7500, -9.1875, -8.8125, -9.8750, -9.8125, -9.5000, -9.8125, -9.5000, -9.0625, -9.8125, -9.5000, -9.3750], + [-13.2500, -13.1250, -5.6875, -13.1875, -13.3750, -8.6875, -6.9688, -6.9375, -10.0625, -10.3125, -9.8125, -7.7188, -8.8125, -5.2188, -3.5000, -12.4375, -9.0625, -6.6250, -10.3125, -9.1875, -10.6250, -11.4375, -11.1250, -10.8750, -10.5000, -10.8750, -11.0000, -11.3125, -10.5000, -9.8750], + ] + ), + ("xpu", None): torch.tensor( + [ + [-10.4375, -10.6875, -6.1875, -10.5000, -10.3125, -9.1250, -8.4375, -8.6250, -9.1875, -9.5625, -9.3125, -8.4375, -9.0625, -3.4375, 2.9531, -10.2500, -6.4062, -6.0000, -9.6875, -9.1875, -8.8125, -9.8125, -9.7500, -9.4375, -9.7500, -9.4375, -9.0000, -9.8125, -9.4375, -9.3125], + [-13.3125, -13.2500, -5.5938, -13.3125, -13.5000, -8.7500, -7.0625, -7.0312, -10.1875, -10.3750, -9.9375, -7.8438, -8.8750, -5.3438, -3.5938, -12.5625, -9.2500, -6.8125, -10.3750, -9.3125, -10.6875, -11.5625, -11.3125, -11.0000, -10.6250, -10.9375, -11.0625, -11.3750, -10.5625, -10.0000], + ] + ), + } + ).get_expectation() + EXPECTED_OUTPUT = EXPECTED_OUTPUT.to(torch_device) # fmt: on input_ids = [1, 42, 21, 12, 43, 23, 1, 4] @@ -339,7 +365,14 @@ def test_model_logits_bf16(self): def test_model_eager(self): """Test Blt model with bfloat16 precision using eager attention implementation.""" NUM_TOKENS_TO_GENERATE = 200 - EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s" + # fmt: off + EXPECTED_TEXT = Expectations( + { + (None, None): "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s", + ("xpu", None): "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m", + } + ) + # fmt: on prompt = "my name is" @@ -354,7 +387,7 @@ def test_model_eager(self): ) output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) - self.assertEqual(output_text, EXPECTED_TEXT) + self.assertEqual(output_text, EXPECTED_TEXT.get_expectation()) @slow @require_torch_bf16