Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.collections.common.data.lhotse.text_adapters import (
LhotseTextAdapter,
LhotseTextPairAdapter,
NeMoMultimodalConversationJsonlAdapter,
NeMoSFTJsonlAdapter,
)
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
Expand Down Expand Up @@ -62,7 +63,17 @@ def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]:


KNOWN_DATASET_CONFIG_TYPES = frozenset(
("nemo", "nemo_tarred", "lhotse", "lhotse_shar", "txt", "txt_pair", "nemo_sft_jsonl", "group")
(
"nemo",
"nemo_tarred",
"lhotse",
"lhotse_shar",
"txt",
"txt_pair",
"nemo_sft_jsonl",
"multimodal_conversation",
"group",
)
)


Expand Down Expand Up @@ -173,6 +184,9 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]:
elif grp_cfg.type == "nemo_sft_jsonl":
is_tarred = True
cuts = read_nemo_sft_jsonl(grp_cfg)
elif grp_cfg.type == "multimodal_conversation":
is_tarred = True
cuts = read_multimodal_conversation_jsonl(grp_cfg)
elif grp_cfg.type == "group":
cuts, is_tarred = parse_and_combine_datasets(
grp_cfg.input_cfg,
Expand Down Expand Up @@ -223,6 +237,21 @@ def read_nemo_sft_jsonl(config: DictConfig) -> CutSet:
).repeat()


def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet:
cuts = CutSet(
NeMoMultimodalConversationJsonlAdapter(
manifest_filepath=config.manifest_filepath,
tarred_audio_filepaths=config.get("tarred_audio_filepaths"),
audio_locator_tag=config.audio_locator_tag,
shuffle_shards=config.shuffle,
shard_seed=config.shard_seed,
)
)
if not config.get("force_finite", False):
cuts = cuts.repeat()
return cuts


def attach_tags(cut, tags: dict):
for key, val in tags.items():
setattr(cut, key, val)
Expand Down
9 changes: 8 additions & 1 deletion nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
from omegaconf import DictConfig, ListConfig, OmegaConf

from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config
from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample
from nemo.collections.common.data.lhotse.text_adapters import (
NeMoMultimodalConversation,
NeMoSFTExample,
SourceTargetTextExample,
TextExample,
)
from nemo.collections.common.prompts.fn import get_prompt_format_fn
from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper
from nemo.utils import logging
Expand Down Expand Up @@ -736,6 +741,8 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa
example.tokenized_prompted_transcript = tokenized_prompted_transcript
example.tokenized_prompt = tokenized_prompt
example.tokenized_transcript = tokenized_transcript
elif isinstance(example, NeMoMultimodalConversation):
example = example.tokenize(tokenizer, prompt_format)
else:
raise RuntimeError(f"Currently we only support tokenization + prompting during sampling for audio modality.")
return example
Expand Down
138 changes: 138 additions & 0 deletions nemo/collections/common/data/lhotse/text_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,21 @@
import copy
import random
from dataclasses import dataclass
from itertools import groupby
from pathlib import Path
from typing import Iterator, Literal, Optional, Union

import numpy as np
import torch
from lhotse import Recording
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.serialization import load_jsonl
from lhotse.utils import Pathlike

from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.collections.common.prompts import PromptFormatter
from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.utils import logging
Expand Down Expand Up @@ -306,6 +311,139 @@ def __iter__(self) -> Iterator[NeMoSFTExample]:
yield NeMoSFTExample(data, language=self.language)


@dataclass
class TextTurn:
value: str
role: str


@dataclass
class AudioTurn:
cut: Cut
role: str
audio_locator_tag: str


@dataclass
class NeMoMultimodalConversation:
id: str
turns: list[TextTurn | AudioTurn]
input_ids: np.ndarray | None = None
context_ids: np.ndarray | None = None
answer_ids: np.ndarray | None = None
mask: np.ndarray | None = None

def tokenize(
self,
tokenizer: TokenizerWrapper | TokenizerSpec,
prompt: PromptFormatter = None,
) -> "NeMoMultimodalConversation":
"""
Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields).
Supports BPE tokenizers and aggregate tokenizers.

The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`.
"""
if isinstance(tokenizer, TokenizerWrapper):
tokenizer = tokenizer._tokenizer
if isinstance(tokenizer, AggregateTokenizer):
raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.")
if prompt is None:
prompt = PromptFormatter.resolve("plain")(tokenizer)
elif isinstance(prompt, str):
prompt = PromptFormatter.resolve(prompt)(tokenizer)

# Collapse consecutive same-role turns into single turn for proper prompt formatting.
turns = groupby(
[
{
"role": turn.role,
"slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag},
}
for turn in self.turns
],
key=lambda turn: turn["role"],
)
turns = [
{"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}}
for role, turn_grp in turns
]
ans = prompt.encode_dialog(turns)
self.input_ids = ans["input_ids"]
self.context_ids = ans["context_ids"]
self.answer_ids = ans["answer_ids"]
self.mask = ans["mask"]

return self


@dataclass
class NeMoMultimodalConversationJsonlAdapter:
"""
``NeMoMultimodalConversationJsonlAdapter`` is used to read a NeMo multimodal conversation JSONL
and yield objects of type ``NeMoMultimodalConversation`` that can be sampled with Lhotse.

We expect the following schema (contained in a single line per example)::

{
"id": str,
"conversations": [
{
"value": str, # text message or path to audio
"from": "User" | "Assistant",
"type": "text" | "audio",
"duration": float, # only for audio
},
...
],
}
"""

manifest_filepath: str | list[str]
audio_locator_tag: str
tarred_audio_filepaths: str | list[str] = None
shuffle_shards: bool = False
shard_seed: Union[int, Literal["trng", "randomized"]] = "trng"

def __post_init__(self):
self.manifest_filepath = expand_sharded_filepaths(self.manifest_filepath)
if self.tarred_audio_filepaths is not None:
raise NotImplementedError(
"Tarred manifests are currently not supported yet for NeMoMultimodalConversation."
)
self.tarred_audio_filepaths = expand_sharded_filepaths(self.tarred_audio_filepaths)

def __iter__(self) -> Iterator[NeMoMultimodalConversation]:
paths = self.manifest_filepath
if self.shuffle_shards:
seed = resolve_seed(self.shard_seed)
random.Random(seed).shuffle(paths)
for path in paths:
for data in load_jsonl(path):
yield NeMoMultimodalConversation(
id=data["id"],
turns=[
(
TextTurn(
value=turn["value"],
role=turn[
"from"
].lower(), # prompt formatter role's are typically lowercase: user/assistant
)
if turn["type"] == "text"
else AudioTurn(
cut=Recording.from_file(get_full_path(turn["value"], path)).to_cut(),
role=turn[
"from"
].lower(), # prompt formatter role's are typically lowercase: user/assistant
audio_locator_tag=self.audio_locator_tag,
)
)
for turn in data["conversations"]
],
)


"""
The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py
with minimal modifications in order to avoid importing the NLP collection.
Expand Down
Loading