Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 181 additions & 1 deletion tests/collections/multimodal/test_speechllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_speechllm_dataset(tokenizer, cuts):
)

batch = dataset[cuts]
print(batch)

expected_keys = {
"sample_ids",
Expand Down Expand Up @@ -164,3 +163,184 @@ def test_speechllm_dataset(tokenizer, cuts):
0., 0., 0., 0., 0., 0., 0., 0., 0.]])
)
# fmt: on


@pytest.fixture
def llama_tokenizer(capsys, tmp_path_factory):
TOKENIZER_TRAIN_TEXT = """
a b c d e f g h i j k l m n o p q r s t u v x y z
A B C D E F G H I J K L M N O P Q R S T U V X Y Z
[EOG]
Example system message.
Example user message.
Example assistant message.
TEST
[INST]
[/INST]
<s>
</s>
<<SYS>>
<</SYS>>
User: Assistant:
user model
Instruct Output
\n\n
<start_of_turn> <end_of_turn>
<|
|>
<|en|> <|de|> <|fr|> <|es|> <|transcribe|> <|translate|> <|pnc|> <|nopnc|> <|startoftranscript|> <|endoftext|>
Feel free to add new tokens for your own tests!?
But know that if you do so, you may need to update the token IDs in the existing tests!
So, it might be a good idea to create a new tokenizer instead when adding new prompt formats.
"""
tmpdir = tmp_path_factory.mktemp("bpe_tokenizer")
text_path = tmpdir / "text.txt"
text_path.write_text(TOKENIZER_TRAIN_TEXT)
with capsys.disabled():
create_spt_model(
str(text_path),
vocab_size=512,
sample_size=-1,
do_lower_case=False,
output_dir=str(tmpdir),
bos=True,
eos=True,
user_defined_symbols=["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>", "[EOG]"],
)
return SentencePieceTokenizer(str(tmpdir / "tokenizer.model"))


def test_speechllm_dataset_prompt_template(llama_tokenizer, cuts):
tokenizer = llama_tokenizer
text_processor = TextProcessing(
tokenizer=tokenizer,
prompt_template='[INST]\n<<SYS>>\nPlease answer the following based on the previous speech feature.\n<</SYS>>\n\n{context}[/INST] {answer}',
context_key="context",
answer_key="answer",
add_eos=True,
add_sep=False,
add_bos=False,
separate_prompt_and_response_with_newline=False,
end_string="[EOG]",
)
dataset = LhotseAudioQuestionAnswerDataset(
text_processor=text_processor,
default_context="do this task",
tokens_to_generate=128,
pad_to_max_length=False,
max_seq_length=128,
)

batch = dataset[cuts]
print(batch)

expected_keys = {
"sample_ids",
"audio_signal",
"audio_signal_length",
"audio_ratio",
"metadata",
"tokens",
"tokens_length",
"labels",
"loss_mask",
"position_ids",
"contexts",
"context_lengths",
"max_length",
"answers",
}
missing_keys = expected_keys - set(batch)
unexpected_keys = set(batch) - expected_keys
assert not missing_keys and not unexpected_keys, f"{missing_keys=} {unexpected_keys=}"

assert batch["sample_ids"] == ["ex0"]
assert batch["metadata"] == [{'audio_filepath': 'ex0.wav'}]
torch.testing.assert_close(batch["audio_ratio"], tensor([1.0]))
torch.testing.assert_close(batch["max_length"], tensor([128]))

assert torch.is_tensor(batch["audio_signal"])
assert torch.is_floating_point(batch["audio_signal"])
assert batch["audio_signal"].shape == (1, 80000)
torch.testing.assert_close(batch["audio_signal_length"], tensor([80000], dtype=torch.int32))

# fmt: off
expected = tensor([[ 8, 3, 8, 5, 8, 105, 18, 9, 12, 17, 9, 41, 14, 17,
22, 125, 43, 9, 117, 19, 18, 18, 79, 48, 15, 92, 12, 17,
9, 42, 8, 19, 14, 43, 9, 85, 21, 9, 114, 45, 19, 86,
17, 72, 20, 9, 9, 32, 46, 117, 9, 123, 69, 9, 25, 8,
6, 8, 93, 14, 8, 74, 88, 12, 86, 18, 13, 85, 21, 19,
27, 13, 116, 19, 14, 13, 78, 13, 4, 72, 19, 84, 9, 8,
65, 120, 45, 19, 14, 8, 7, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2]])
torch.testing.assert_close(batch["tokens"], expected)
torch.testing.assert_close(batch["tokens_length"], tensor([91]))
assert tokenizer.ids_to_text(expected[0, :91].tolist()) == "[INST] <<SYS>> Please answer the following based on the previous speech feature. <</SYS>> non default prompt context[/INST] some transcription [EOG]"

expected = tensor([[ 8, 3, 8, 5, 8, 105, 18, 9, 12, 17, 9, 41, 14, 17,
22, 125, 43, 9, 117, 19, 18, 18, 79, 48, 15, 92, 12, 17,
9, 42, 8, 19, 14, 43, 9, 85, 21, 9, 114, 45, 19, 86,
17, 72, 20, 9, 9, 32, 46, 117, 9, 123, 69, 9, 25, 8,
6, 8, 93, 14, 8, 74, 88, 12, 86, 18, 13, 85, 21, 19,
27, 13, 116, 19, 14, 13, 78, 13, 4, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2]])
torch.testing.assert_close(batch["contexts"], expected)
torch.testing.assert_close(batch["context_lengths"], tensor([79]))
assert tokenizer.ids_to_text(expected[0, :79].tolist()) == "[INST] <<SYS>> Please answer the following based on the previous speech feature. <</SYS>> non default prompt context[/INST]"

expected = tensor([[ 72, 19, 84, 9, 8, 65, 120, 45, 19, 14, 8, 7, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2]])
torch.testing.assert_close(batch["answers"], expected)
assert tokenizer.ids_to_text(expected[0, :12].tolist()) == "some transcription [EOG]"

expected = tensor([[ 3, 8, 5, 8, 105, 18, 9, 12, 17, 9, 41, 14, 17, 22,
125, 43, 9, 117, 19, 18, 18, 79, 48, 15, 92, 12, 17, 9,
42, 8, 19, 14, 43, 9, 85, 21, 9, 114, 45, 19, 86, 17,
72, 20, 9, 9, 32, 46, 117, 9, 123, 69, 9, 25, 8, 6,
8, 93, 14, 8, 74, 88, 12, 86, 18, 13, 85, 21, 19, 27,
13, 116, 19, 14, 13, 78, 13, 4, 72, 19, 84, 9, 8, 65,
120, 45, 19, 14, 8, 7, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2]])
torch.testing.assert_close(batch["labels"], expected)
assert tokenizer.ids_to_text(expected[0, :90].tolist()) == "[INST] <<SYS>> Please answer the following based on the previous speech feature. <</SYS>> non default prompt context[/INST] some transcription [EOG]"

torch.testing.assert_close(
batch["position_ids"],
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127]])
)

torch.testing.assert_close(
batch["loss_mask"],
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.]])
)
# fmt: on