Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 130 additions & 5 deletions nemo/collections/common/data/lhotse/text_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import copy
import random
from dataclasses import dataclass
from collections import deque
from dataclasses import asdict, dataclass
from itertools import groupby
from pathlib import Path
from typing import Iterator, Literal, Optional, Union
Expand All @@ -25,7 +26,8 @@
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.serialization import load_jsonl
from lhotse.utils import Pathlike
from lhotse.shar import AudioTarWriter, JsonlShardWriter, TarIterator, TarWriter
from lhotse.utils import Pathlike, asdict_nonull, is_valid_url

from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand Down Expand Up @@ -316,13 +318,28 @@ class TextTurn:
value: str
role: str

def to_dict(self):
return {"type": "text", "from": self.role.title(), "value": self.value}


@dataclass
class AudioTurn:
cut: Cut
role: str
audio_locator_tag: str

def to_dict(self):
assert self.cut.has_recording and self.cut.recording.sources[0].type not in {
"shar",
"memory",
}, "Cannot serialize AudioTurn to dict because it doesn't reference an audio file (the audio is stored in memory)."
return {
"type": "audio",
"from": self.role.title(),
"duration": self.cut.duration,
"value": self.cut.recording.sources[0].source,
}


@dataclass
class NeMoMultimodalConversation:
Expand Down Expand Up @@ -376,6 +393,12 @@ def tokenize(

return self

def to_dict(self):
return {"id": self.id, "conversations": [t.to_dict() for t in self.turns]}

def list_cuts(self) -> list[Cut]:
return [turn.cut for turn in self.turns if isinstance(turn, AudioTurn)]


@dataclass
class NeMoMultimodalConversationJsonlAdapter:
Expand Down Expand Up @@ -408,12 +431,63 @@ class NeMoMultimodalConversationJsonlAdapter:
def __post_init__(self):
self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath)
if self.tarred_audio_filepaths is not None:
raise NotImplementedError(
"Tarred manifests are currently not supported yet for NeMoMultimodalConversation."
)
self.tarred_audio_filepaths = expand_sharded_filepaths(self.tarred_audio_filepaths)
assert len(self.manifest_filepath) == len(
self.tarred_audio_filepaths
), f"{len(self.manifest_filepath)} != {len(self.tarred_audio_filepaths)}"

def __iter__(self) -> Iterator[NeMoMultimodalConversation]:
if self.tarred_audio_filepaths is not None:
yield from self._iter_tar()
else:
yield from self._iter_jsonl()

def _iter_tar(self):
paths = list(zip(self.manifest_filepath, self.tarred_audio_filepaths))
if self.shuffle_shards:
seed = resolve_seed(self.shard_seed)
random.Random(seed).shuffle(paths)
for jsonl_path, tar_path in paths:
tar = iter(TarIterator(tar_path))
for data in load_jsonl(jsonl_path):
audio_turns = [t for t in data["conversations"] if t["type"] == "audio"]
cuts = []
for turn in audio_turns:
recording, audio_path = next(tar)
audio_path = str(audio_path)
cut = recording.to_cut()
assert (
audio_path == turn['value']
), f"Mismatch between JSONL and tar. JSONL defines audio path={turn['value']} but we got the following from tar {audio_path=}"
assert (
cut.duration == turn["duration"]
), f"Mismatch between JSONL and tar. JSONL defines audio duration={turn['duration']} but we got the following from tar {cut.duration=}"
cuts.append(cut)
cuts = deque(cuts)
yield NeMoMultimodalConversation(
id=data["id"],
turns=[
(
TextTurn(
value=turn["value"],
role=turn[
"from"
].lower(), # prompt formatter role's are typically lowercase: user/assistant
)
if turn["type"] == "text"
else AudioTurn(
cut=cuts.popleft(),
role=turn[
"from"
].lower(), # prompt formatter role's are typically lowercase: user/assistant
audio_locator_tag=self.audio_locator_tag,
)
)
for turn in data["conversations"]
],
)

def _iter_jsonl(self):
paths = self.manifest_filepath
if self.shuffle_shards:
seed = resolve_seed(self.shard_seed)
Expand Down Expand Up @@ -444,6 +518,57 @@ def __iter__(self) -> Iterator[NeMoMultimodalConversation]:
)


class NeMoMultimodalConversationTarWriter:
def __init__(self, output_dir: str, shard_size: int = 100):
self.output_dir = output_dir
self.shard_size = shard_size
self._reset()
self._setup_writers()

def write(self, example: NeMoMultimodalConversation):
self._maybe_increment_shard()
serialized = example.to_dict()
for turn in serialized["conversations"]:
if turn["type"] == "audio":
turn["value"] = Path(turn["value"]).with_suffix(".flac").name
self.manifest_writer.write(serialized)
for cut in example.list_cuts():
assert (
cut.has_recording
), f"Cannot serialize multimodal conversation with cuts that have no recordings. We got: {cut}"
self.tar_writer.write(cut.recording.id, cut.load_audio(), cut.sampling_rate, cut.recording)
self.item_cntr += 1

def close(self):
self.manifest_writer.close()
self.tar_writer.close()

