diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index d1f8c5ba03ef..24c0ffaf59b7 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -29,6 +29,7 @@ from nemo.collections.common.data.lhotse.text_adapters import ( LhotseTextAdapter, LhotseTextPairAdapter, + NeMoMultimodalConversationJsonlAdapter, NeMoSFTJsonlAdapter, ) from nemo.collections.common.parts.preprocessing.manifest import get_full_path @@ -62,7 +63,17 @@ def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]: KNOWN_DATASET_CONFIG_TYPES = frozenset( - ("nemo", "nemo_tarred", "lhotse", "lhotse_shar", "txt", "txt_pair", "nemo_sft_jsonl", "group") + ( + "nemo", + "nemo_tarred", + "lhotse", + "lhotse_shar", + "txt", + "txt_pair", + "nemo_sft_jsonl", + "multimodal_conversation", + "group", + ) ) @@ -173,6 +184,9 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: elif grp_cfg.type == "nemo_sft_jsonl": is_tarred = True cuts = read_nemo_sft_jsonl(grp_cfg) + elif grp_cfg.type == "multimodal_conversation": + is_tarred = True + cuts = read_multimodal_conversation_jsonl(grp_cfg) elif grp_cfg.type == "group": cuts, is_tarred = parse_and_combine_datasets( grp_cfg.input_cfg, @@ -223,6 +237,21 @@ def read_nemo_sft_jsonl(config: DictConfig) -> CutSet: ).repeat() +def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet: + cuts = CutSet( + NeMoMultimodalConversationJsonlAdapter( + manifest_filepath=config.manifest_filepath, + tarred_audio_filepaths=config.get("tarred_audio_filepaths"), + audio_locator_tag=config.audio_locator_tag, + shuffle_shards=config.shuffle, + shard_seed=config.shard_seed, + ) + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts + + def attach_tags(cut, tags: dict): for key, val in tags.items(): setattr(cut, key, val) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 7c42767fd7b3..14379c328ab5 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -41,7 +41,12 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config -from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample +from nemo.collections.common.data.lhotse.text_adapters import ( + NeMoMultimodalConversation, + NeMoSFTExample, + SourceTargetTextExample, + TextExample, +) from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.utils import logging @@ -736,6 +741,8 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa example.tokenized_prompted_transcript = tokenized_prompted_transcript example.tokenized_prompt = tokenized_prompt example.tokenized_transcript = tokenized_transcript + elif isinstance(example, NeMoMultimodalConversation): + example = example.tokenize(tokenizer, prompt_format) else: raise RuntimeError(f"Currently we only support tokenization + prompting during sampling for audio modality.") return example diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 3d1138d427f2..79c54b89bc0d 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -15,16 +15,21 @@ import copy import random from dataclasses import dataclass +from itertools import groupby from pathlib import Path from typing import Iterator, Literal, Optional, Union import numpy as np import torch +from lhotse import Recording +from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.serialization import load_jsonl from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging @@ -306,6 +311,139 @@ def __iter__(self) -> Iterator[NeMoSFTExample]: yield NeMoSFTExample(data, language=self.language) +@dataclass +class TextTurn: + value: str + role: str + + +@dataclass +class AudioTurn: + cut: Cut + role: str + audio_locator_tag: str + + +@dataclass +class NeMoMultimodalConversation: + id: str + turns: list[TextTurn | AudioTurn] + input_ids: np.ndarray | None = None + context_ids: np.ndarray | None = None + answer_ids: np.ndarray | None = None + mask: np.ndarray | None = None + + def tokenize( + self, + tokenizer: TokenizerWrapper | TokenizerSpec, + prompt: PromptFormatter = None, + ) -> "NeMoMultimodalConversation": + """ + Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). + Supports BPE tokenizers and aggregate tokenizers. + + The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. + """ + if isinstance(tokenizer, TokenizerWrapper): + tokenizer = tokenizer._tokenizer + if isinstance(tokenizer, AggregateTokenizer): + raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.") + if prompt is None: + prompt = PromptFormatter.resolve("plain")(tokenizer) + elif isinstance(prompt, str): + prompt = PromptFormatter.resolve(prompt)(tokenizer) + + # Collapse consecutive same-role turns into single turn for proper prompt formatting. + turns = groupby( + [ + { + "role": turn.role, + "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, + } + for turn in self.turns + ], + key=lambda turn: turn["role"], + ) + turns = [ + {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} + for role, turn_grp in turns + ] + ans = prompt.encode_dialog(turns) + self.input_ids = ans["input_ids"] + self.context_ids = ans["context_ids"] + self.answer_ids = ans["answer_ids"] + self.mask = ans["mask"] + + return self + + +@dataclass +class NeMoMultimodalConversationJsonlAdapter: + """ + ``NeMoMultimodalConversationJsonlAdapter`` is used to read a NeMo multimodal conversation JSONL + and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse. + + We expect the following schema (contained in a single line per example):: + + { + "id": str, + "conversations": [ + { + "value": str, # text message or path to audio + "from": "User" | "Assistant", + "type": "text" | "audio", + "duration": float, # only for audio + }, + ... + ], + } + """ + + manifest_filepath: str | list[str] + audio_locator_tag: str + tarred_audio_filepaths: str | list[str] = None + shuffle_shards: bool = False + shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" + + def __post_init__(self): + self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath) + if self.tarred_audio_filepaths is not None: + raise NotImplementedError( + "Tarred manifests are currently not supported yet for NeMoMultimodalConversation." + ) + self.tarred_audio_filepaths = expand_sharded_filepaths(self.tarred_audio_filepaths) + + def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + paths = self.manifest_filepath + if self.shuffle_shards: + seed = resolve_seed(self.shard_seed) + random.Random(seed).shuffle(paths) + for path in paths: + for data in load_jsonl(path): + yield NeMoMultimodalConversation( + id=data["id"], + turns=[ + ( + TextTurn( + value=turn["value"], + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + ) + if turn["type"] == "text" + else AudioTurn( + cut=Recording.from_file(get_full_path(turn["value"], path)).to_cut(), + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + audio_locator_tag=self.audio_locator_tag, + ) + ) + for turn in data["conversations"] + ], + ) + + """ The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py with minimal modifications in order to avoid importing the NLP collection. diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py new file mode 100644 index 000000000000..4ded7c25d12a --- /dev/null +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -0,0 +1,214 @@ +import json +from itertools import islice + +import lhotse +import pytest +import torch +from lhotse.testing.dummies import dummy_cut, dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model + + +class Identity(torch.utils.data.Dataset): + def __getitem__(self, cuts: lhotse.CutSet) -> lhotse.CutSet: + return cuts + + +@pytest.fixture(scope="session") +def multimodal_conversations_path(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("text_data") + en_path = tmp_path / "manifest.json" + data = [ + { + "id": "convo_1", + "conversations": [ + { + "value": "Can you help summarize the following?", + "from": "User", + "type": "text", + }, + { + "value": "123.wav", + "from": "User", + "type": "audio", + "duration": 5.73, + }, + { + "value": "I'm glad to assist you with your request. Here's a summary:", + "from": "Assistant", + "type": "text", + }, + { + "value": "123_answer.wav", + "from": "Assistant", + "type": "audio", + "duration": 7.11, + }, + { + "value": "Can you further shorten it?", + "from": "User", + "type": "text", + }, + { + "value": "Of course!", + "from": "Assistant", + "type": "text", + }, + ], + } + ] + lhotse.serialization.save_to_jsonl(data, en_path) + dummy_recording(0, 5.73, with_data=True).to_cut().save_audio(tmp_path / "123.wav") + dummy_recording(0, 7.11, with_data=True).to_cut().save_audio(tmp_path / "123_answer.wav") + return en_path + + +def test_multimodal_conversation_input(multimodal_conversations_path): + + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": multimodal_conversations_path, + "audio_locator_tag": "[audio]", + }, + ], + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # Note: this test does not need to pass a tokenizer because we use static batch sizes + dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity()) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, NeMoMultimodalConversation) + assert ex.id == "convo_1" + assert len(ex.turns) == 6 + t = ex.turns[0] + assert isinstance(t, TextTurn) + assert t.role == "user" + assert t.value == "Can you help summarize the following?" + t = ex.turns[1] + assert isinstance(t, AudioTurn) + assert t.role == "user" + assert t.audio_locator_tag == "[audio]" + assert t.cut.duration == 5.73 + assert t.cut.load_audio().shape == (1, 91680) + t = ex.turns[2] + assert isinstance(t, TextTurn) + assert t.role == "assistant" + assert t.value == "I'm glad to assist you with your request. Here's a summary:" + t = ex.turns[3] + assert isinstance(t, AudioTurn) + assert t.role == "assistant" + assert t.audio_locator_tag == "[audio]" + assert t.cut.duration == 7.11 + assert t.cut.load_audio().shape == (1, 113760) + t = ex.turns[4] + assert isinstance(t, TextTurn) + assert t.role == "user" + assert t.value == "Can you further shorten it?" + t = ex.turns[5] + assert isinstance(t, TextTurn) + assert t.role == "assistant" + assert t.value == "Of course!" + for key in ("input_ids", "context_ids", "answer_ids", "mask"): + assert getattr(ex, key) is None # not tokenized/prompted + + +@pytest.fixture +def tokenizer(tmp_path_factory, multimodal_conversations_path): + tmpdir = tmp_path_factory.mktemp("multi_convo_tokenizer") + text_path = tmpdir / "text.txt" + text_path.write_text( + "\n".join( + turn["value"] + for item in lhotse.serialization.load_jsonl(multimodal_conversations_path) + for turn in item["conversations"] + ) + ) + create_spt_model( + text_path, + vocab_size=128, + sample_size=-1, + do_lower_case=False, + output_dir=str(tmpdir), + bos=True, + eos=True, + user_defined_symbols=["[INST]", "[/INST]", "<>", "<>", "[audio]"], + ) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path, tokenizer): + + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": multimodal_conversations_path, + "audio_locator_tag": "[audio]", + }, + ], + "prompt_format": "llama2", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, NeMoMultimodalConversation) + + assert torch.is_tensor(ex.input_ids) + assert ex.input_ids.shape == (105,) + assert ( + tokenizer.ids_to_text(ex.input_ids) + == "[INST] Can you help summarize the following? [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST] Of course!" + ) + + assert torch.is_tensor(ex.context_ids) + assert ex.context_ids.shape == (95,) + assert ( + tokenizer.ids_to_text(ex.context_ids) + == "[INST] Can you help summarize the following? [audio] [/INST] I'm glad to assist you with your request. Here's a summary: [audio] [INST] Can you further shorten it? [/INST]" + ) + + assert torch.is_tensor(ex.answer_ids) + assert ex.answer_ids.shape == (10,) + assert tokenizer.ids_to_text(ex.answer_ids) == "Of course!" + + assert torch.is_tensor(ex.mask) + assert ex.mask.shape == (105,) + assert (ex.mask[:30] == False).all() # user turn + assert (ex.mask[30:72] == True).all() # assistant turn + assert (ex.mask[72:95] == False).all() # user turn + assert (ex.mask[95:] == True).all() # assistant turn