From 0259f1c8ed1918626fd17f860c9f20505e21bc16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 30 Sep 2024 13:24:47 -0400 Subject: [PATCH 1/4] Draft implementation of NeMo Multimodal Conversation format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 19 +++ .../common/data/lhotse/text_adapters.py | 127 ++++++++++++++++++ 2 files changed, 146 insertions(+) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index d1f8c5ba03ef..489ce1193f0e 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -29,6 +29,7 @@ from nemo.collections.common.data.lhotse.text_adapters import ( LhotseTextAdapter, LhotseTextPairAdapter, + NeMoMultimodalConversationJsonlAdapter, NeMoSFTJsonlAdapter, ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path @@ -173,6 +174,9 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: elif grp_cfg.type == "nemo_sft_jsonl": is_tarred = True cuts = read_nemo_sft_jsonl(grp_cfg) + elif grp_cfg.type == "multimodal_conversation": + is_tarred = True + cuts = read_multimodal_conversation_jsonl(grp_cfg) elif grp_cfg.type == "group": cuts, is_tarred = parse_and_combine_datasets( grp_cfg.input_cfg, @@ -223,6 +227,21 @@ def read_nemo_sft_jsonl(config: DictConfig) -> CutSet: ).repeat() +def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet: + cuts = CutSet( + NeMoMultimodalConversationJsonlAdapter( + manifest_filepath=config.manifest_filepath, + tarred_audio_filepaths=config.tarred_audio_filepaths, + audio_locator_tag=config.audio_locator_tag, + shuffle_shards=config.shuffle, + shard_seed=config.shard_seed, + ) + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts + + def attach_tags(cut, tags: dict): for key, val in tags.items(): setattr(cut, key, val) diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 3d1138d427f2..cc3e27eb3338 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -20,11 +20,14 @@ import numpy as np import torch +from lhotse import Recording +from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.serialization import load_jsonl from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging @@ -306,6 +309,130 @@ def __iter__(self) -> Iterator[NeMoSFTExample]: yield NeMoSFTExample(data, language=self.language) +@dataclass +class TextTurn: + value: str + role: str + + +@dataclass +class AudioTurn: + cut: Cut + role: str + audio_locator_tag: str + + +@dataclass +class NeMoMultimodalConversation: + id: str + turns: list[TextTurn | AudioTurn] + input_ids: np.ndarray | None = None + context_ids: np.ndarray | None = None + answer_ids: np.ndarray | None = None + mask: np.ndarray | None = None + + def tokenize( + self, + tokenizer: TokenizerWrapper | TokenizerSpec, + prompt: PromptFormatter = None, + ) -> "NeMoMultimodalConversation": + """ + Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). + Supports BPE tokenizers and aggregate tokenizers. + + The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. + """ + if isinstance(tokenizer, TokenizerWrapper): + tokenizer = tokenizer._tokenizer + if isinstance(tokenizer, AggregateTokenizer): + raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.") + if prompt is None: + prompt = PromptFormatter.resolve("plain")(tokenizer) + + ans = prompt.encode_dialog( + [ + { + "role": turn.role, + "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_token}, + } + for turn in self.turns + ] + ) + self.input_ids = ans["input_ids"] + self.context_ids = ans["context_ids"] + self.answer_ids = ans["answer_ids"] + self.mask = ans["mask"] + + return self + + +@dataclass +class NeMoMultimodalConversationJsonlAdapter: + """ + ``NeMoMultimodalConversationJsonlAdapter`` is used to read a NeMo multimodal conversation JSONL + and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse. + + We expect the following schema (contained in a single line per example):: + + { + "id": str, + "conversations": [ + { + "value": str, # text message or path to audio + "from": "User" | "Assistant", + "type": "text" | "audio", + "duration": float, # only for audio + }, + ... + ], + } + """ + + manifest_filepath: str | list[str] + audio_locator_tag: str + tarred_audio_filepaths: str | list[str] = None + shuffle_shards: bool = False + shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" + + def __post_init__(self): + self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath) + if self.tarred_audio_filepaths is not None: + raise NotImplementedError( + "Tarred manifests are currently not supported yet for NeMoMultimodalConversation." + ) + self.tarred_audio_filepaths = expand_sharded_filepaths(self.tarred_audio_filepaths) + + def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + paths = self.manifest_filepath + if self.shuffle_shards: + seed = resolve_seed(self.shard_seed) + random.Random(seed).shuffle(paths) + for path in paths: + for data in load_jsonl(path): + yield NeMoMultimodalConversation( + id=data["id"], + turns=[ + ( + TextTurn( + value=turn["value"], + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + ) + if turn["type"] == "text" + else AudioTurn( + cut=Recording.from_file(turn["value"]).to_cut(), + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + audio_locator_tag=self.audio_locator_tag, + ) + ) + for turn in data["conversations"] + ], + ) + + """ The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py with minimal modifications in order to avoid importing the NLP collection. From 1833c445d15dd5884723dab1149b4841e9b4b22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 1 Oct 2024 10:31:29 -0400 Subject: [PATCH 2/4] Fully working data parsing and iteration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 14 ++- .../common/data/lhotse/text_adapters.py | 3 +- .../test_lhotse_multimodal_dataloading.py | 118 ++++++++++++++++++ 3 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 tests/collections/common/test_lhotse_multimodal_dataloading.py diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 489ce1193f0e..24c0ffaf59b7 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -63,7 +63,17 @@ def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]: KNOWN_DATASET_CONFIG_TYPES = frozenset( - ("nemo", "nemo_tarred", "lhotse", "lhotse_shar", "txt", "txt_pair", "nemo_sft_jsonl", "group") + ( + "nemo", + "nemo_tarred", + "lhotse", + "lhotse_shar", + "txt", + "txt_pair", + "nemo_sft_jsonl", + "multimodal_conversation", + "group", + ) ) @@ -231,7 +241,7 @@ def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet: cuts = CutSet( NeMoMultimodalConversationJsonlAdapter( manifest_filepath=config.manifest_filepath, - tarred_audio_filepaths=config.tarred_audio_filepaths, + tarred_audio_filepaths=config.get("tarred_audio_filepaths"), audio_locator_tag=config.audio_locator_tag, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index cc3e27eb3338..ce697b45c282 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -27,6 +27,7 @@ from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -421,7 +422,7 @@ def __iter__(self) -> Iterator[NeMoMultimodalConversation]: ) if turn["type"] == "text" else AudioTurn( - cut=Recording.from_file(turn["value"]).to_cut(), + cut=Recording.from_file(get_full_path(turn["value"], path)).to_cut(), role=turn[ "from" ].lower(), # prompt formatter role's are typically lowercase: user/assistant diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py new file mode 100644 index 000000000000..71bea44d09a6 --- /dev/null +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -0,0 +1,118 @@ +import json +from itertools import islice + +import lhotse +import pytest +import torch +from lhotse.testing.dummies import dummy_cut, dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn + + +class Identity(torch.utils.data.Dataset): + def __getitem__(self, cuts: lhotse.CutSet) -> lhotse.CutSet: + return cuts + + +@pytest.fixture(scope="session") +def multimodal_conversations_path(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("text_data") + en_path = tmp_path / "manifest.json" + data = [ + { + "id": "convo_1", + "conversations": [ + { + "value": "Can you help summarize the following?", + "from": "User", + "type": "text", + }, + { + "value": "123.wav", + "from": "User", + "type": "audio", + "duration": 5.73, + }, + { + "value": "I'm glad to assist you with your request. Here's a summary:", + "from": "Assistant", + "type": "text", + }, + { + "value": "123_answer.wav", + "from": "Assistant", + "type": "audio", + "duration": 7.11, + }, + { + "value": "Can you further shorten it?", + "from": "User", + "type": "text", + }, + ], + } + ] + lhotse.serialization.save_to_jsonl(data, en_path) + dummy_recording(0, 5.73, with_data=True).to_cut().save_audio(tmp_path / "123.wav") + dummy_recording(0, 7.11, with_data=True).to_cut().save_audio(tmp_path / "123_answer.wav") + return en_path + + +def test_multimodal_conversation_input(multimodal_conversations_path): + + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": multimodal_conversations_path, + "audio_locator_tag": "[audio]", + }, + ], + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # Note: this test does not need to pass a tokenizer because we use static batch sizes + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, NeMoMultimodalConversation) + assert ex.id == "convo_1" + assert len(ex.turns) == 5 + t = ex.turns[0] + assert isinstance(t, TextTurn) + assert t.role == "user" + assert t.value == "Can you help summarize the following?" + t = ex.turns[1] + assert isinstance(t, AudioTurn) + assert t.role == "user" + assert t.audio_locator_tag == "[audio]" + assert t.cut.duration == 5.73 + assert t.cut.load_audio().shape == (1, 91680) + t = ex.turns[2] + assert isinstance(t, TextTurn) + assert t.role == "assistant" + assert t.value == "I'm glad to assist you with your request. Here's a summary:" + t = ex.turns[3] + assert isinstance(t, AudioTurn) + assert t.role == "assistant" + assert t.audio_locator_tag == "[audio]" + assert t.cut.duration == 7.11 + assert t.cut.load_audio().shape == (1, 113760) + t = ex.turns[4] + assert isinstance(t, TextTurn) + assert t.role == "user" + assert t.value == "Can you further shorten it?" From 6ab340589f543ff91345551d1c26a76ff2a368ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 1 Oct 2024 10:50:53 -0400 Subject: [PATCH 3/4] Fully working dataloading with tokenization + prompting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 9 +- .../common/data/lhotse/text_adapters.py | 4 +- .../test_lhotse_multimodal_dataloading.py | 98 ++++++++++++++++++- 3 files changed, 108 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 7c42767fd7b3..14379c328ab5 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -41,7 +41,12 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config -from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample +from nemo.collections.common.data.lhotse.text_adapters import ( + NeMoMultimodalConversation, + NeMoSFTExample, + SourceTargetTextExample, + TextExample, +) from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.utils import logging @@ -736,6 +741,8 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa example.tokenized_prompted_transcript = tokenized_prompted_transcript example.tokenized_prompt = tokenized_prompt example.tokenized_transcript = tokenized_transcript + elif isinstance(example, NeMoMultimodalConversation): + example = example.tokenize(tokenizer, prompt_format) else: raise RuntimeError(f"Currently we only support tokenization + prompting during sampling for audio modality.") return example diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index ce697b45c282..74cc4217c17a 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -349,12 +349,14 @@ def tokenize( raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.") if prompt is None: prompt = PromptFormatter.resolve("plain")(tokenizer) + elif isinstance(prompt, str): + prompt = PromptFormatter.resolve(prompt)(tokenizer) ans = prompt.encode_dialog( [ { "role": turn.role, - "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_token}, + "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, } for turn in self.turns ] diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index 71bea44d09a6..2e0c1f0d5ce3 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -9,6 +9,8 @@ from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model class Identity(torch.utils.data.Dataset): @@ -51,6 +53,11 @@ def multimodal_conversations_path(tmp_path_factory): "from": "User", "type": "text", }, + { + "value": "Of course!", + "from": "Assistant", + "type": "text", + }, ], } ] @@ -91,7 +98,7 @@ def test_multimodal_conversation_input(multimodal_conversations_path): ex = b[0] assert isinstance(ex, NeMoMultimodalConversation) assert ex.id == "convo_1" - assert len(ex.turns) == 5 + assert len(ex.turns) == 6 t = ex.turns[0] assert isinstance(t, TextTurn) assert t.role == "user" @@ -116,3 +123,92 @@ def test_multimodal_conversation_input(multimodal_conversations_path): assert isinstance(t, TextTurn) assert t.role == "user" assert t.value == "Can you further shorten it?" + t = ex.turns[5] + assert isinstance(t, TextTurn) + assert t.role == "assistant" + assert t.value == "Of course!" + for key in ("input_ids", "context_ids", "answer_ids", "mask"): + assert getattr(ex, key) is None # not tokenized/prompted + + +@pytest.fixture +def tokenizer(tmp_path_factory, multimodal_conversations_path): + tmpdir = tmp_path_factory.mktemp("multi_convo_tokenizer") + text_path = tmpdir / "text.txt" + text_path.write_text( + "\n".join( + turn["value"] + for item in lhotse.serialization.load_jsonl(multimodal_conversations_path) + for turn in item["conversations"] + ) + ) + create_spt_model( + text_path, + vocab_size=128, + sample_size=-1, + do_lower_case=False, + output_dir=str(tmpdir), + bos=True, + eos=True, + user_defined_symbols=["[INST]", "[/INST]", "<>", "<>", "[audio]"], + ) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path, tokenizer): + + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": multimodal_conversations_path, + "audio_locator_tag": "[audio]", + }, + ], + "prompt_format": "llama2", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, NeMoMultimodalConversation) + + assert torch.is_tensor(ex.input_ids) + assert ex.input_ids.shape == (111,) + assert ( + tokenizer.ids_to_text(ex.input_ids) + == "[INST] Can you help summarize the following? [/INST] [INST] [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST] Of course!" + ) + + assert torch.is_tensor(ex.context_ids) + assert ex.context_ids.shape == (101,) + assert ( + tokenizer.ids_to_text(ex.context_ids) + == "[INST] Can you help summarize the following? [/INST] [INST] [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST]" + ) + + assert torch.is_tensor(ex.answer_ids) + assert ex.answer_ids.shape == (10,) + assert tokenizer.ids_to_text(ex.answer_ids) == "Of course!" + + assert torch.is_tensor(ex.mask) + assert ex.mask.shape == (111,) + assert (ex.mask[:35] == False).all() # user turn + assert (ex.mask[35:78] == True).all() # assistant turn + assert (ex.mask[78:101] == False).all() # user turn + assert (ex.mask[101:] == True).all() # assistant turn From a5a21f40423a114510743454f80391ba1f826d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 1 Oct 2024 15:28:27 -0400 Subject: [PATCH 4/4] Collapse consecutive user turns into single turn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/text_adapters.py | 12 ++++++++++-- .../test_lhotse_multimodal_dataloading.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 74cc4217c17a..79c54b89bc0d 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -15,6 +15,7 @@ import copy import random from dataclasses import dataclass +from itertools import groupby from pathlib import Path from typing import Iterator, Literal, Optional, Union @@ -352,15 +353,22 @@ def tokenize( elif isinstance(prompt, str): prompt = PromptFormatter.resolve(prompt)(tokenizer) - ans = prompt.encode_dialog( + # Collapse consecutive same-role turns into single turn for proper prompt formatting. + turns = groupby( [ { "role": turn.role, "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, } for turn in self.turns - ] + ], + key=lambda turn: turn["role"], ) + turns = [ + {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} + for role, turn_grp in turns + ] + ans = prompt.encode_dialog(turns) self.input_ids = ans["input_ids"] self.context_ids = ans["context_ids"] self.answer_ids = ans["answer_ids"] diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index 2e0c1f0d5ce3..4ded7c25d12a 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -189,17 +189,17 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path assert isinstance(ex, NeMoMultimodalConversation) assert torch.is_tensor(ex.input_ids) - assert ex.input_ids.shape == (111,) + assert ex.input_ids.shape == (105,) assert ( tokenizer.ids_to_text(ex.input_ids) - == "[INST] Can you help summarize the following? [/INST] [INST] [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST] Of course!" + == "[INST] Can you help summarize the following? [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST] Of course!" ) assert torch.is_tensor(ex.context_ids) - assert ex.context_ids.shape == (101,) + assert ex.context_ids.shape == (95,) assert ( tokenizer.ids_to_text(ex.context_ids) - == "[INST] Can you help summarize the following? [/INST] [INST] [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST]" + == "[INST] Can you help summarize the following? [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST]" ) assert torch.is_tensor(ex.answer_ids) @@ -207,8 +207,8 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path assert tokenizer.ids_to_text(ex.answer_ids) == "Of course!" assert torch.is_tensor(ex.mask) - assert ex.mask.shape == (111,) - assert (ex.mask[:35] == False).all() # user turn - assert (ex.mask[35:78] == True).all() # assistant turn - assert (ex.mask[78:101] == False).all() # user turn - assert (ex.mask[101:] == True).all() # assistant turn + assert ex.mask.shape == (105,) + assert (ex.mask[:30] == False).all() # user turn + assert (ex.mask[30:72] == True).all() # assistant turn + assert (ex.mask[72:95] == False).all() # user turn + assert (ex.mask[95:] == True).all() # assistant turn