diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 538353fee44d..4d26443f63d6 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -198,6 +198,7 @@ def test_sdpa_can_dispatch_composite_models(self): @require_torch class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): + cleanup(torch_device, gc_collect=True) self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") def tearDown(self): @@ -206,7 +207,9 @@ def tearDown(self): @slow def test_small_model_integration_test_single(self): # Let' s make sure we test the preprocessing to replace what is used - model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") + model = Qwen2AudioForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16 + ) url = "https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/glass-breaking-151256.mp3" messages = [ @@ -223,47 +226,35 @@ def test_small_model_integration_test_single(self): formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) - inputs = self.processor(text=formatted_prompt, audios=[raw_audio], return_tensors="pt", padding=True) + inputs = self.processor(text=formatted_prompt, audio=[raw_audio], return_tensors="pt", padding=True).to( + torch_device + ) + torch.manual_seed(42) output = model.generate(**inputs, max_new_tokens=32) # fmt: off - EXPECTED_INPUT_IDS = torch.tensor([[ - 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, - *[151646] * 101, - 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198, - ]]) - # fmt: on - self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS)) - - EXPECTED_DECODED_TEXT = ( - "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>" - + "<|AUDIO|>" * 101 - + "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>" + EXPECTED_INPUT_IDS = torch.tensor( + [[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, *[151646] * 101 , 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198]], + device=torch_device ) + # fmt: on + torch.testing.assert_close(inputs["input_ids"], EXPECTED_INPUT_IDS) + # fmt: off + EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>" + "<|AUDIO|>" * 101 + "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>" + # fmt: on self.assertEqual( self.processor.decode(output[0], skip_special_tokens=False), EXPECTED_DECODED_TEXT, ) - # test the error when incorrect number of audio tokens - # fmt: off - inputs["input_ids"] = torch.tensor([[ - 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, - *[151646] * 200, - 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198, - ]]) - # fmt: on - with self.assertRaisesRegex( - ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101" - ): - model.generate(**inputs, max_new_tokens=32) - @slow def test_small_model_integration_test_batch(self): # Let' s make sure we test the preprocessing to replace what is used - model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") + model = Qwen2AudioForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16 + ) conversation1 = [ { @@ -322,23 +313,27 @@ def test_small_model_integration_test_batch(self): )[0] ) - inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True) + inputs = self.processor(text=text, audio=audios, return_tensors="pt", padding=True).to(torch_device) + torch.manual_seed(42) output = model.generate(**inputs, max_new_tokens=32) EXPECTED_DECODED_TEXT = [ "system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nWhat can you hear?\nassistant\ncough and throat clearing.", "system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat does the person say?\nassistant\nThe original content of this audio is: 'Mister Quiller is the apostle of the middle classes and we are glad to welcome his gospel.'", ] + self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) @slow - def test_small_model_integration_test_multiturn(self): + def test_small_model_integration_test_multiurn(self): # Let' s make sure we test the preprocessing to replace what is used - model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct") + model = Qwen2AudioForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16 + ) messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -379,12 +374,15 @@ def test_small_model_integration_test_multiturn(self): )[0] ) - inputs = self.processor(text=formatted_prompt, audios=audios, return_tensors="pt", padding=True) + inputs = self.processor(text=formatted_prompt, audio=audios, return_tensors="pt", padding=True).to( + torch_device + ) + torch.manual_seed(42) output = model.generate(**inputs, max_new_tokens=32, top_k=1) EXPECTED_DECODED_TEXT = [ - "system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nHow about this one?\nassistant\nThroat clearing.", + "system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nHow about this one?\nassistant\nThroat clearing." ] self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True),