Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions tests/models/qwen2_audio/test_modeling_qwen2_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_sdpa_can_dispatch_composite_models(self):
@require_torch
class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
cleanup(torch_device, gc_collect=True)
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")

def tearDown(self):
Expand All @@ -206,7 +207,9 @@ def tearDown(self):
@slow
def test_small_model_integration_test_single(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
model = Qwen2AudioForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16
)

url = "https://huggingface.co/datasets/raushan-testing-hf/audio-test/resolve/main/glass-breaking-151256.mp3"
messages = [
Expand All @@ -223,47 +226,35 @@ def test_small_model_integration_test_single(self):

formatted_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)

inputs = self.processor(text=formatted_prompt, audios=[raw_audio], return_tensors="pt", padding=True)
inputs = self.processor(text=formatted_prompt, audio=[raw_audio], return_tensors="pt", padding=True).to(
torch_device
)

torch.manual_seed(42)
output = model.generate(**inputs, max_new_tokens=32)

# fmt: off
EXPECTED_INPUT_IDS = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 101,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))

EXPECTED_DECODED_TEXT = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>"
+ "<|AUDIO|>" * 101
+ "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
EXPECTED_INPUT_IDS = torch.tensor(
[[151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647, *[151646] * 101 , 151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198]],
device=torch_device
)
# fmt: on
torch.testing.assert_close(inputs["input_ids"], EXPECTED_INPUT_IDS)

# fmt: off
EXPECTED_DECODED_TEXT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nAudio 1: <|audio_bos|>" + "<|AUDIO|>" * 101 + "<|audio_eos|>\nWhat's that sound?<|im_end|>\n<|im_start|>assistant\nIt is the sound of glass breaking.<|im_end|>"
# fmt: on
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=False),
EXPECTED_DECODED_TEXT,
)

# test the error when incorrect number of audio tokens
# fmt: off
inputs["input_ids"] = torch.tensor([[
151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 14755, 220, 16, 25, 220, 151647,
*[151646] * 200,
151648, 198, 3838, 594, 429, 5112, 30, 151645, 198, 151644, 77091, 198,
]])
# fmt: on
with self.assertRaisesRegex(
ValueError, "Audio features and audio tokens do not match: tokens: 200, features 101"
):
model.generate(**inputs, max_new_tokens=32)

@slow
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
model = Qwen2AudioForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16
)

conversation1 = [
{
Expand Down Expand Up @@ -322,23 +313,27 @@ def test_small_model_integration_test_batch(self):
)[0]
)

inputs = self.processor(text=text, audios=audios, return_tensors="pt", padding=True)
inputs = self.processor(text=text, audio=audios, return_tensors="pt", padding=True).to(torch_device)

torch.manual_seed(42)
output = model.generate(**inputs, max_new_tokens=32)

EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nWhat can you hear?\nassistant\ncough and throat clearing.",
"system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat does the person say?\nassistant\nThe original content of this audio is: 'Mister Quiller is the apostle of the middle classes and we are glad to welcome his gospel.'",
]

self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)

@slow
def test_small_model_integration_test_multiturn(self):
def test_small_model_integration_test_multiurn(self):
# Let' s make sure we test the preprocessing to replace what is used
model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B-Instruct")
model = Qwen2AudioForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-Audio-7B-Instruct", device_map=torch_device, dtype=torch.float16
)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
Expand Down Expand Up @@ -379,12 +374,15 @@ def test_small_model_integration_test_multiturn(self):
)[0]
)

inputs = self.processor(text=formatted_prompt, audios=audios, return_tensors="pt", padding=True)
inputs = self.processor(text=formatted_prompt, audio=audios, return_tensors="pt", padding=True).to(
torch_device
)

torch.manual_seed(42)
output = model.generate(**inputs, max_new_tokens=32, top_k=1)

EXPECTED_DECODED_TEXT = [
"system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nHow about this one?\nassistant\nThroat clearing.",
"system\nYou are a helpful assistant.\nuser\nAudio 1: \nWhat's that sound?\nassistant\nIt is the sound of glass shattering.\nuser\nAudio 2: \nHow about this one?\nassistant\nThroat clearing."
]
self.assertEqual(
self.processor.batch_decode(output, skip_special_tokens=True),
Expand Down