def __enter__(self):
self._reset()
self.manifest_writer.__enter__()
self.tar_writer.__enter__()
return self

def __exit__(self, *args, **kwargs):
self.close()

def _maybe_increment_shard(self):
if self.item_cntr > 0 and self.item_cntr % self.shard_size == 0:
self.item_cntr = 0
self.shard_idx += 1
self._setup_writers()

def _reset(self):
self.item_cntr = 0
self.shard_idx = 0

def _setup_writers(self):
if not is_valid_url(self.output_dir): # skip dir creation for URLs
Path(self.output_dir).mkdir(exist_ok=True)
self.manifest_writer = JsonlShardWriter(f"{self.output_dir}/manifest_{self.shard_idx}.jsonl", shard_size=None)
self.tar_writer = AudioTarWriter(f"{self.output_dir}/audio_{self.shard_idx}.tar", shard_size=None)


"""
The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py
with minimal modifications in order to avoid importing the NLP collection.
Expand Down
28 changes: 28 additions & 0 deletions scripts/speech_llm/export_conversations_to_tar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from random import Random

import click
from lhotse import CutSet

from nemo.collections.common.data.lhotse.text_adapters import (
NeMoMultimodalConversationJsonlAdapter,
NeMoMultimodalConversationTarWriter,
)


@click.command()
@click.argument("manifest", type=click.Path())
@click.argument("output_dir", type=click.Path())
@click.option("-n", "--shard_size", type=int, default=100, help="Number of conversations per shard.")
@click.option("--shuffle/--no-shuffle", default=False, help="Shuffle conversations.")
@click.option("-s", "--seed", type=int, default=42, help="Random seed.")
def export(manifest: str, output_dir: str, shard_size: int, shuffle: bool, seed: int):
with NeMoMultimodalConversationTarWriter(output_dir, shard_size=shard_size) as writer:
Copy link
Collaborator

@zhehuaichen zhehuaichen Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this tarring support out of the box on speech-to-speech case where there is speech in the assistant turns? if not, can u document what may be missing?
FYI @subhankar-ghosh

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will support this out of the box.

source = NeMoMultimodalConversationJsonlAdapter(manifest, audio_locator_tag="<dummy>")
if shuffle:
source = CutSet(source).shuffle(buffer_size=50000, rng=Random(seed))
for item in source:
writer.write(item)


if __name__ == '__main__':
export()
52 changes: 51 additions & 1 deletion tests/collections/common/test_lhotse_multimodal_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
from itertools import islice

import lhotse
import numpy as np
import pytest
import torch
from lhotse.testing.dummies import dummy_cut, dummy_recording
from omegaconf import OmegaConf

from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn
from nemo.collections.common.data.lhotse.text_adapters import (
AudioTurn,
NeMoMultimodalConversation,
NeMoMultimodalConversationJsonlAdapter,
NeMoMultimodalConversationTarWriter,
TextTurn,
)
from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model

Expand Down Expand Up @@ -212,3 +219,46 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path
assert (ex.mask[30:72] == True).all() # assistant turn
assert (ex.mask[72:95] == False).all() # user turn
assert (ex.mask[95:] == True).all() # assistant turn


def test_multimodal_conversation_tarred_format(multimodal_conversations_path, tmp_path_factory):
(conversation,) = list(NeMoMultimodalConversationJsonlAdapter(multimodal_conversations_path, "[audio]"))
tar_dir = tmp_path_factory.mktemp("multi_convo_tarred")
with NeMoMultimodalConversationTarWriter(tar_dir) as writer:
writer.write(conversation)

(restored_conversation,) = list(
NeMoMultimodalConversationJsonlAdapter(
manifest_filepath=tar_dir / "manifest_0.jsonl",
audio_locator_tag="[audio]",
tarred_audio_filepaths=tar_dir / "audio_0.tar",
)
)
assert conversation.id == restored_conversation.id
assert len(conversation.turns) == len(restored_conversation.turns)
for lhs, rhs in zip(conversation.turns, restored_conversation.turns):
assert type(lhs) == type(rhs)
assert lhs.role == lhs.role
if isinstance(lhs, TextTurn):
assert lhs.value == rhs.value
else:
assert lhs.audio_locator_tag == rhs.audio_locator_tag
assert lhs.cut.id == rhs.cut.id
np.testing.assert_allclose(lhs.cut.load_audio(), rhs.cut.load_audio())


def test_multimodal_conversation_tarred_format_sharding_works(multimodal_conversations_path, tmp_path_factory):
(conversation,) = list(NeMoMultimodalConversationJsonlAdapter(multimodal_conversations_path, "[audio]"))
tar_dir = tmp_path_factory.mktemp("multi_convo_tarred")
with NeMoMultimodalConversationTarWriter(tar_dir, shard_size=10) as writer:
for i in range(30):
writer.write(conversation)

loader = NeMoMultimodalConversationJsonlAdapter(
manifest_filepath=tar_dir / "manifest_{0..2}.jsonl",
audio_locator_tag="[audio]",
tarred_audio_filepaths=tar_dir / "audio_{0..2}.tar",
)
restored = list(loader)
assert len(restored) == 30
assert all(c == restored[0] for c in restored[1:])