diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 79c54b89bc0d..9e64b37bffef 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -14,7 +14,8 @@ import copy import random -from dataclasses import dataclass +from collections import deque +from dataclasses import asdict, dataclass from itertools import groupby from pathlib import Path from typing import Iterator, Literal, Optional, Union @@ -25,7 +26,8 @@ from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.serialization import load_jsonl -from lhotse.utils import Pathlike +from lhotse.shar import AudioTarWriter, JsonlShardWriter, TarIterator, TarWriter +from lhotse.utils import Pathlike, asdict_nonull, is_valid_url from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths from nemo.collections.common.parts.preprocessing.manifest import get_full_path @@ -316,6 +318,9 @@ class TextTurn: value: str role: str + def to_dict(self): + return {"type": "text", "from": self.role.title(), "value": self.value} + @dataclass class AudioTurn: @@ -323,6 +328,18 @@ class AudioTurn: role: str audio_locator_tag: str + def to_dict(self): + assert self.cut.has_recording and self.cut.recording.sources[0].type not in { + "shar", + "memory", + }, "Cannot serialize AudioTurn to dict because it doesn't reference an audio file (the audio is stored in memory)." + return { + "type": "audio", + "from": self.role.title(), + "duration": self.cut.duration, + "value": self.cut.recording.sources[0].source, + } + @dataclass class NeMoMultimodalConversation: @@ -376,6 +393,12 @@ def tokenize( return self + def to_dict(self): + return {"id": self.id, "conversations": [t.to_dict() for t in self.turns]} + + def list_cuts(self) -> list[Cut]: + return [turn.cut for turn in self.turns if isinstance(turn, AudioTurn)] + @dataclass class NeMoMultimodalConversationJsonlAdapter: @@ -408,12 +431,63 @@ class NeMoMultimodalConversationJsonlAdapter: def __post_init__(self): self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath) if self.tarred_audio_filepaths is not None: - raise NotImplementedError( - "Tarred manifests are currently not supported yet for NeMoMultimodalConversation." - ) self.tarred_audio_filepaths = expand_sharded_filepaths(self.tarred_audio_filepaths) + assert len(self.manifest_filepath) == len( + self.tarred_audio_filepaths + ), f"{len(self.manifest_filepath)} != {len(self.tarred_audio_filepaths)}" def __iter__(self) -> Iterator[NeMoMultimodalConversation]: + if self.tarred_audio_filepaths is not None: + yield from self._iter_tar() + else: + yield from self._iter_jsonl() + + def _iter_tar(self): + paths = list(zip(self.manifest_filepath, self.tarred_audio_filepaths)) + if self.shuffle_shards: + seed = resolve_seed(self.shard_seed) + random.Random(seed).shuffle(paths) + for jsonl_path, tar_path in paths: + tar = iter(TarIterator(tar_path)) + for data in load_jsonl(jsonl_path): + audio_turns = [t for t in data["conversations"] if t["type"] == "audio"] + cuts = [] + for turn in audio_turns: + recording, audio_path = next(tar) + audio_path = str(audio_path) + cut = recording.to_cut() + assert ( + audio_path == turn['value'] + ), f"Mismatch between JSONL and tar. JSONL defines audio path={turn['value']} but we got the following from tar {audio_path=}" + assert ( + cut.duration == turn["duration"] + ), f"Mismatch between JSONL and tar. JSONL defines audio duration={turn['duration']} but we got the following from tar {cut.duration=}" + cuts.append(cut) + cuts = deque(cuts) + yield NeMoMultimodalConversation( + id=data["id"], + turns=[ + ( + TextTurn( + value=turn["value"], + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + ) + if turn["type"] == "text" + else AudioTurn( + cut=cuts.popleft(), + role=turn[ + "from" + ].lower(), # prompt formatter role's are typically lowercase: user/assistant + audio_locator_tag=self.audio_locator_tag, + ) + ) + for turn in data["conversations"] + ], + ) + + def _iter_jsonl(self): paths = self.manifest_filepath if self.shuffle_shards: seed = resolve_seed(self.shard_seed) @@ -444,6 +518,57 @@ def __iter__(self) -> Iterator[NeMoMultimodalConversation]: ) +class NeMoMultimodalConversationTarWriter: + def __init__(self, output_dir: str, shard_size: int = 100): + self.output_dir = output_dir + self.shard_size = shard_size + self._reset() + self._setup_writers() + + def write(self, example: NeMoMultimodalConversation): + self._maybe_increment_shard() + serialized = example.to_dict() + for turn in serialized["conversations"]: + if turn["type"] == "audio": + turn["value"] = Path(turn["value"]).with_suffix(".flac").name + self.manifest_writer.write(serialized) + for cut in example.list_cuts(): + assert ( + cut.has_recording + ), f"Cannot serialize multimodal conversation with cuts that have no recordings. We got: {cut}" + self.tar_writer.write(cut.recording.id, cut.load_audio(), cut.sampling_rate, cut.recording) + self.item_cntr += 1 + + def close(self): + self.manifest_writer.close() + self.tar_writer.close() + + def __enter__(self): + self._reset() + self.manifest_writer.__enter__() + self.tar_writer.__enter__() + return self + + def __exit__(self, *args, **kwargs): + self.close() + + def _maybe_increment_shard(self): + if self.item_cntr > 0 and self.item_cntr % self.shard_size == 0: + self.item_cntr = 0 + self.shard_idx += 1 + self._setup_writers() + + def _reset(self): + self.item_cntr = 0 + self.shard_idx = 0 + + def _setup_writers(self): + if not is_valid_url(self.output_dir): # skip dir creation for URLs + Path(self.output_dir).mkdir(exist_ok=True) + self.manifest_writer = JsonlShardWriter(f"{self.output_dir}/manifest_{self.shard_idx}.jsonl", shard_size=None) + self.tar_writer = AudioTarWriter(f"{self.output_dir}/audio_{self.shard_idx}.tar", shard_size=None) + + """ The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py with minimal modifications in order to avoid importing the NLP collection. diff --git a/scripts/speech_llm/export_conversations_to_tar.py b/scripts/speech_llm/export_conversations_to_tar.py new file mode 100644 index 000000000000..d8ee6c72a6c4 --- /dev/null +++ b/scripts/speech_llm/export_conversations_to_tar.py @@ -0,0 +1,28 @@ +from random import Random + +import click +from lhotse import CutSet + +from nemo.collections.common.data.lhotse.text_adapters import ( + NeMoMultimodalConversationJsonlAdapter, + NeMoMultimodalConversationTarWriter, +) + + +@click.command() +@click.argument("manifest", type=click.Path()) +@click.argument("output_dir", type=click.Path()) +@click.option("-n", "--shard_size", type=int, default=100, help="Number of conversations per shard.") +@click.option("--shuffle/--no-shuffle", default=False, help="Shuffle conversations.") +@click.option("-s", "--seed", type=int, default=42, help="Random seed.") +def export(manifest: str, output_dir: str, shard_size: int, shuffle: bool, seed: int): + with NeMoMultimodalConversationTarWriter(output_dir, shard_size=shard_size) as writer: + source = NeMoMultimodalConversationJsonlAdapter(manifest, audio_locator_tag="") + if shuffle: + source = CutSet(source).shuffle(buffer_size=50000, rng=Random(seed)) + for item in source: + writer.write(item) + + +if __name__ == '__main__': + export() diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index 4ded7c25d12a..f53a45a72971 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -2,13 +2,20 @@ from itertools import islice import lhotse +import numpy as np import pytest import torch from lhotse.testing.dummies import dummy_cut, dummy_recording from omegaconf import OmegaConf from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config -from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn +from nemo.collections.common.data.lhotse.text_adapters import ( + AudioTurn, + NeMoMultimodalConversation, + NeMoMultimodalConversationJsonlAdapter, + NeMoMultimodalConversationTarWriter, + TextTurn, +) from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model @@ -212,3 +219,46 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path assert (ex.mask[30:72] == True).all() # assistant turn assert (ex.mask[72:95] == False).all() # user turn assert (ex.mask[95:] == True).all() # assistant turn + + +def test_multimodal_conversation_tarred_format(multimodal_conversations_path, tmp_path_factory): + (conversation,) = list(NeMoMultimodalConversationJsonlAdapter(multimodal_conversations_path, "[audio]")) + tar_dir = tmp_path_factory.mktemp("multi_convo_tarred") + with NeMoMultimodalConversationTarWriter(tar_dir) as writer: + writer.write(conversation) + + (restored_conversation,) = list( + NeMoMultimodalConversationJsonlAdapter( + manifest_filepath=tar_dir / "manifest_0.jsonl", + audio_locator_tag="[audio]", + tarred_audio_filepaths=tar_dir / "audio_0.tar", + ) + ) + assert conversation.id == restored_conversation.id + assert len(conversation.turns) == len(restored_conversation.turns) + for lhs, rhs in zip(conversation.turns, restored_conversation.turns): + assert type(lhs) == type(rhs) + assert lhs.role == lhs.role + if isinstance(lhs, TextTurn): + assert lhs.value == rhs.value + else: + assert lhs.audio_locator_tag == rhs.audio_locator_tag + assert lhs.cut.id == rhs.cut.id + np.testing.assert_allclose(lhs.cut.load_audio(), rhs.cut.load_audio()) + + +def test_multimodal_conversation_tarred_format_sharding_works(multimodal_conversations_path, tmp_path_factory): + (conversation,) = list(NeMoMultimodalConversationJsonlAdapter(multimodal_conversations_path, "[audio]")) + tar_dir = tmp_path_factory.mktemp("multi_convo_tarred") + with NeMoMultimodalConversationTarWriter(tar_dir, shard_size=10) as writer: + for i in range(30): + writer.write(conversation) + + loader = NeMoMultimodalConversationJsonlAdapter( + manifest_filepath=tar_dir / "manifest_{0..2}.jsonl", + audio_locator_tag="[audio]", + tarred_audio_filepaths=tar_dir / "audio_{0..2}.tar", + ) + restored = list(loader) + assert len(restored) == 30 + assert all(c == restored[0] for c in restored[1:])