From 38651d0679d2d8227b3b1d0ff784f3285f2eaa4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 16 Aug 2024 05:48:30 -0700 Subject: [PATCH 01/63] Preliminary support for oomptimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_models.py | 54 +++++++++++++++++++ scripts/speech_recognition/oomptimizer.py | 12 ++++- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 6ef434929f58..43d4cb32309f 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -55,6 +55,7 @@ from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import adapter_mixins +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import AppState, logging, model_utils from nemo.utils.model_utils import inject_model_parallel_rank @@ -1705,6 +1706,59 @@ def find_frozen_submodules(model): self.perception = self.trainer.strategy._setup_model(self.perception) self.perception = self.perception.cuda(torch.cuda.current_device()) + @property + def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + + # TODO: add support for text + # input_ids = text_batch["text_input_ids"][:, :-1] + # labels = text_batch["text_input_ids"][:, 1:] + # attention_mask = self._create_attention_mask(input_ids) + # loss_mask = text_batch["text_masks"][:, 1:] + + return { + "cls": dict, + "inputs": [ + {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "tokens", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "tokens_length", + "type": NeuralType(("B",), LengthsType()), + "seq_length": "output", + }, + { + "name": "labels", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "loss_mask", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "output", + }, + { + "name": "num_audios", + "type": "constant", + "value": "batch", + }, + { + "name": "context_start_idx", + "type": "constant", + "value": 0, + }, + ], + } + class CrossAttendModularAudioGPTModel(ModularAudioGPTModel): """Modularized speech GPT model.""" diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 165ac5ac692d..22b08a9cebee 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -12,7 +12,7 @@ from omegaconf import OmegaConf from nemo.collections.asr.models.asr_model import ASRModel -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import logging @@ -111,7 +111,12 @@ def __call__(self, input_seq_length: int, output_seq_length: int): names = [] for item in self.schema["inputs"]: nt = item["type"] - if not isinstance(nt, NeuralType): # placeholder + if isinstance(nt, str) and nt == "constant": + if isinstance(val := item["value"], str) and val == "batch": + tnsr = torch.tensor([B], dtype=torch.long, device=self.device) + else: + tnsr = torch.tensor([val], dtype=torch.long, device=self.device) + elif not isinstance(nt, NeuralType): # placeholder tnsr = torch.tensor([]) elif isinstance(nt.elements_type, AudioSignal): seq_length = select_seq_length[item["seq_length"]] @@ -122,6 +127,9 @@ def __call__(self, input_seq_length: int, output_seq_length: int): elif isinstance(nt.elements_type, LabelsType): seq_length = select_seq_length[item["seq_length"]] tnsr = torch.randint(0, item["vocab_size"], size=(B, seq_length), device=self.device) + elif isinstance(nt.elements_type, MaskType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, seq_length, device=self.device) else: raise RuntimeError("Unexpected item in oomptimizer schema: {item}") batch.append(tnsr) From a1754beb03689105c3e6bf78a1397d5f98769f0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 09:14:30 -0700 Subject: [PATCH 02/63] OOMptimizer for SpeechLLM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_models.py | 5 - .../oomptimizer-speechllm.py | 550 ++++++++++++++++++ 2 files changed, 550 insertions(+), 5 deletions(-) create mode 100755 scripts/speech_recognition/oomptimizer-speechllm.py diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 43d4cb32309f..3b522a3695a6 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -1746,11 +1746,6 @@ def oomptimizer_schema(self) -> dict: "type": NeuralType(("B", "T"), MaskType()), "seq_length": "output", }, - { - "name": "num_audios", - "type": "constant", - "value": "batch", - }, { "name": "context_start_idx", "type": "constant", diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py new file mode 100755 index 000000000000..16e5ca0816a4 --- /dev/null +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python +import importlib +import math +import sys +from numbers import Number +from typing import Iterable, Literal + +import click +import pytorch_lightning as pl +import torch +from lhotse import compute_num_samples +from omegaconf import OmegaConf + +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType +from nemo.utils import logging + + +class ProfilingBatchGenerator: + """ + ProfilingBatchGenerator is used to generate artificial mini-batches for model training + and tracking the progress of batch size optimization. + + The high-level usage API is the following:: + + >>> gen = ProfilingBatchGenerator(schema) + ... finished = False + ... while not finished: + ... batch = gen(input_seq_len, output_seq_len) + ... try: + ... training_step(model, batch) + ... oom = False + ... except torch.cuda.OutOfMemoryError: + ... oom = True + ... finished = gen.advance(oom) + ... solution = gen.max_batch_size # The solution of the search problem. + ... gen.reset() # Can re-use for other sequence lengths now. + + The search terminates once the difference between max working batch size and min OOM batch size + divided by the latter is smaller than ``rel_gap_thresh`` that difference amounts to a single element. + For example, a max working batch size is 96 and min OOM batch size is 100 indicates a gap of 0.04, + which would terminate the search with threshold of 0.05. + + In order to generate mini-batches compatible with a given model, the generator: + + * accepts a ``schema`` argument in its constructor, and + + * accepts input/output sequence lengths in each call to generate a mini-batch. + + ``schema`` has the following structure:: + + + >>> { + ... "cls": tuple | MyBatchType, + ... "inputs": [ + ... { + ... "type": NeuralType(...) | Literal["dummy"], + ... "seq_length": Literal["input", "output"], + ... "vocab_size": int, # optional, required only for LabelsType + ... "name": str, # optional, indicates kwarg + ... }, + ... ..., + ... ] + ... } + + ``cls`` indicates how we should construct the mini-batch. Typically you can just use ``tuple`` for most + batch schemas. However, if the model expects a specific, e.g., dataclass, you can tell ``ProfilingBatchGenerator`` + to use it. The mini-batch object will be constructed using the items in ``inputs``. + + Each element of ``inputs`` specifies a NeMo NeuralType which needs to have a defined ``elements_type``. + The supported types are ``AudioSignal``, ``LengthsType`` and ``LabelsType``. + If "type" is not a NeuralType, we interpret that as a placeholder tensor that's not relevant but expected + by the model/batch constructor. In addition, ``"seq_length"`` key is used to determine whether we should apply + input or output sequence length to a given tensor. + + Optional keys: + + * ``vocab_size`` is required for ``LabelsType`` so that we can generate proper label values. + + * ``name`` is required if objects of ``cls`` have to be constructed using keyword arguments. + + A simple schema example for a model using audio/lengths tensor pair (unsupervised/self-supervised):: + + >>> { + ... "cls": tuple, + ... "inputs": [ + ... {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + ... {"type": NeuralType(("B"), LengthsType()), "seq_length": "input"}, + ... ] + ... } + + """ + + def __init__( + self, + schema: dict, + start_batch_size: int = 32, + rel_gap_thresh: float = 0.05, + device: str = "cuda", + ): + self.schema = schema + self.start_batch_size = start_batch_size + self.rel_gap_thresh = rel_gap_thresh + self.device = device + self.reset() + + def __call__(self, input_seq_length: int, output_seq_length: int): + B = self._current + select_seq_length = {"input": input_seq_length, "output": output_seq_length} + batch = [] + names = [] + for item in self.schema["inputs"]: + nt = item["type"] + if isinstance(nt, str) and nt == "constant": + if isinstance(val := item["value"], str) and val == "batch": + tnsr = torch.tensor([B], dtype=torch.long, device=self.device) + else: + tnsr = torch.tensor([val], dtype=torch.long, device=self.device) + elif not isinstance(nt, NeuralType): # placeholder + tnsr = torch.tensor([]) + elif isinstance(nt.elements_type, AudioSignal): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.randn(B, seq_length, dtype=torch.float32, device=self.device) + elif isinstance(nt.elements_type, LengthsType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, dtype=torch.long, device=self.device) * seq_length + elif isinstance(nt.elements_type, MaskType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, seq_length, device=self.device) + elif isinstance(nt.elements_type, LabelsType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.randint(0, item["vocab_size"], size=(B, seq_length), device=self.device) + else: + raise RuntimeError("Unexpected item in oomptimizer schema: {item}") + batch.append(tnsr) + names.append(item.get("name")) + args = [elem for name, elem in zip(names, batch) if name is None] + kwargs = {name: elem for name, elem in zip(names, batch) if name is not None} + if not kwargs and self.schema["cls"] == tuple: + return tuple(args) + return self.schema["cls"](*args, **kwargs) + + @property + def max_batch_size(self) -> int | None: + """ + Return the solution of the batch size search problem. + It will keep returning None until the search is done. + """ + if ( + self._max_ok is not None + and self._min_err is not None + and (self.current_rel_gap <= self.rel_gap_thresh or self._min_err - self._max_ok <= 1) + ): + return self._max_ok + return None + + @property + def current_rel_gap(self) -> float | None: + """ + Return the current gap between the largest batch that works and the smallest batch that triggers OOM. + The gap is defined as the batch size difference divided by the larger element. + E.g., if the best found batch size is 95 and the smallest that triggers OOM is 100, the gap is 0.05. + """ + if self._min_err is None or self._max_ok is None: + return None + return (self._min_err - self._max_ok) / self._min_err + + def reset(self): + """Reset the generator to prepare it for a new search.""" + self._current = self.start_batch_size + self._max_ok = None # max batch size that works + self._min_err = None # min batch size that doesn't work + + def advance(self, oom: bool) -> bool: + """ + Adjusts the current batch size based on the outcome. + Returns a bool indicating whether the calibration is complete. + """ + if self.max_batch_size is not None: + return True + + if oom: + # Training step failed with OOM. + # Update the minimum known batch size that causes an error. + self._min_err = min(float("inf") if self._min_err is None else self._min_err, self._current) + # Training step failed on OOM + if self._max_ok is None: + # We haven't found a batch size that works yet, keep going 2x down. + self._current = round(self._current / 2) + else: + # Try the middle-point between the known extremes. + self._current = round((self._max_ok + self._min_err) / 2) + else: + # Training step successful. + # Update the maximum known batch size that works. + self._max_ok = max(-1 if self._max_ok is None else self._max_ok, self._current) + if self._min_err is None: + # We haven't found a batch size that causes an error yet, keep going 2x higher + self._current *= 2 + else: + # Try the middle-point between the known extremes. + self._current = round((self._max_ok + self._min_err) / 2) + + if self._current == 0: + raise RuntimeError( + "We diverged and arrived batch_size=0. Perhaps the input is too large for this model and hardware." + ) + + return False + + +class FloatList(click.Option): + """Support passing bucket duration bins as [1.1,2.5,5.6,...]""" + + name = "list[float]" + + def type_cast_value(self, ctx, value): + if isinstance(value, list) and all(isinstance(v, float) for v in value): + return value + try: + import ast + + ans = ast.literal_eval(value) + if isinstance(ans[0], list): + ans = [tuple(item) for item in ans] + return ans + except ValueError: + raise click.BadParameter(value) + + +@click.command(context_settings={'show_default': True}) +@click.option( + "-n", + "--pretrained-name", + type=str, + default=None, + help="Name of a pretrained model to use, e.g. 'nvidia/canary-1b'.", +) +@click.option( + "-m", + "--module-name", + type=str, + default=None, + help="Full path to NeMo's module corresponding to CONFIG_PATH, e.g. 'nemo.collections.asr.models.EncDecMultiTaskModel'.", +) +@click.option( + "-c", "--config-path", type=str, default=None, help="Path to the training configuration file for MODULE_NAME." +) +@click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.") +@click.option( + "-b", + "--buckets", + cls=FloatList, + default=[5.0, 10.0, 15.0, 20.0, 25.0, 30.0], + help="List of upper-bound bucket bins (i.e. first bucket is [0.0 - item0), second bucket is [item0 - item1), etc.). " + "We also support a nested list for 2D bucketing, e.g. [[2.0, 10],[2.0,20],[4.5,15],[4.5,30],...], " + "where each item is a pair of (max_input_seq_len, max_output_seq_len) for a given bucket.", +) +@click.option( + "-t", + "--threshold", + type=float, + default=0.05, + help="Search stopping criterion in range [0, 1], lower is more precise. Interpret as the uncerainty gap, i.e. (min_oom_batch_size - max_ok_batch_size) / min_oom_batch_size.", +) +@click.option("-s", "--start-batch-size", type=int, default=32, help="Initial batch size to start the search from.") +@click.option( + "-r", + "--ratio", + type=int, + default=12, # conservative estimate towards longer transcripts + help="The output_sequence_length to input_sequence_length ratio for the purpose of determing the maximum output sequence lengths. " + "The interpretation depends on input and output modalities. Examples: for audio->text it's tokens per second. " + "For text->audio it's seconds per token. For audio->audio it's output seconds per input second. " + "For text->text it's output tokens per input token. " + "In general larger ratio means longer output sequences and increased memory consumption. " + "The default value is set adequately for automatic speech recognition. " + "This argument is ignored when 2D buckets are provided to --buckets option.", +) +@click.option( + "-f", + "--memory-fraction", + type=float, + default=0.9, + help="Limits the use of CUDA memory for this process to MEMORY_FRACTION of the total device memory. " + "By default we force 5% memory to be unused to account for non-training-loop related CUDA memory usage" + "in actual training scripts.", +) +@click.option( + "-d", + "--device", + default="cuda:0", + help="Device string to be passed to torch.device; due to MEMORY_FRACTION option, " + "it must specify the device index (e.g. cuda:0). " + "You can also leave the default index and select a specific GPU using env var CUDA_VISIBLE_DEVICES=", +) +@click.option( + "-y", + "--dtype", + default="bfloat16", + help="Float precision to use for computation (used together with autocast).", +) +@click.option( + "--ddp/--no-ddp", + type=bool, + default=True, + help="Whether we should simulate DDP GPU RAM usage. Stores an extra copy of the model in GPU memory. Enabled by default.", +) +def oomptimizer( + pretrained_name: str | None, + module_name: str | None, + config_path: str | None, + optimizer_name: str, + buckets: list[float], + threshold: float, + start_batch_size: int, + ratio: int, + memory_fraction: float, + device: str, + dtype: str, + ddp: bool, +): + """ + OOMptimizer finds the optimal batch sizes for training your model with bucketing dataloading. + It performs a search over batch sizes until it converges by measuring the GPU memory usage for + a model's training step and optimizer update. + + \b + There are two main usage patterns: for using a pretrained model or an untrained model configuration. + The latter is more flexible but requires the user to provide two separate arguments. Examples: + * python oomptimizer.py --pretrained-name nvidia/canary-1b + * python oomptimizer.py --module-name nemo.collections.asr.models.EncDecMultiTaskModel \ + --config-path examples/asr/conf/speech_multitask/fast-conformer_aed.yaml + + Dynamic bucketing is notoriously difficult to tune as you risk running into CUDA OOM many steps into the training. + In order to simplify finding the optimal settings, OOMptimizer scans each bucket to find the maximum possible + batch size that doesn't trigger a CUDA OOM. + + \b + The suggested workflow is the following: + 1) Run scripts/speech_recognition/estimate_duration_bins.py to get the duration distribution of your data. + (consider running estimate_duration_bins_2d.py for models with a strong dependency on output sequence length + such as attention-encoder-decoder models). + 2) Run OOMptimizer to find the optimal batch sizes for your specific model, optimizer, and GPU. + 3) Use these optimal settings in your actual training script and enjoy optimal GPU utilization OOM-free. + + In the unlikely event that OOMptimizer bucket batch sizes are still leading to OOMs, + please try a lower setting of the MEMORY_FRACTION option, e.g. 0.75 (75% of GPU memory). + This may be required in very complex setups where there are additional GPU RAM loads that can't be anticipated + through the combination of training_step and optimizer update. + """ + if all(opt is None for opt in (pretrained_name, module_name, config_path)): + click.secho( + "You need to provide either PRETRAINED_NAME or the pair of MODULE_NAME and CONFIG_PATH.", fg="yellow" + ) + sys.exit(1) + logging.setLevel(logging.CRITICAL) + torch.cuda.set_per_process_memory_fraction(memory_fraction, device) + + model_clones = [] + for _ in range(2 if ddp else 1): + if pretrained_name is not None: + assert ( + config_path is None and module_name is None + ), "--pretrained-name cannot be used together with --module-name/--config-path" + click.echo(f"Intializing ASR model from pretrained checkpoint {pretrained_name}.") + trainer = pl.Trainer(barebones=True) + trainer.log_every_n_steps = 1000000 + model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) + else: + assert config_path is not None, "--module-name requires --config-path to be specified as well." + assert module_name is not None, "--config-path requires --module-name to be specified as well." + cfg = OmegaConf.load(config_path) + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + trainer.log_every_n_steps = 1000000 + namespace, name = module_name.rsplit('.', maxsplit=1) + model_cls = getattr(importlib.import_module(namespace), name) + model = model_cls.restore_from_pretrained_models(cfg, trainer=trainer).to(device) + model.log = lambda *args, **kwargs: None + model_clones.append(model) + model = model_clones[-1] + # model.setup(stage="fit") + model.init_consumed_samples = 0 + model._compute_consumed_samples_after_training_step = lambda *args, **kwargs: 1 + + if not hasattr(model, "oomptimizer_schema"): + click.secho( + f"We read model of type {type(model)} which doesn't seem to support OOMptimizer " + f"(we could not find the property .oomptimizer_schema).", + fg="red", + ) + sys.exit(1) + + schema = model.oomptimizer_schema + + click.echo("Setting up the optimizers.") + optimizer, _ = model.setup_optimization({"name": optimizer_name, "lr": 1e-7, "weight_decay": 0.0}) + + is_2d_bucketing = all( + isinstance(item, (list, tuple)) and len(item) == 2 and all(isinstance(v, Number) for v in item) + for item in buckets + ) + # Determine modality for input and output. + modalities = [ + ( + "text" + if any( + isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction + for item in schema["inputs"] + if not isinstance(item["type"], str) + ) + else "audio" + ) + for direction in ("input", "output") + ] + + def get_max_seq_lens(buckets): + + def _determine_lens_for_bucket(bin): + if is_2d_bucketing: + input_len, output_len = bin + else: + input_len = bin + output_len = math.ceil(ratio * input_len) + sampling_rate = getattr( + model, "sample_rate", 16000 + ) # TODO: may need to extend schema for broader model coverage + match modalities: + case "audio", "audio": + return ( + compute_num_samples(input_len, sampling_rate=sampling_rate), + compute_num_samples(output_len, sampling_rate=sampling_rate), + ) + case "audio", "text": + return (compute_num_samples(input_len, sampling_rate=sampling_rate), output_len) + case "text", "audio": + return ( + input_len, + compute_num_samples(output_len, sampling_rate=sampling_rate), + ) + case "text", "text": + return input_len, output_len + case _: + raise RuntimeError(f"Unexpected modality combination: {_}") + + return [_determine_lens_for_bucket(bin) for bin in buckets] + + click.echo("Starting profiling.") + max_seq_lens = get_max_seq_lens(buckets) + gen = ProfilingBatchGenerator(schema=schema, start_batch_size=start_batch_size, rel_gap_thresh=threshold) + profile = {} + + # Iterate buckets from the largest to the smallest sequences. This usually ends up creating + # a tiny bit smaller batches, likely due to worse memory fragmentation. + with torch.autocast("cuda", getattr(torch, dtype)): + for bucket, (seq_len_in, seq_len_out) in reversed(list(zip(buckets, max_seq_lens))): + click.echo(f"The current sequence lengths are: input={seq_len_in} output={seq_len_out}.") + gen.reset() + batch_idx = 0 + + def step(): + click.echo( + f"\t[BEGIN step] [CUDA RAM CURRENT: {torch.cuda.memory_allocated() / (1024 * 1024):.1f}MB] [CUDA RAM MAX: {torch.cuda.max_memory_allocated() / (1024*1024):.1f}MB]" + ) + batch = gen(seq_len_in, seq_len_out) + oom = False + try: + click.echo( + f"\tCurrent settings | batch_size={gen._current} | gap: {gen.current_rel_gap}... ", nl=False + ) + optimizer.zero_grad() + out = model.training_step(iter([batch])) + # out['loss'].sum().backward() + optimizer.step() + except torch.cuda.OutOfMemoryError as e: + click.secho(f"OOM!", fg="yellow") + oom = True + except RuntimeError as e: + if "cuFFT error: CUFFT_INTERNAL_ERROR" not in str(e): + raise + click.secho(f"OOM!", fg="yellow") + oom = True + else: + click.secho(f"OK!", fg="green") + finally: + click.echo( + f"\t[END step] [CUDA RAM CURRENT: {torch.cuda.memory_allocated() / (1024 * 1024):.1f}MB] [CUDA RAM MAX: {torch.cuda.max_memory_allocated() / (1024*1024):.1f}MB]" + ) + del batch + # Note: We could call empty_cache() to free up some more memory on the GPU, + # but we have found out empirically that this causes a mismatched condition + # between OOMptimizer and the actual training. During training, there is some + # degree of memory fragmentation and it's better to simulate that in OOMptimizer. + torch.cuda.memory.empty_cache() + torch.cuda.reset_max_memory_allocated() + return oom + + oom = step() + while not (finished := gen.advance(oom)): + click.echo("\t" + "=" * 80) + oom = step() + + click.secho( + f"=> Optimal setting for bucket={bucket} (input={seq_len_in} output={seq_len_out}) is max_batch_size={gen.max_batch_size}", + fg="green", + ) + profile[(bucket, seq_len_in, seq_len_out)] = gen.max_batch_size + gen.start_batch_size = gen.max_batch_size * 2 + + # Reverse the profile to be ascendingly sorted again. + profile = dict(reversed(list(profile.items()))) + + click.echo("The 1st stage profile is:") + for (bucket, seq_len_in, seq_len_out), bs in profile.items(): + click.echo(f"Bucket={bucket} (input={seq_len_in} output={seq_len_out}) => max_batch_size={bs}") + + if is_2d_bucketing: + # 2D bucketing doesn't support bucket merging. + final_profile = [["[" + ",".join(map(str, b)) + "]", bs] for (b, _, __), bs in profile.items()] + max_input_len, max_output_len = buckets[-1] + ratio = max_output_len / max_input_len + else: + click.echo("Bucket merging stage...") + final_profile = [] + for idx, ((bucket, seq_len_in, seq_len_out), bs) in enumerate(profile.items()): + if idx == 0: + final_profile.append([bucket, bs]) + continue + if bs == final_profile[-1][1]: + click.echo(f"Merging bucket {idx} with bucket {idx-1} due to identical batch sizes.") + final_profile[-1][0] = bucket + continue + final_profile.append([bucket, bs]) + max_input_len = final_profile[-1][0] + + click.secho(f"The profile was created with the following settings:") + click.secho(f"* using {memory_fraction:.1%} of available GPU RAM.") + click.secho(f"* {'' if ddp else 'not '}simulating DDP memory overhead.") + click.secho(f"* using AMP with dtype={dtype}.") + click.secho("The final profile is:", bold=True) + click.secho("\tbucket_duration_bins=[" + ",".join(str(seqlen) for seqlen, bs in final_profile) + "]", bold=True) + click.secho("\tbucket_batch_size=[" + ",".join(str(bs) for seqlen, bs in final_profile) + "]", bold=True) + click.secho("\t(The following flags are suitable for ASR/speech-to-text models):") + click.secho(f"\tmax_tps={ratio}", bold=True) + click.secho(f"\tmax_duration={max_input_len}", bold=True) + + +if __name__ == "__main__": + oomptimizer() From 97a543e2c9d7736b47cdb296b5154752b6b181cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 12:56:09 -0400 Subject: [PATCH 03/63] Initial version of estimate token bins script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_recognition/estimate_token_bins.py | 306 ++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 scripts/speech_recognition/estimate_token_bins.py diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py new file mode 100644 index 000000000000..6d01cf699843 --- /dev/null +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -0,0 +1,306 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import math +from functools import partial +from itertools import islice +from pathlib import Path +from typing import Callable, Iterable + +import numpy as np +import pandas as pd +from lhotse.cut import Cut +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config +from nemo.collections.common.data.lhotse.dataloader import ( + DurationFilter, + FixedBucketBatchSizeConstraint2D, + LhotseDataLoadingConfig, + MultimodalFixedBucketBatchSizeConstraint2D, + TokenPerSecondFilter, + tokenize, +) +from nemo.collections.common.prompts.formatter import PromptFormatter +from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Estimate token bins for Lhotse dynamic bucketing using a sample of the input dataset. " + "The dataset is read either from one or more manifest files and supports data weighting. " + "Unlike estimate_duration_bins.py, this script is intended for text data only. " + "It supports 2D bucketing. ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "input", + help='Path to a data input configuration YAML file. ' + 'This is the only type of input specification supported for text data.', + ) + parser.add_argument( + "-t", + "--tokenizer", + nargs="+", + required=True, + help="Path to one or more SPE tokenizers. More than one means we'll use AggregateTokenizer and --langs argument must also be used. When provided, we'll estimate a 2D distribution for input and output sequence lengths.", + ) + parser.add_argument( + "-a", "--langs", nargs="+", help="Language names for each of AggregateTokenizer sub-tokenizers." + ) + parser.add_argument( + "-b", + "--buckets", + type=int, + default=30, + help="The desired number of buckets (dim0 => covers input sequence length / audio duration).", + ) + parser.add_argument( + "-s", + "--sub-buckets", + type=int, + default=2, + help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens).", + ) + parser.add_argument( + "-n", + "--num_examples", + type=int, + default=-1, + help="The number of examples (utterances) to estimate the bins. -1 means use all data " + "(be careful: it could be iterated over infinitely).", + ) + # parser.add_argument( + # "-l", + # "--min_tokens", + # type=float, + # default=-float("inf"), + # help="If specified, we'll filter out examples with less tokens than this number.", + # ) + # parser.add_argument( + # "-u", + # "--max_tokens", + # type=float, + # default=float("inf"), + # help="If specified, we'll filter out examples with more tokens than this number.", + # ) + # parser.add_argument( + # "--max_tpt", + # type=float, + # default=float("inf"), + # help="If specified, we'll filter out examples with more output tokens per input token than this. " + # ) + parser.add_argument( + "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." + ) + parser.add_argument( + "-f", + "--prompt-format", + type=str, + help="When specified, we'll use a prompt formatter in addition to the tokenizer for the purpose of estimating token count bins. " + "This is useful for accurate 2D bucket estimation with models such as EncDecMultiTaskModel (Canary-1B), " + "or any model where the label sequence consists of a user prompt and a model's response.", + ) + parser.add_argument( + "-p", + "--prompt", + type=str, + help="Prompt slots provided as a Python list of dicts. It is used together with --prompt-format option." + "For example, with Canary-1B you may use: [{'role':'user','slots':{'source_lang':'en','target_lang':'en','task':'asr','pnc':'yes'}]", + ) + return parser.parse_args() + + +def estimate_token_buckets( + cuts: Iterable[Cut], + num_buckets: int, + num_subbuckets: int, + quiet: bool, +) -> list[tuple[float, float]]: + """ + This function is based on lhotse.dataset.sampling.dynamic_bucketing.estimate_duration_buckets. + It extends it to a 2D bucketing case. + """ + assert num_buckets > 1 + + constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0]) + + # Gather the duration and token count statistics for the dataset. + num_input_tokens = [] + num_output_tokens = [] + for c in cuts: + itoks, otoks = constraint.measure_length(c) + num_input_tokens.append(itoks) + num_output_tokens.append(otoks) + num_input_tokens = np.array(num_input_tokens, dtype=np.int32) + num_output_tokens = np.array(num_output_tokens, dtype=np.int32) + joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) + joint.sort() + num_input_tokens = joint.f0 + num_output_tokens = joint.f1 + + # We are building buckets with equal duration (empirically leads to more even bucket exhaustion over time). + # We need to determine how much duration to allocate per bucket. + size_per_bucket = num_input_tokens.sum() / num_buckets + + if not quiet: + print("Duration distribution:") + print(pd.Series(num_input_tokens).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + max_input_tokens = num_input_tokens[-1] + + tpt = num_output_tokens / num_input_tokens + if not quiet: + print("Output tokens per input token distribution:") + print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + max_tpt = tpt.max() + del tpt + + bins = [] + bin_indexes = [0] + tot = 0.0 + + def _estimate_output_token_buckets(max_bucket_duration): + # Since this is 2D bucketing, apply the same bin creation logic + # for the second dimension (i.e. token count) as for the first dimension (duration). + # That means we aim to have each bucket contain roughly the same number of tokens. + # Note that this estimation is biased towards more padding if you have + # a lot of zero-token examples (e.g. non-speech). + nonlocal bins + num_tokens_bucket = num_output_tokens[bin_indexes[-1] : binidx] + num_tokens_bucket.sort() + tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets + tot_toks = 0 + # Iterate over token counts, and whenever we hit tokens_per_subbucket, create a new 2D bucket bin. + for num_toks in num_tokens_bucket: + # Threshold hit: we are creating a new (max_duration, max_num_tokens) bin. + if tot_toks > tokens_per_subbucket: + bins.append((max_bucket_duration, num_toks)) + tot_toks = 0 + tot_toks += num_toks + bins.append((size, math.ceil(size * max_tpt))) + + # Iterate over data, and whenever we hit size_per_bucket, create a new bucket bin. + for binidx, size in enumerate(num_input_tokens): + if tot > size_per_bucket: + # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). + _estimate_output_token_buckets(max_bucket_duration=size) + tot = 0.0 + tot += size + + # Estimate an extra 2D bin set for global max duration. + _estimate_output_token_buckets(max_bucket_duration=max_input_tokens) + + return bins + + +def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: + if len(paths) == 1: + tok = SentencePieceTokenizer(paths[0]) + else: + assert langs is not None and len(paths) == len( + langs + ), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" + tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) + return TokenizerWrapper(tok) + + +def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): + if prompt is not None: + turns = prompt.get_default_dialog_slots() + last_turn = {"role": prompt.OUTPUT_ROLE, "slots": prompt.get_slots(prompt.OUTPUT_ROLE)} + assert len(last_turn["slots"]) == 1 # TODO: not sure how to handle multi-slot for system output here + for key in last_turn["slots"]: + last_turn["slots"][key] = cut.supervisions[0].text + last_turn["slots"][prompt.PROMPT_LANGUAGE_SLOT] = cut.supervisions[0].language + turns.append(last_turn) + ans = prompt.encode_dialog(turns) + cut.supervisions[0].tokens = ans["input_ids"] + + elif tokenizer is not None: + cut = tokenize(cut, tokenizer) + + return cut + + +class RejectionsCounter: + def __init__(self, predicate: Callable, message: str): + self.predicate = predicate + self.message = message + self.total = 0 + self.rejected = 0 + + def __call__(self, example) -> bool: + ans = self.predicate(example) + self.total += 1 + if not ans: + self.rejected += 1 + return ans + + def print_report(self) -> None: + if self.rejected: + print(f"{self.message} | Rejected {self.rejected}/{self.total} examples.") + + +def main(): + args = parse_args() + + if not args.quiet: + pd.set_option('display.float_format', lambda x: '%.2f' % x) + + tokenizer = None + prompt = None + if args.tokenizer is not None: + tokenizer = load_tokenizer(args.tokenizer, args.langs) + if args.prompt_format is not None: + prompt_defaults = None + if args.prompt is not None: + prompt_defaults = ast.literal_eval(args.prompt) + prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) + + assert args.input.endswith(".yaml") + config = OmegaConf.merge( + OmegaConf.structured(LhotseDataLoadingConfig), + OmegaConf.from_dotlist([f"input_cfg={args.input}"]), + ) + cuts, _ = read_cutset_from_config(config) + # duration_filter = RejectionsCounter(DurationFilter(args.min_duration, args.max_duration), "Duration filtering") + # cuts = cuts.filter(duration_filter) + cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt)) + # tpt_filter = RejectionsCounter(TokensPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") + # cuts = cuts.filter(tpt_filter) + if (N := args.num_examples) > 0: + cuts = islice(cuts, N) + + token_bins = estimate_token_buckets( + cuts, + num_buckets=args.buckets, + num_subbuckets=args.sub_buckets, + quiet=args.quiet, + ) + token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" + if args.quiet: + print(token_bins) + return + # duration_filter.print_report() + # tps_filter.print_report() + print("Use the following options in your config:") + print(f"\tnum_buckets={args.buckets}") + print(f"\tbucket_duration_bins={token_bins}") + + +if __name__ == "__main__": + main() From c6f0b3d4f394d405642e917d9a166330b10d8286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 13:27:43 -0400 Subject: [PATCH 04/63] Initial support for multimodal 2d bucketing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 7c42767fd7b3..b1618262fccd 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -595,7 +595,7 @@ def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: @dataclass class MultimodalSamplingConstraint(SamplingConstraint): # how many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch - token_equivalent_duration: float + token_equivalent_duration: float | None = None # defines maximum batch size (may be lower than that if batch_length is also specified) batch_size: int | None = None @@ -693,15 +693,25 @@ def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = No class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): token_equivalent_duration: float | None = None - def measure_length(self, example: Any) -> float: - assert not self.bucketing_2d_enabled, "2D bucketing for multimodal sampling is not yet supported." - if hasattr(example, "num_tokens"): - return example.num_tokens + def measure_length(self, example: Any) -> float | tuple[float, float]: + # Case 1: audio if isinstance(example, Cut): assert ( self.token_equivalent_duration is not None ), "Cannot use MultimodalFixedBucketBatchSizeConstraint with speech data when token_equivalent_duration was not specified." - return example.duration / self.token_equivalent_duration + in_tokens = example.duration / self.token_equivalent_duration + if self.bucketing_2d_enabled: + return in_tokens, _measure_tokens(example) + else: + return in_tokens + # Case 2: text + if self.bucketing_2d_enabled: + if hasattr(example, "context_ids") and hasattr(example, "answer_ids"): + return len(example.context_ids), len(example.answer_ids) + else: + if hasattr(example, "num_tokens"): + return example.num_tokens + raise RuntimeError(f"Unsupported example type: {type(example)}") From 7b52d5b74ab2abf8001ad16477acfe8f94db3150 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 14:00:58 -0400 Subject: [PATCH 05/63] Extend to text-to-text oomptimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_models.py | 101 ++++++++++-------- .../oomptimizer-speechllm.py | 10 +- 2 files changed, 68 insertions(+), 43 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 3b522a3695a6..ba22a8c7d731 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -1706,53 +1706,70 @@ def find_frozen_submodules(model): self.perception = self.trainer.strategy._setup_model(self.perception) self.perception = self.perception.cuda(torch.cuda.current_device()) - @property - def oomptimizer_schema(self) -> dict: + def oomptimizer_schema(self, schema: str = "audio") -> dict: """ Return a typing schema for optimal batch size calibration for various sequence lengths using OOMptimizer. """ - # TODO: add support for text - # input_ids = text_batch["text_input_ids"][:, :-1] - # labels = text_batch["text_input_ids"][:, 1:] - # attention_mask = self._create_attention_mask(input_ids) - # loss_mask = text_batch["text_masks"][:, 1:] - - return { - "cls": dict, - "inputs": [ - {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, - {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, - { - "name": "tokens", - "type": NeuralType(("B", "T"), LabelsType()), - "seq_length": "output", - "vocab_size": self.tokenizer.vocab_size, - }, - { - "name": "tokens_length", - "type": NeuralType(("B",), LengthsType()), - "seq_length": "output", - }, - { - "name": "labels", - "type": NeuralType(("B", "T"), LabelsType()), - "seq_length": "output", - "vocab_size": self.tokenizer.vocab_size, - }, - { - "name": "loss_mask", - "type": NeuralType(("B", "T"), MaskType()), - "seq_length": "output", - }, - { - "name": "context_start_idx", - "type": "constant", - "value": 0, - }, - ], - } + if schema == "audio": + return { + "cls": dict, + "inputs": [ + {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "tokens", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "tokens_length", + "type": NeuralType(("B",), LengthsType()), + "seq_length": "output", + }, + { + "name": "labels", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "loss_mask", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "output", + }, + { + "name": "context_start_idx", + "type": "constant", + "value": 0, + }, + ], + } + elif schema == "text": + # TODO: add support for text + # input_ids = text_batch["text_input_ids"][:, :-1] + # labels = text_batch["text_input_ids"][:, 1:] + # attention_mask = self._create_attention_mask(input_ids) + # loss_mask = text_batch["text_masks"][:, 1:] + + return { + "cls": dict, + "inputs": [ + { + "name": "text_input_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "input", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "text_masks", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "input", + }, + ], + } class CrossAttendModularAudioGPTModel(ModularAudioGPTModel): diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py index 16e5ca0816a4..d3f4b5a53d1d 100755 --- a/scripts/speech_recognition/oomptimizer-speechllm.py +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -248,6 +248,13 @@ def type_cast_value(self, ctx, value): "-c", "--config-path", type=str, default=None, help="Path to the training configuration file for MODULE_NAME." ) @click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.") +@click.option( + "-s", + "--schema", + type=str, + default="audio", + help="Which schema to use (typically used for choosing the modality, i.e., 'audio' / 'text'", +) @click.option( "-b", "--buckets", @@ -312,6 +319,7 @@ def oomptimizer( module_name: str | None, config_path: str | None, optimizer_name: str, + schema: str, buckets: list[float], threshold: float, start_batch_size: int, @@ -392,7 +400,7 @@ def oomptimizer( ) sys.exit(1) - schema = model.oomptimizer_schema + schema = model.oomptimizer_schema(schema) click.echo("Setting up the optimizers.") optimizer, _ = model.setup_optimization({"name": optimizer_name, "lr": 1e-7, "weight_decay": 0.0}) From f63a110895b4792054a5d0395e455a685ac2e7ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 14:32:04 -0400 Subject: [PATCH 06/63] Preliminary support for Llama2 prompt format in ast+mt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 5 ++++- .../common/data/lhotse/text_adapters.py | 18 +++++++++++++++++- .../speech_recognition/estimate_token_bins.py | 15 ++++++++++----- .../oomptimizer-speechllm.py | 1 - 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index b1618262fccd..376363d3be5b 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -42,6 +42,7 @@ from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample +from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.utils import logging @@ -747,7 +748,9 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa example.tokenized_prompt = tokenized_prompt example.tokenized_transcript = tokenized_transcript else: - raise RuntimeError(f"Currently we only support tokenization + prompting during sampling for audio modality.") + # TODO: need an equivalent of get_prompt_format_fn for text modality + # to be able to construct different kinds of turns specific to a given application + example = example.tokenize(tokenizer, prompt=PromptFormatter.resolve(prompt_format)(tokenizer)) return example diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 3d1138d427f2..894e485d3111 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -25,6 +25,7 @@ from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging @@ -107,7 +108,22 @@ def num_tokens(self) -> Optional[int]: return self.input_ids.shape[0] return None - def tokenize(self, tokenizer: TokenizerWrapper) -> "TextExample": + def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) -> "TextExample": + + if prompt is not None: + # TODO(pzelasko): this is temporarily hardcoded and assumes LLama2 prompt format + ans = prompt.encode_dialog( + [ + {"role": "system_and_user", "slots": {"system": self.question.text, "message": self.source.text}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, + ] + ) + self.input_ids = ans["input_ids"] + self.context_ids = ans["context_ids"] + self.answer_ids = ans["answer_ids"] + self.mask = ans["mask"] + return self + input_ids = [] context_ids = [] if self.question: diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 6d01cf699843..7c06f4061279 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -73,8 +73,9 @@ def parse_args(): "-s", "--sub-buckets", type=int, - default=2, - help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens).", + default=None, + help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens). " + "If not provided, we'll only perform 1D bucketing. ", ) parser.add_argument( "-n", @@ -128,7 +129,7 @@ def parse_args(): def estimate_token_buckets( cuts: Iterable[Cut], num_buckets: int, - num_subbuckets: int, + num_subbuckets: int | None, quiet: bool, ) -> list[tuple[float, float]]: """ @@ -197,12 +198,16 @@ def _estimate_output_token_buckets(max_bucket_duration): for binidx, size in enumerate(num_input_tokens): if tot > size_per_bucket: # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). - _estimate_output_token_buckets(max_bucket_duration=size) + if num_subbuckets is not None: # 2D bucketing + _estimate_output_token_buckets(max_bucket_duration=size) + else: # 1D bucketing + bins.append(size) tot = 0.0 tot += size # Estimate an extra 2D bin set for global max duration. - _estimate_output_token_buckets(max_bucket_duration=max_input_tokens) + if num_subbuckets is not None: + _estimate_output_token_buckets(max_bucket_duration=max_input_tokens) return bins diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py index d3f4b5a53d1d..97efa2806bd9 100755 --- a/scripts/speech_recognition/oomptimizer-speechllm.py +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -249,7 +249,6 @@ def type_cast_value(self, ctx, value): ) @click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.") @click.option( - "-s", "--schema", type=str, default="audio", From ef96459f2396f551f1433200a039872cf4f60c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 14:42:08 -0400 Subject: [PATCH 07/63] Support for 1D estimate token bins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_recognition/estimate_token_bins.py | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 7c06f4061279..28ab36ccbff5 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -32,6 +32,7 @@ FixedBucketBatchSizeConstraint2D, LhotseDataLoadingConfig, MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, TokenPerSecondFilter, tokenize, ) @@ -137,22 +138,33 @@ def estimate_token_buckets( It extends it to a 2D bucketing case. """ assert num_buckets > 1 + is_2d = num_subbuckets is not None - constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0]) + if is_2d: + constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0]) + else: + constraint = MultimodalSamplingConstraint() # Gather the duration and token count statistics for the dataset. num_input_tokens = [] num_output_tokens = [] for c in cuts: - itoks, otoks = constraint.measure_length(c) - num_input_tokens.append(itoks) - num_output_tokens.append(otoks) + ans = constraint.measure_length(c) + if is_2d: + itoks, otoks = ans + num_input_tokens.append(itoks) + num_output_tokens.append(otoks) + else: + num_input_tokens.append(ans) num_input_tokens = np.array(num_input_tokens, dtype=np.int32) - num_output_tokens = np.array(num_output_tokens, dtype=np.int32) - joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) - joint.sort() - num_input_tokens = joint.f0 - num_output_tokens = joint.f1 + if is_2d: + num_output_tokens = np.array(num_output_tokens, dtype=np.int32) + joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) + joint.sort() + num_input_tokens = joint.f0 + num_output_tokens = joint.f1 + else: + num_input_tokens.sort() # We are building buckets with equal duration (empirically leads to more even bucket exhaustion over time). # We need to determine how much duration to allocate per bucket. @@ -163,12 +175,13 @@ def estimate_token_buckets( print(pd.Series(num_input_tokens).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) max_input_tokens = num_input_tokens[-1] - tpt = num_output_tokens / num_input_tokens - if not quiet: - print("Output tokens per input token distribution:") - print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) - max_tpt = tpt.max() - del tpt + if is_2d: + tpt = num_output_tokens / num_input_tokens + if not quiet: + print("Output tokens per input token distribution:") + print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + max_tpt = tpt.max() + del tpt bins = [] bin_indexes = [0] @@ -198,9 +211,9 @@ def _estimate_output_token_buckets(max_bucket_duration): for binidx, size in enumerate(num_input_tokens): if tot > size_per_bucket: # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). - if num_subbuckets is not None: # 2D bucketing + if is_2d: _estimate_output_token_buckets(max_bucket_duration=size) - else: # 1D bucketing + else: bins.append(size) tot = 0.0 tot += size From 2cae09bb0a0902d2ceaa1b2d7356204e2d0fc913 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 14:48:16 -0400 Subject: [PATCH 08/63] Support for 1D estimate token bins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 6 ++++-- scripts/speech_recognition/estimate_token_bins.py | 13 ++----------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 376363d3be5b..c5774d65703d 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -735,7 +735,7 @@ def tokenize(example: Example, tokenizer) -> Example: return example -def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Example: +def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str | PromptFormatter) -> Example: # TODO(pzelasko): This mechanism makes it possible to measure the actual output sequence length # for prompted models such as AED MultiTask (Canary), which includes the transcript and the prompt. # We intend to extend it for text modality in follow-up work. @@ -750,7 +750,9 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa else: # TODO: need an equivalent of get_prompt_format_fn for text modality # to be able to construct different kinds of turns specific to a given application - example = example.tokenize(tokenizer, prompt=PromptFormatter.resolve(prompt_format)(tokenizer)) + if isinstance(prompt_format, str): + prompt_format = PromptFormatter.resolve(prompt_format)(tokenizer) + example = example.tokenize(tokenizer, prompt=prompt_format) return example diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 28ab36ccbff5..3fff098fbde1 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -35,6 +35,7 @@ MultimodalSamplingConstraint, TokenPerSecondFilter, tokenize, + tokenize_with_prompt, ) from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer @@ -238,19 +239,9 @@ def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrappe def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): if prompt is not None: - turns = prompt.get_default_dialog_slots() - last_turn = {"role": prompt.OUTPUT_ROLE, "slots": prompt.get_slots(prompt.OUTPUT_ROLE)} - assert len(last_turn["slots"]) == 1 # TODO: not sure how to handle multi-slot for system output here - for key in last_turn["slots"]: - last_turn["slots"][key] = cut.supervisions[0].text - last_turn["slots"][prompt.PROMPT_LANGUAGE_SLOT] = cut.supervisions[0].language - turns.append(last_turn) - ans = prompt.encode_dialog(turns) - cut.supervisions[0].tokens = ans["input_ids"] - + cut = tokenize_with_prompt(cut, tokenizer, prompt) elif tokenizer is not None: cut = tokenize(cut, tokenizer) - return cut From bdec6181fed3bbeb4659765d0f8929c3c8170997 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 14:52:43 -0400 Subject: [PATCH 09/63] Fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/estimate_token_bins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 3fff098fbde1..ce0949a4403b 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -288,7 +288,9 @@ def main(): cuts, _ = read_cutset_from_config(config) # duration_filter = RejectionsCounter(DurationFilter(args.min_duration, args.max_duration), "Duration filtering") # cuts = cuts.filter(duration_filter) - cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt)) + cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) + if hasattr(cuts, "prefetch"): + cuts = cuts.prefetch() # to be released in lhotse 1.27 # tpt_filter = RejectionsCounter(TokensPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") # cuts = cuts.filter(tpt_filter) if (N := args.num_examples) > 0: From a7ce8b6a8478d42e7ee62c3a3daa22903a4edb69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 19 Aug 2024 15:20:12 -0400 Subject: [PATCH 10/63] Fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/estimate_token_bins.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index ce0949a4403b..fca14bce4285 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -302,7 +302,10 @@ def main(): num_subbuckets=args.sub_buckets, quiet=args.quiet, ) - token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" + if args.sub_buckets is not None: + token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" + else: + token_bins = "[" + ','.join(f"{b:d}" for b in token_bins) + "]" if args.quiet: print(token_bins) return From b3ed44ca5eaffa5c879e287236589e9779fed9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 20 Aug 2024 14:13:01 -0400 Subject: [PATCH 11/63] Minor tweaks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../collections/multimodal/speech_llm/models/modular_models.py | 2 ++ scripts/speech_recognition/oomptimizer-speechllm.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index ba22a8c7d731..45d78ace066b 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -1770,6 +1770,8 @@ def oomptimizer_schema(self, schema: str = "audio") -> dict: }, ], } + else: + raise RuntimeError(f"Unknown schema type for oomptimizer of class {type(self)}: '{schema}'") class CrossAttendModularAudioGPTModel(ModularAudioGPTModel): diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py index 97efa2806bd9..f98d13281273 100755 --- a/scripts/speech_recognition/oomptimizer-speechllm.py +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -282,7 +282,8 @@ def type_cast_value(self, ctx, value): "For text->text it's output tokens per input token. " "In general larger ratio means longer output sequences and increased memory consumption. " "The default value is set adequately for automatic speech recognition. " - "This argument is ignored when 2D buckets are provided to --buckets option.", + "This argument is ignored when 2D buckets are provided to --buckets option. " + "For GPT-style models, use --ratio=1 ", ) @click.option( "-f", From f7809d61f048247cf12e06c082815079a05012a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 20 Aug 2024 16:47:12 -0400 Subject: [PATCH 12/63] Add min/max tokens filter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index c5774d65703d..f942ecb676a5 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -96,6 +96,8 @@ class LhotseDataLoadingConfig: token_equivalent_duration: float | None = None batch_tokens: int | None = None quadratic_factor: float | None = None + min_tokens: int | None = -1 + max_tokens: int | None = 1_000_000_000 # 3. Supported existing NeMo options. shuffle: bool = False @@ -432,6 +434,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Duration filtering, same as native NeMo dataloaders. # We can filter after the augmentations because they are applied only when calling load_audio(). cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration)) + cuts = cuts.filter(TokenCountFilter(config.min_tokens, config.max_tokens)) bucket_duration_bins = determine_bucket_duration_bins(config) if config.use_multimodal_sampling: @@ -776,6 +779,21 @@ def __call__(self, example) -> bool: return True # does not apply to text etc. +class TokenCountFilter: + """Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise.""" + + def __init__(self, t_min: float, t_max: float) -> None: + self.t_min = t_min + self.t_max = t_max + + def __call__(self, example) -> bool: + if isinstance(example, Cut): + return True # does not apply to Cuts + elif hasattr(example, "num_tokens"): + return self.t_min <= example.num_tokens <= self.t_max + return True # applies only to non-audio with num_tokens property + + class TokenPerSecondFilter: """ Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) From 2b26cb07818951e93588419c9c65f05bb7b0128e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 11:48:35 -0400 Subject: [PATCH 13/63] Change to bisect_left for bucket idx selection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index f942ecb676a5..d2ebd9eadc88 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -671,7 +671,7 @@ def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = No return super().select_bucket(buckets=buckets, example=example, example_len=example_len) if example_len is None: example_len = self.measure_length(example) - bucket_idx = bisect.bisect_right(buckets, example_len) + bucket_idx = bisect.bisect_left(buckets, example_len) # For 2D bucketing we have to refine the initially found bucket_idx, as bisect # looks primarily at the first index of a tuple (i.e. duration). # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) From 9589023ea44a67564d8f0dea61f659fd3a79e470 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 14:01:51 -0400 Subject: [PATCH 14/63] Add reconfigure_num_microbatches_calculator at the start of train epoch for modular models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../multimodal/speech_llm/models/modular_models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 45d78ace066b..5e53920ceec5 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -1182,6 +1182,16 @@ def load_state_dict(self, state_dict, strict: bool = True): else: super(MegatronGPTModel, self).load_state_dict(state_dict, strict=strict) + def on_train_epoch_start(self) -> None: + app_state = AppState() + reconfigure_num_microbatches_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + def on_load_checkpoint(self, checkpoint) -> None: """LightningModule hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint From 4f6d4fad53bc0e710f91eb0be0f9639c0b2d0765 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 11:03:20 -0700 Subject: [PATCH 15/63] Update lhotse multi-sampler config and make validation datasets finite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- ...ular_audio_gpt_config_cross_llama_lhotse_multi.yaml | 10 ++++++---- nemo/collections/common/data/lhotse/cutset.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 52149da6a570..436ccf0dca4e 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -266,8 +266,8 @@ model: micro_batch_size: ${model.micro_batch_size} max_seq_length: 2048 min_seq_length: 1 - context_key: 'input' - label_key: 'output' + context_key: 'context' + answer_key: 'answer' add_eos: True # add_eos: False end_string: ${model.data.end_string} @@ -276,10 +276,11 @@ model: separate_prompt_and_response_with_newline: False truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{input}[/INST] {output}" + prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" validation_ds: manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False @@ -289,7 +290,7 @@ model: min_seq_length: 1 drop_last: False context_key: ${model.data.train_ds.context_key} - label_key: ${model.data.train_ds.label_key} + answer_key: ${model.data.train_ds.answer_key} add_eos: ${model.data.train_ds.add_eos} end_string: ${model.data.end_string} add_sep: ${model.data.train_ds.add_sep} @@ -312,6 +313,7 @@ model: # test_ds: # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # force_finite: true # workaround to allow using input_cfg # names: null # Names of the corresponding datasets used to log metrics. # global_batch_size: ${model.global_batch_size} # micro_batch_size: ${model.micro_batch_size} diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index d1f8c5ba03ef..52a02fe0ede5 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -142,6 +142,8 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: # Resolve /path/to/input_cfg.yaml into config contents if needed. input_cfg = OmegaConf.load(input_cfg) cuts, is_tarred = parse_and_combine_datasets(input_cfg, propagate_attrs=propagate_attrs) + if propagate_attrs["force_finite"]: + is_tarred = False # TEMPORARY Disables IterableDatasetWrapper behaviour for finite datasets return cuts, is_tarred From 049bad5898566241d30dfdd3acd00431ffc51e9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 22 Aug 2024 15:05:56 -0400 Subject: [PATCH 16/63] Initial implementation of text+audio training for T5 modular models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_t5_models.py | 354 +++++++++++++----- 1 file changed, 255 insertions(+), 99 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index fce31d031abd..07e604936dee 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -49,6 +49,7 @@ from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.core.classes.mixins import adapter_mixins +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import AppState, logging, model_utils try: @@ -364,7 +365,7 @@ def prepare_llm_input(self, audio_batch): def forward( self, - audio_batch, + batch, checkpoint_activations_all_layers, ): """Forward pass of the model. @@ -372,39 +373,64 @@ def forward( We prepend audio embeddings to the instruction and label text tokens as the LLM input. """ - if 'audio_ratio' in audio_batch: - self.log( - 'audio_ratio', audio_batch['audio_ratio'].mean(), prog_bar=True, batch_size=1, rank_zero_only=False + + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + + multimodal_output = {} + + if 'audio_signal' in audio_batch: + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) + # enc_input = speech and text prompt + # dec_input and label = text output label + b = audio_batch['answers'].shape[0] + device = audio_batch['answers'].device + dec_input = ( + audio_batch['masked_answer_ids'] if 'masked_answer_ids' in audio_batch else audio_batch['answers'] ) - self.log( - 'local_batch_size', - audio_batch['audio_ratio'].shape[0], - prog_bar=True, - batch_size=1, - rank_zero_only=False, + dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=device), dec_input[:, :-1]], dim=-1) + labels = audio_batch['answers'] + dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() + output = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=encoder_input, + ) + loss_mask = dec_mask + multimodal_output['audio_text'] = (output, loss_mask) + + if text_batch: + b = text_batch['text_answer_ids'].shape[0] + encoder_input_ids = text_batch["text_context_ids"] + enc_mask = (encoder_input_ids != self.tokenizer.pad_id).long().contiguous() + decoder_input_ids = torch.cat( + [ + torch.full([b, 1], self.bos_id, device=encoder_input_ids.device), + text_batch["text_answer_ids"][:, :-1], + ], + dim=-1, ) + labels = text_batch["text_answer_ids"] + dec_mask = (decoder_input_ids != self.tokenizer.pad_id).long().contiguous() + loss_mask = dec_mask + output = self.frozen_model.enc_dec_model( + enc_input_ids=encoder_input_ids, + enc_attn_mask=enc_mask, + dec_input_ids=decoder_input_ids, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=None, + ) + multimodal_output['text'] = (output, loss_mask) - encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) - # enc_input = speech and text prompt - # dec_input and label = text output label - b = audio_batch['answers'].shape[0] - device = audio_batch['answers'].device - dec_input = audio_batch['masked_answer_ids'] if 'masked_answer_ids' in audio_batch else audio_batch['answers'] - dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=device), dec_input[:, :-1]], dim=-1) - labels = audio_batch['answers'] - dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() - output = self.frozen_model.enc_dec_model( - enc_input_ids=None, - enc_attn_mask=enc_mask, - dec_input_ids=dec_input, - dec_attn_mask=dec_mask, - token_type_ids=None, - labels=labels, - output_enc_hidden_only=False, - enc_input=encoder_input, - ) - loss_mask = dec_mask - return output, loss_mask + return multimodal_output def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): @@ -446,21 +472,42 @@ def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) batch = {key: val.cuda(non_blocking=True) for key, val in batch.items()} - output_tensor, loss_mask = self.forward( + multimodal_output = self.forward( batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers ) - def loss_func(output_tensor): + def loss_func(multimodal_output): # Loss for a micro-batch (ub) - if 'audio_ratio' in batch: - text_loss_weight = self.cfg.get('text_loss_weight', 1.0) - audio_ratio = batch['audio_ratio'] - scaled_loss_mask = loss_mask * torch.unsqueeze( - (1 * audio_ratio + text_loss_weight * (1 - audio_ratio)), 1 + loss_for_ub = None + + modality_weights = self.cfg.get("modality_loss_weights") + + for key, (output, loss_mask) in multimodal_output.items(): + cur_loss = self.loss_func(loss_mask.contiguous(), output.contiguous()) + if modality_weights is not None: + assert ( + key in modality_weights + ), f"Expected cfg.modality_loss_weights={modality_weights} to contain key {key}" + cur_loss = cur_loss * modality_weights[key] + if loss_for_ub is None: + loss_for_ub = cur_loss + else: + loss_for_ub += cur_loss + self.log( + f'{key}_loss', + cur_loss.mean(), + prog_bar=True, + batch_size=1, + rank_zero_only=False, ) - loss_for_ub = self.loss_func(scaled_loss_mask, output_tensor) - else: - loss_for_ub = self.loss_func(loss_mask, output_tensor) + self.log( + f'{key}_batch_size', + loss_mask.shape[0], + prog_bar=True, + batch_size=1, + rank_zero_only=False, + ) + if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['loss_mask'].sum() if loss_for_ub.isnan(): @@ -484,10 +531,20 @@ def loss_func(output_tensor): reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) return loss_for_ub, {'avg': reduced_loss} - return output_tensor, loss_func + return multimodal_output, loss_func return fwd_output_and_loss_func + def on_train_epoch_start(self) -> None: + app_state = AppState() + reconfigure_num_microbatches_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + def _build_dataset(self, data_cfg, is_train=True): return build_speechllm_dataset(self, data_cfg, is_train) @@ -920,6 +977,8 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): return outputs def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + # TODO: support text-only part of mini-batch + # the following supports STT (audio-text) inference batch = move_to_device(batch, device=self.device) encoder_input, attention_mask, enc_mask = self.prepare_llm_input(batch) @@ -1172,68 +1231,97 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): batch = next(dataloader_iter) # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} - _, seq_length = batch['tokens'].shape - # handle the case where the batch size from dynamic bucketting is not divisible in lhotse - data_iter = get_iterator_k_split(batch, get_num_microbatches(), enforce_divisible_batch=False) - - # handle asynchronous grad reduction - no_sync_func = None - grad_sync_func = None - param_sync_func = None - if not forward_only and self.with_distributed_adam: - no_sync_func = partial( - self._optimizer.no_sync, - greedy_grad_copy=self.megatron_amp_O2, - ) - grad_sync_func = self.reduce_overlap_gradients - param_sync_func = self.sync_overlap_parameters - - self.model.config.no_sync_func = no_sync_func - self.model.config.grad_sync_func = grad_sync_func - self.model.config.param_sync_func = param_sync_func - - fwd_bwd_function = get_forward_backward_func() - - dec_seq_length = batch['answers'].shape[1] - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), - data_iterator=data_iter, - model=[self.model], - num_microbatches=get_num_microbatches(), - forward_only=forward_only, - seq_length=seq_length, - micro_batch_size=get_micro_batch_size(), - decoder_seq_length=dec_seq_length, - ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - if (not forward_only) or self.cfg.data.get('validation_drop_last', True): - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + + # Note: We want to perform full fwd+bwd separately for each modality, + # as it allows us to save GPU memory. Otherwise, we'd have to + # hold the activations from one modality in memory while running + # forward for the other. + batch_losses = [] + for batch in (audio_batch, text_batch): + if not batch: + continue + + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() + batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + + # TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities, + # but I feel like this needs larger refactoring + if 'tokens' in batch and 'text_input_ids' in batch: + seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1]) + dec_seq_length = max(batch['answers'].shape[1], batch['text_answer_ids'].shape[1]) + elif 'tokens' in batch: + seq_length = batch['tokens'].shape[1] + dec_seq_length = batch['answers'].shape[1] + elif 'text_input_ids' in batch: + seq_length = batch['text_input_ids'].shape[1] + dec_seq_length = batch['text_answer_ids'].shape[1] else: - # Get the total loss since micro batches sizes are not uniform - loss_sum_tensors_list = [ - loss_sum['loss_sum_and_ub_size'] - for loss_sum in losses_reduced_per_micro_batch - if loss_sum['loss_sum_and_ub_size'][1] > 0 - ] - loss_sum = ( - torch.vstack(loss_sum_tensors_list).sum(axis=0) - if len(loss_sum_tensors_list) > 0 - else torch.tensor([0.0, 0.0]).cuda() + seq_length = None # TODO(pzelasko): not sure if it is even needed ??? + dec_seq_length = None + + # handle the case where the batch size from dynamic bucketting is not divisible in lhotse + data_iter = get_iterator_k_split(batch, get_num_microbatches(), enforce_divisible_batch=False) + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, ) - return loss_sum - else: - # we're not on the last pipeline stage so no losses - if forward_only: - loss_mean = [] + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + self.model.config.no_sync_func = no_sync_func + self.model.config.grad_sync_func = grad_sync_func + self.model.config.param_sync_func = param_sync_func + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(validation_step=forward_only), + data_iterator=data_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_mean = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) else: - loss_mean = torch.tensor(0.0).cuda() + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + batch_losses.append(loss_mean) + loss_mean = torch.cat(batch_losses).mean() return loss_mean def loss_func(self, loss_mask, output_tensor): @@ -1268,6 +1356,74 @@ def setup_mcore_distributed_parallel(self): if self.with_distributed_adam and self.use_mcore_dist_optim: raise ValueError("T5 does not support both distributed adam and mcore distributed data parallel.") + def oomptimizer_schema(self, schema: str = "audio") -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + + if schema == "audio": + return { + "cls": dict, + "inputs": [ + {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "tokens", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "tokens_length", + "type": NeuralType(("B",), LengthsType()), + "seq_length": "output", + }, + { + "name": "labels", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "loss_mask", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "output", + }, + { + "name": "context_start_idx", + "type": "constant", + "value": 0, + }, + ], + } + elif schema == "text": + # TODO: add support for text + # input_ids = text_batch["text_input_ids"][:, :-1] + # labels = text_batch["text_input_ids"][:, 1:] + # attention_mask = self._create_attention_mask(input_ids) + # loss_mask = text_batch["text_masks"][:, 1:] + + return { + "cls": dict, + "inputs": [ + { + "name": "text_context_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "input", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "text_answer_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + ], + } + else: + raise RuntimeError(f"Unknown schema type for oomptimizer of class {type(self)}: '{schema}'") + class DecoderTextPromptModularizedAudioT5Model(ModularizedAudioT5Model): """Modularized speech GPT model.""" From 8ca73d2da7d834404fe1026e6da3be186c3235ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 26 Aug 2024 11:46:41 -0400 Subject: [PATCH 17/63] megatron t5 nmt prompt formatter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 3 -- .../common/data/lhotse/text_adapters.py | 31 +++++++++++++------ nemo/collections/common/prompts/t5nmt.py | 24 ++++++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) create mode 100644 nemo/collections/common/prompts/t5nmt.py diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index d2ebd9eadc88..881000bb5c54 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -739,9 +739,6 @@ def tokenize(example: Example, tokenizer) -> Example: def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str | PromptFormatter) -> Example: - # TODO(pzelasko): This mechanism makes it possible to measure the actual output sequence length - # for prompted models such as AED MultiTask (Canary), which includes the transcript and the prompt. - # We intend to extend it for text modality in follow-up work. if isinstance(example, Cut): prompt_format_fn = get_prompt_format_fn(prompt_format) (tokenized_prompted_transcript,), (tokenized_prompt,), (tokenized_transcript,) = prompt_format_fn( diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 894e485d3111..0657ee4a3464 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -25,7 +25,8 @@ from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths -from nemo.collections.common.prompts import PromptFormatter +from nemo.collections.common.prompts import Llama2PromptFormatter, PromptFormatter +from nemo.collections.common.prompts.t5nmt import T5NMTPromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging @@ -108,16 +109,28 @@ def num_tokens(self) -> Optional[int]: return self.input_ids.shape[0] return None - def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) -> "TextExample": + def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) -> "SourceTargetTextExample": if prompt is not None: - # TODO(pzelasko): this is temporarily hardcoded and assumes LLama2 prompt format - ans = prompt.encode_dialog( - [ - {"role": "system_and_user", "slots": {"system": self.question.text, "message": self.source.text}}, - {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, - ] - ) + if isinstance(prompt, Llama2PromptFormatter): + ans = prompt.encode_dialog( + [ + { + "role": "system_and_user", + "slots": {"system": self.question.text, "message": self.source.text}, + }, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, + ] + ) + elif isinstance(prompt, T5NMTPromptFormatter): + ans = prompt.encode_dialog( + [ + {"role": "user", "slots": {"message": self.source.text}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, + ] + ) + else: + raise RuntimeError(f"Unexpected prompt formatter: {prompt}") self.input_ids = ans["input_ids"] self.context_ids = ans["context_ids"] self.answer_ids = ans["answer_ids"] diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py new file mode 100644 index 000000000000..4d17993eddd5 --- /dev/null +++ b/nemo/collections/common/prompts/t5nmt.py @@ -0,0 +1,24 @@ +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + + +class T5NMTPromptFormatter(PromptFormatter): + """ + The default prompt format for Megatron T5 based neural machine translation models. + """ + + NAME = "t5nmt" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"Q: |message|\n\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"A: |message|", + "slots": { + "message": Modality.Text, + }, + }, + } From b26b5dd7aae58d08a223524ffd87596433b72e78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 26 Aug 2024 12:41:11 -0700 Subject: [PATCH 18/63] Fixes for MT+AST T5 oomptimizer and training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_t5_models.py | 22 ++++++++++--------- .../oomptimizer-speechllm.py | 6 +++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 07e604936dee..5567da28a429 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -345,7 +345,7 @@ def prepare_llm_input(self, audio_batch): input_ids, input_length, labels, loss_mask = ( audio_batch['contexts'], audio_batch['context_lengths'], - audio_batch['labels'], + audio_batch['answers'], audio_batch['loss_mask'], ) @@ -1319,6 +1319,8 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): loss_mean = [] else: loss_mean = torch.tensor(0.0).cuda() + if loss_mean.ndim == 0: + loss_mean = loss_mean.unsqueeze(0) batch_losses.append(loss_mean) loss_mean = torch.cat(batch_losses).mean() @@ -1348,7 +1350,12 @@ def test_step(self, dataloader_iter, dataloader_idx=0): return self.inference_step(dataloader_iter, 'test') def training_step(self, dataloader_iter): - batch, batch_idx, dataloader_idx = next(dataloader_iter) + ans = next(dataloader_iter) + if isinstance(ans, tuple) and len(ans) == 3: + batch, batch_idx, dataloader_idx = ans + else: + batch = ans + batch_idx = 0 return super().training_step(itertools.chain([batch]), batch_idx=batch_idx) def setup_mcore_distributed_parallel(self): @@ -1369,18 +1376,18 @@ def oomptimizer_schema(self, schema: str = "audio") -> dict: {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, { - "name": "tokens", + "name": "contexts", "type": NeuralType(("B", "T"), LabelsType()), "seq_length": "output", "vocab_size": self.tokenizer.vocab_size, }, { - "name": "tokens_length", + "name": "context_lengths", "type": NeuralType(("B",), LengthsType()), "seq_length": "output", }, { - "name": "labels", + "name": "answers", "type": NeuralType(("B", "T"), LabelsType()), "seq_length": "output", "vocab_size": self.tokenizer.vocab_size, @@ -1390,11 +1397,6 @@ def oomptimizer_schema(self, schema: str = "audio") -> dict: "type": NeuralType(("B", "T"), MaskType()), "seq_length": "output", }, - { - "name": "context_start_idx", - "type": "constant", - "value": 0, - }, ], } elif schema == "text": diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py index f98d13281273..b9313c06950c 100755 --- a/scripts/speech_recognition/oomptimizer-speechllm.py +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -392,6 +392,12 @@ def oomptimizer( model.init_consumed_samples = 0 model._compute_consumed_samples_after_training_step = lambda *args, **kwargs: 1 + from megatron.core.parallel_state import initialize_model_parallel + from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo + initialize_model_parallel_for_nemo(world_size=1, global_rank=0, local_rank=0, micro_batch_size=16, global_batch_size=16) + torch.distributed.init_process_group("nccl", world_size=1, rank=0) + initialize_model_parallel() + if not hasattr(model, "oomptimizer_schema"): click.secho( f"We read model of type {type(model)} which doesn't seem to support OOMptimizer " From 850e49446d2c6b5355c5626ea5e2ffd47edca359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 07:12:19 -0700 Subject: [PATCH 19/63] configs, fixes, token-per-token filtering --- ...o_gpt_config_cross_llama_lhotse_multi.yaml | 4 +- .../conf/modular_audio_gpt_config_eval.yaml | 2 +- .../salm/modular_audio_t5_multi_config.yaml | 333 ++++++++++++++++++ .../common/data/lhotse/dataloader.py | 23 +- .../speech_llm/data/build_dataset.py | 2 + .../speech_llm/models/modular_models.py | 9 +- 6 files changed, 365 insertions(+), 8 deletions(-) create mode 100644 examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 436ccf0dca4e..89fded05024b 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -29,7 +29,7 @@ trainer: devices: 1 accelerator: gpu num_nodes: 1 - precision: 16 + precision: bf16-mixed logger: False # logger provided by exp_manager enable_checkpointing: False use_distributed_sampler: False @@ -240,7 +240,7 @@ model: multi_config: true audio: input_cfg: ??? - sampler_fusion: round_robin + sampler_fusion: zip seed: 0 shard_seed: "trng" batch_size: null diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml index 62b9030b4708..658485aa6807 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml @@ -104,7 +104,7 @@ model: prompt_template: ${data.train_ds.prompt_template} # don't change, let hydra resolve from saved config tokens_to_generate: 512 log_every_n_steps: 1 - sample_rate: ${data.train_ds.sample_rate} # don't change, let hydra resolve from saved config + sample_rate: 16000 # don't change, let hydra resolve from saved config audio_locator: null # set it to allow multiple audios in a sample, e.g. '|audio|', and use it in the context field of manifest to specify the locations of audios (`audio_filepath` is a list of audios). metric: diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml new file mode 100644 index 000000000000..7adaa99b21cb --- /dev/null +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -0,0 +1,333 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This configuration is similar to modular_audio_t5_multi_config.yaml, +# with the difference being in how it performs multimodal sampling. +# The changes are in model.data.train_ds section. +# You'll notice that it defines two sub-sections: audio and text. +# Their names are arbitrary in the sense that you may define more subsections as you like, also with repeated modalities. +# We still set up a single dataloader, but each sub-section produces its own sampler with its own batch size related settings. +# That means each sub-section may decide about its own static/dynamic batch sizes, bucketing, etc. +# These different samplers are later combined into a single sampler using one of three available sampler fusion strategies: +# round_robin (taking turns), randomized_round_robin (at each step select a sampler according to weights), +# or zip (sample mini-batch from each and combine them). +name: megatron_audio_t5_salm_lhotse_multi_sampler + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16-mixed + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_t5_models.ModularizedAudioT5Model + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + virtual_prompt_style: 'no-prompts' # make cls happy + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + global_batch_size: 128 + micro_batch_size: 4 + language_model_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: True + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # override_vocab_size: 1024 + + lora_tuning: + kqv_adapter_dim: 128 + kv_adapter_dim: 64 + q_adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + + peft: + peft_scheme: "adapter" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + train_ds: + use_lhotse: true + multi_config: true + audio: + input_cfg: ??? + sampler_fusion: zip + seed: 0 + shard_seed: "trng" + batch_size: null + batch_duration: 360 + quadratic_factor: 15 + use_bucketing: true + num_buckets: 30 + bucket_buffer_size: 20000 + num_workers: 4 + shuffle: true + text: + input_cfg: ??? + use_multimodal_sampling: true + batch_tokens: 8000 + quadratic_factor: 192 + use_bucketing: true + num_buckets: 30 + bucket_buffer_size: 20000 + num_workers: 4 + shuffle: true + + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + max_seq_length: 2048 + min_seq_length: 1 + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + add_sep: True + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + + validation_ds: + force_finite: true # workaround to allow using input_cfg + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # force_finite: true # workaround to allow using input_cfg + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'input' + # label_key: 'output' + # add_eos: ${model.data.train_ds.add_eos} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 881000bb5c54..d103ebbcb8be 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -108,8 +108,10 @@ class LhotseDataLoadingConfig: num_workers: int = 0 pin_memory: bool = False channel_selector: int | str | None = None - min_tps: int = -1 # allowed tokens per second + min_tps: int = -1 # allowed tokens per second (audio-only) max_tps: float = float("inf") + min_tpt: int = -1 # allowed tokens per token (text-only) + max_tpt: float = float("inf") # 4. Optional Lhotse data augmentation. # a. On-the-fly noise/audio mixing. @@ -393,6 +395,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No tokenizer = TokenizerWrapper(tokenizer) cuts = cuts.map(partial(tokenize, tokenizer=tokenizer), apply_fn=None) cuts = cuts.filter(TokenPerSecondFilter(config.min_tps, config.max_tps)) + cuts = cuts.filter(TokenPerTokenFilter(config.min_tpt, config.max_tpt)) # 2. Optional augmentations. # 2.a. Noise mixing. @@ -809,6 +812,24 @@ def __call__(self, example) -> bool: tps = _measure_tps(example) return self.tps_min <= tps <= self.tps_max +class TokenPerTokenFilter: + """ + Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) + is in range [tps_min, tps_max] and ``False`` otherwise. + """ + + def __init__(self, tpt_min: float, tpt_max: float) -> None: + assert tpt_min <= tpt_max + self.tpt_min = tpt_min + self.tpt_max = tpt_max + self.enabled = tpt_min > 0 or tpt_max < float("inf") + + def __call__(self, example) -> bool: + if isinstance(example, Cut) or not self.enabled: + return True # pass-through for non-text examples. + tpt = example.answer_ids.shape[0] / example.context_ids.shape[0] + return self.tpt_min <= tpt <= self.tpt_max + def _measure_tokens(cut: Cut) -> int: if hasattr(cut, "tokenized_prompted_transcript"): diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 4e471758ee11..906d7fe9ba0f 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -138,6 +138,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), dataset=dataset, + tokenizer=dataset.text_processor.tokenizer, ) ) else: @@ -158,6 +159,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), dataset=dataset, + tokenizer=dataset.text_processor.tokenizer, ) ) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 5e53920ceec5..1f12bf1cc3c5 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -476,7 +476,8 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): else: seq_length = None # TODO(pzelasko): not sure if it is even needed ??? - data_iter = get_iterator_k_split(batch, get_num_microbatches()) + data_iter = get_iterator_k_split(batch, 1) + #data_iter = get_iterator_k_split(batch, get_num_microbatches()) # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches # if log_token_counts: @@ -506,10 +507,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only), data_iterator=self._make_data_iterator_list(data_iter), model=self.model, - num_microbatches=get_num_microbatches(), + num_microbatches=1, #get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=get_micro_batch_size(), + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, #get_micro_batch_size(), first_val_step=first_val_step, ) @@ -1355,7 +1356,7 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int inference_config['end_strings'] = [self.cfg.data.end_string] global_batch_size_per_gpu = batch['tokens'].size(0) - num_micro_batches_before_decode = get_num_microbatches() + num_micro_batches_before_decode = 1 #get_num_microbatches() compute_logprob = inference_config.get('compute_logprob', False) if compute_logprob: From ffd32b17b4d56de671039a1a8516f466b3ae7420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 16:13:09 +0200 Subject: [PATCH 20/63] Support text modality in predict_step MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/modular_audio_gpt_text_eval.py | 118 ++++++++++++++++++ .../speech_llm/models/modular_t5_models.py | 43 ++++--- 2 files changed, 147 insertions(+), 14 deletions(-) create mode 100644 examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py b/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py new file mode 100644 index 000000000000..d76e479829fa --- /dev/null +++ b/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py @@ -0,0 +1,118 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging + +mp.set_start_method("spawn", force=True) + +""" +This is the script to run inference with a ModularAudioGPTModel. + +If you want to evaluate an ModularAudioGPTModel: + +MEGATRON_CKPT=/path/to/megatron-llm.nemo +ALM_DIR=/path/to/nemo_experiments/job_name +ALM_YAML=$ALM_DIR/version_0/hparams.yaml +ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.5-step\=103-epoch\=0-last.ckpt" + +VAL_MANIFESTS="[/data/libri-test-other.json,/data/MCV_7.1_test.json,/data/wsj-test.json]" +VAL_NAMES="[ls-test-other,mcv7.1-test,wsj-test]" + +HYDRA_FULL_ERROR=1 \ +CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ + model.restore_from_path=$MEGATRON_CKPT \ + model.peft.restore_from_path=$ALM_CKPT \ + model.peft.restore_from_hparams_path=$ALM_YAML \ + model.data.test_ds.manifest_filepath=$VAL_MANIFESTS \ + model.data.test_ds.names=$VAL_NAMES \ + model.data.test_ds.global_batch_size=8 \ + model.data.test_ds.micro_batch_size=8 \ + model.data.test_ds.tokens_to_generate=256 \ + ++inference.greedy=False \ + ++inference.top_k=50 \ + ++inference.top_p=0.95 \ + ++inference.temperature=0.4 \ + ++inference.repetition_penalty=1.2 \ + ++model.data.test_ds.output_dir=${ALM_DIR} +""" + + +@hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_eval") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + logging.info("**************************************************\n\n") + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + + if cfg.model.from_pretrained: + # Load model from NGC or HuggingFace + logging.info(f"Loading model from cloud: {cfg.model.from_pretrained}") + model_cfg = ModularAudioGPTModel.from_pretrained( + cfg.model.from_pretrained, trainer=trainer, return_config=True + ) + model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer, model_cfg) + model_file = ModularAudioGPTModel.from_pretrained( + cfg.model.from_pretrained, trainer=trainer, return_model_file=True + ) + model = ModularAudioGPTModel.restore_from( + restore_path=model_file, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + map_location="cpu", + ) + if "peft" in model_cfg and model_cfg.peft.get("peft_scheme", None): + # need this due to the way that MegatronGPTSFTModel doesn't load adapters in model initialization + model.load_adapters(model_file, map_location="cpu") + else: + # Load model from a local file + model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer) + model = ModularAudioGPTModel.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + map_location="cpu", + ) + model = ModularAudioGPTModel.load_adapters_for_inference(cfg, model_cfg, model) + model = ModularAudioGPTModel.load_audio_encoder_for_inference(cfg, model_cfg, model) + + model.freeze() + if cfg.get("save_as_nemo", None): + model.setup("predict") # need to call setup() to load adapters and prepare for saving + model.save_to(cfg.save_as_nemo) + logging.info(f"Model saved to {Path(cfg.save_as_nemo).absolute()}, exiting...") + exit(0) + + if not cfg.model.get('use_flash_attention', False): + cfg.inference.compute_attention_mask = True + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + + # run inference + trainer.test(model) + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 5567da28a429..d1b5098ffb27 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -977,27 +977,42 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): return outputs def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - # TODO: support text-only part of mini-batch # the following supports STT (audio-text) inference batch = move_to_device(batch, device=self.device) - encoder_input, attention_mask, enc_mask = self.prepare_llm_input(batch) - # enc_input = speech and text prompt - # dec_input and label = text output label - predicted_token_ids, log_probs = self.frozen_model.decode( - tokens_enc=None, - enc_mask=enc_mask, - num_tokens_to_generate=self._inference_config['tokens_to_generate'], - encoder_input=encoder_input, - tokenizer=self.tokenizer, - bos_id=self.bos_id, - ) + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + assert ( + audio_batch or text_batch and not (audio_batch or text_batch) + ), f"Expecting only text or audio batch, got {len(text_batch)=} and {len(audio_batch)=}" + + if 'audio_signal' in audio_batch: + input_text = audio_batch['contexts'] + labels = audio_batch['answers'] + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=None, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + encoder_input=encoder_input, + tokenizer=self.tokenizer, + bos_id=self.bos_id, + ) + if text_batch: + input_text = text_batch['text_context_ids'] + labels = text_batch["text_answer_ids"] + enc_mask = (input_text != self.tokenizer.pad_id).long().contiguous() + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=input_text, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + tokenizer=self.tokenizer, + bos_id=self.bos_id, + ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. - input_text = batch['contexts'] preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) input_text = MegatronT5SFTModel.ids_to_text(input_text, self.tokenizer) - labels = batch['answers'] if labels is not None: labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) From 024701f2d45498b6e2a40fba88e7ea6461af57e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 16:41:40 +0200 Subject: [PATCH 21/63] Support text data in val/test dl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../multimodal/speech_llm/data/build_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 906d7fe9ba0f..7b4f29d3872f 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -149,7 +149,13 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict assert len(input_cfg) == 1, "Only one dataset with multiple manifest paths is supported for eval" data_cfg.input_cfg = input_cfg # for getting names - manifest_filepath = [ic.manifest_filepath for ic in input_cfg[0].input_cfg] + manifest_filepath = [] + for ic in input_cfg[0].input_cfg: + if hasattr(ic, "manifest_filepath"): + manifest_filepath.append(ic.manifest_filepath) + else: + assert ic.type == "txt_pair" + manifest_filepath.append(ic.target_path) for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] From f574e70a752045409d6ee8353a51b289bf756126 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 17:15:56 +0200 Subject: [PATCH 22/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/multimodal/speech_llm/data/build_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 7b4f29d3872f..e6bd0172b4e3 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -155,7 +155,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict manifest_filepath.append(ic.manifest_filepath) else: assert ic.type == "txt_pair" - manifest_filepath.append(ic.target_path) + manifest_filepath.append(ic.target_paths) for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] From 2e2b3961af722f7aa9e1a9154fd2ca7c755fc723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 17:46:46 +0200 Subject: [PATCH 23/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/models/modular_t5_models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index d1b5098ffb27..f343d7d16a26 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -930,9 +930,14 @@ def _validation_step_internal( def inference_step(self, dataloader_iter, mode, dataloader_idx=0): batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds - self._reconfigure_and_process_inference_batch(batch, data_cfg) - # Meta data from dataset - metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + if "tokens" in batch: + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + else: + batch["tokens"] = batch["context_input_ids"] + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + batch.pop("tokens") loss = self._validation_step_internal(itertools.chain([batch]), batch_idx, dataloader_idx, result_mode=mode) # We need _inference_config to get generation params @@ -945,8 +950,8 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] - labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + inputs_text = output["inputs_text"] # [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] + labels_text = output["labels_text"] # [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] preds_text = output['preds_text'] if data_cfg.get("log_every_n_steps", None) is not None: if batch_idx % data_cfg.log_every_n_steps == 0: From dfdac5e025459642292169da2fa1b0a6f209ad35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 18:26:17 +0200 Subject: [PATCH 24/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../multimodal/speech_llm/models/modular_t5_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index f343d7d16a26..17db7e31fd0b 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -934,7 +934,7 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): self._reconfigure_and_process_inference_batch(batch, data_cfg) metadata = batch.get('metadata', [{}] * len(batch['tokens'])) else: - batch["tokens"] = batch["context_input_ids"] + batch["tokens"] = batch["text_context_ids"] self._reconfigure_and_process_inference_batch(batch, data_cfg) metadata = batch.get('metadata', [{}] * len(batch['tokens'])) batch.pop("tokens") From 81bd7327da85d8e6bfdc5b0c5e3a1dea88da15e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 18:34:45 +0200 Subject: [PATCH 25/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../multimodal/speech_llm/models/modular_t5_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 17db7e31fd0b..556bea15713e 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -988,7 +988,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} assert ( - audio_batch or text_batch and not (audio_batch or text_batch) + audio_batch or text_batch and not (audio_batch and text_batch) ), f"Expecting only text or audio batch, got {len(text_batch)=} and {len(audio_batch)=}" if 'audio_signal' in audio_batch: From 00edb6ddfc6d1646541c3164f7ba0dde79bbd632 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 18:45:52 +0200 Subject: [PATCH 26/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../multimodal/speech_llm/models/modular_t5_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 556bea15713e..19b4ba2da2b8 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -950,7 +950,7 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = output["inputs_text"] # [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] + inputs_text = output["input_text"] # [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] labels_text = output["labels_text"] # [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] preds_text = output['preds_text'] if data_cfg.get("log_every_n_steps", None) is not None: From 2eb63318138ea882fa376a40c3899746971de78d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 19:32:50 +0200 Subject: [PATCH 27/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/multimodal/speech_llm/data/build_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index e6bd0172b4e3..b707756ca804 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -159,6 +159,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] + conf.force_finite = True dls.append( get_lhotse_dataloader_from_config( conf, From cbaed3cc6afb5c0d1a5a373662708cd727657c87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 19:54:58 +0200 Subject: [PATCH 28/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/multimodal/speech_llm/data/build_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index b707756ca804..16188300d05f 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -14,6 +14,7 @@ import copy from pathlib import Path +import omegaconf import torch from megatron.core import parallel_state from omegaconf.omegaconf import OmegaConf @@ -159,7 +160,8 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] - conf.force_finite = True + with omegaconf.open_dict(conf): + conf.force_finite = True dls.append( get_lhotse_dataloader_from_config( conf, From e8ec5a42e0965aa0235b146b39f35e67f2220dfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 6 Sep 2024 20:11:14 +0200 Subject: [PATCH 29/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/multimodal/speech_llm/data/build_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 16188300d05f..40120ca1511b 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -160,8 +160,8 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] - with omegaconf.open_dict(conf): - conf.force_finite = True + OmegaConf.set_struct(conf, False) + conf.force_finite = True dls.append( get_lhotse_dataloader_from_config( conf, From 8c597b5347d3d6c15655c9726bb3b863e45e11c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 7 Sep 2024 05:59:58 +0200 Subject: [PATCH 30/63] fix infinite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 52a02fe0ede5..838ecf9e5a7e 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -189,18 +189,21 @@ def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: def read_txt_paths(config: DictConfig) -> CutSet: - return CutSet( + cuts = CutSet( LhotseTextAdapter( paths=config.paths, language=config.language, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts def read_txt_pair_paths(config: DictConfig) -> CutSet: - return CutSet( + cuts = CutSet( LhotseTextPairAdapter( source_paths=config.source_paths, target_paths=config.target_paths, @@ -211,18 +214,24 @@ def read_txt_pair_paths(config: DictConfig) -> CutSet: shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts def read_nemo_sft_jsonl(config: DictConfig) -> CutSet: - return CutSet( + cuts = CutSet( NeMoSFTJsonlAdapter( paths=config.paths, language=config.language, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts def attach_tags(cut, tags: dict): From 14a1896b1ce327531f4bab754f0289f4e7db8641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 7 Sep 2024 11:04:00 +0200 Subject: [PATCH 31/63] prompt format fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/text_adapters.py | 2 +- nemo/collections/common/prompts/t5nmt.py | 65 ++++++++++++++++++- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 0657ee4a3464..8a99f256f8da 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -125,7 +125,7 @@ def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) elif isinstance(prompt, T5NMTPromptFormatter): ans = prompt.encode_dialog( [ - {"role": "user", "slots": {"message": self.source.text}}, + {"role": "user", "slots": {"message": self.source.text, "target_lang": self.target.language}}, {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, ] ) diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index 4d17993eddd5..cb8dc69e1541 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -1,24 +1,85 @@ +import torch +from lhotse import CutSet, MonoCut +from lhotse.cut import MixedCut + +from nemo.collections.common.prompts import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter +from nemo.collections.common.tokenizers import TokenizerSpec class T5NMTPromptFormatter(PromptFormatter): """ The default prompt format for Megatron T5 based neural machine translation models. + Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py#L790 """ NAME = "t5nmt" OUTPUT_ROLE = "assistant" TEMPLATE = { "user": { - "template": f"Q: |message|\n\n", + "template": f"|target_lang||message|", "slots": { + "target_lang": Modality.Text, "message": Modality.Text, }, }, OUTPUT_ROLE: { - "template": f"A: |message|", + "template": f"|message|", "slots": { "message": Modality.Text, }, }, } + + def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: dict) -> list[int]: + # Automatically adds "<" and ">" to target lang token for T5 NMT. + # Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py#L307 + if (val := slot_values.get("target_lang")) is not None: + if not val.startswith("<") or not val.endswith(">"): + slot_values["target_lang"] = f"<{val}>" + return super().encode_turn( + prompt_template=prompt_template, expected_slots=expected_slots, slot_values=slot_values + ) + + +@registered_prompt_format_fn +def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: + formatter = T5NMTPromptFormatter(tokenizer) + + prompts_with_answers, prompts, answers = [], [], [] + for cut in cuts: + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" + ) + + assert hasattr(cut, "context"), cut + + turns = [ + dict( + role="user", + # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. + slots={"target_lang": cut.context, "message": ""}, + ), + ] + if len(cut.supervisions) > 1 and cut.supervisions[0].text: + turns.append( + dict( + role="assistant", + slots={"message": cut.supervisions[0].text}, + ) + ) + encoded = formatter.encode_dialog(turns) + prompts_with_answers.append(encoded["input_ids"]) + prompts.append(encoded["context_ids"]) + if "answer_ids" in encoded: + assert ( + encoded["answer_ids"][-1].item() == formatter.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" + answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS + else: + answers.append([]) + + return prompts_with_answers, prompts, answers From 5c382cd93d2e0bbd2d48b4ce0eada5450c420860 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 10 Sep 2024 07:32:41 -0700 Subject: [PATCH 32/63] Fixes in audio supervision MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../conf/salm/modular_audio_t5_multi_config.yaml | 10 ++++++++-- nemo/collections/common/prompts/t5nmt.py | 7 ++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml index 7adaa99b21cb..094fb49f9d7d 100644 --- a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -75,6 +75,7 @@ exp_manager: model: virtual_prompt_style: 'no-prompts' # make cls happy + audio_prompt_first: False seed: 1234 tensor_model_parallel_size: 1 # intra-layer model parallelism pipeline_model_parallel_size: 1 # inter-layer model parallelism @@ -218,7 +219,8 @@ model: multi_config: true audio: input_cfg: ??? - sampler_fusion: zip + sampler_fusion: round_robin + prompt_format: t5nmt seed: 0 shard_seed: "trng" batch_size: null @@ -231,6 +233,7 @@ model: shuffle: true text: input_cfg: ??? + prompt_format: t5nmt use_multimodal_sampling: true batch_tokens: 8000 quadratic_factor: 192 @@ -253,10 +256,12 @@ model: separate_prompt_and_response_with_newline: False truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + prompt_template: "{context}{answer}" + #prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" validation_ds: force_finite: true # workaround to allow using input_cfg + prompt_format: t5nmt global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False @@ -289,6 +294,7 @@ model: # test_ds: # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. # force_finite: true # workaround to allow using input_cfg + # prompt_format: t5nmt # names: null # Names of the corresponding datasets used to log metrics. # global_batch_size: ${model.global_batch_size} # micro_batch_size: ${model.micro_batch_size} diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index cb8dc69e1541..429c5340a14a 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -64,7 +64,7 @@ def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], l slots={"target_lang": cut.context, "message": ""}, ), ] - if len(cut.supervisions) > 1 and cut.supervisions[0].text: + if len(cut.supervisions) > 0 and cut.supervisions[0].text: turns.append( dict( role="assistant", @@ -75,10 +75,7 @@ def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], l prompts_with_answers.append(encoded["input_ids"]) prompts.append(encoded["context_ids"]) if "answer_ids" in encoded: - assert ( - encoded["answer_ids"][-1].item() == formatter.tokenizer.eos - ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" - answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS + answers.append(encoded["answer_ids"]) else: answers.append([]) From 6e276caf244ae60e2b9eac7009e6d96447801181 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 10 Sep 2024 11:17:27 -0400 Subject: [PATCH 33/63] remove superficial padding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/data/lhotse_dataset.py | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 985aa241ce93..230164f9558d 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -150,32 +150,38 @@ def collate_text_data( def get_max_len(input_list): return max([len(x) for x in input_list]) - max_length = tokens_to_generate + max( - get_max_len(fields["input_ids"]), get_max_len(fields["context_ids"]), get_max_len(fields["answer_ids"]) - ) + # max_length = tokens_to_generate + max( + # get_max_len(fields["input_ids"]), get_max_len(fields["context_ids"]), get_max_len(fields["answer_ids"]) + # ) + input_id_maxlen = get_max_len(fields["input_ids"]) + context_id_maxlen = get_max_len(fields["context_ids"]) + answer_id_maxlen = get_max_len(fields["answer_ids"]) # increase max length to nearest multiple of 4 or 8 if pad_to_max_length: - max_length = max_seq_length - else: - max_length = min(max_seq_length, ceil_to_nearest(max_length, 8)) - - all_tokens = collate_vectors(fields["input_ids"], max_length=max_length, padding_value=pad_id) + input_id_maxlen = max_seq_length + context_id_maxlen = max_seq_length + answer_id_maxlen = max_seq_length + # max_length = max_seq_length + # else: + # max_length = min(max_seq_length, ceil_to_nearest(max_length, 8)) + + all_tokens = collate_vectors(fields["input_ids"], max_length=input_id_maxlen, padding_value=pad_id) full_lengths = torch.LongTensor([len(item) for item in fields["input_ids"]]) - assert max_length <= max_seq_length, f"{max_length=} <= {max_seq_length=}" + assert input_id_maxlen <= max_seq_length, f"{input_id_maxlen=} <= {max_seq_length=}" return { "tokens": all_tokens[:, :-1], "tokens_length": full_lengths - 1, "labels": all_tokens[:, 1:], "loss_mask": collate_vectors( - [torch.as_tensor(build_loss_mask(item)) for item in examples], max_length=max_length, padding_value=0 + [torch.as_tensor(build_loss_mask(item)) for item in examples], max_length=input_id_maxlen, padding_value=0 )[:, 1:], - "position_ids": torch.arange(max_length, dtype=torch.long).repeat(batch_size, 1), - "contexts": collate_vectors(fields["context_ids"], max_length=max_length, padding_value=pad_id), + "position_ids": torch.arange(input_id_maxlen, dtype=torch.long).repeat(batch_size, 1), + "contexts": collate_vectors(fields["context_ids"], max_length=context_id_maxlen, padding_value=pad_id), "context_lengths": torch.LongTensor([len(seq) for seq in fields["context_ids"]]), - "answers": collate_vectors(fields["answer_ids"], max_length=max_length, padding_value=pad_id), - "max_length": torch.LongTensor([max_length] * batch_size), + "answers": collate_vectors(fields["answer_ids"], max_length=answer_id_maxlen, padding_value=pad_id), + "max_length": torch.LongTensor([input_id_maxlen] * batch_size), } From 2fae9f9a0fd1640dbef4f4793c90a41b2394fc77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 10 Sep 2024 08:25:29 -0700 Subject: [PATCH 34/63] test config and prompt context fetching fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../salm/modular_audio_t5_multi_config.yaml | 52 +++++++++---------- nemo/collections/common/prompts/t5nmt.py | 9 +++- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml index 094fb49f9d7d..b8472fdf53fd 100644 --- a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -291,32 +291,32 @@ model: average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. num_classes: null - # test_ds: - # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. - # force_finite: true # workaround to allow using input_cfg - # prompt_format: t5nmt - # names: null # Names of the corresponding datasets used to log metrics. - # global_batch_size: ${model.global_batch_size} - # micro_batch_size: ${model.micro_batch_size} - # shuffle: False - # num_workers: 4 - # pin_memory: True - # max_seq_length: 2048 - # min_seq_length: 1 - # drop_last: False - # context_key: 'input' - # label_key: 'output' - # add_eos: ${model.data.train_ds.add_eos} - # add_sep: ${model.data.train_ds.add_sep} - # add_bos: ${model.data.train_ds.add_bos} - # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} - # write_predictions_to_file: False - # output_file_path_prefix: null # Prefix of the file to write predictions to. - # truncation_field: "context" # Options: ['context', 'answer'] - # index_mapping_dir: null # Path to a directory to write index mapping files. - # prompt_template: ${model.data.train_ds.prompt_template} - # # ASR configs - # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + test_ds: + manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg + prompt_format: t5nmt + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 4 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} # metric: # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index 429c5340a14a..8a91b7ef56bc 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -55,13 +55,18 @@ def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], l f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" ) - assert hasattr(cut, "context"), cut + if hasattr(cut, "context"): + context = cut.context + elif hasattr(cut, "default_context"): + context = cut.default_context + else: + raise RuntimeError("Missing context/default_context custom field in cut: {cut}") turns = [ dict( role="user", # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. - slots={"target_lang": cut.context, "message": ""}, + slots={"target_lang": context, "message": ""}, ), ] if len(cut.supervisions) > 0 and cut.supervisions[0].text: From 34f85262a1d6a05a6b0c179fbffcda98ea4ea432 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 10 Sep 2024 13:24:07 -0400 Subject: [PATCH 35/63] support text-only decoding for salm/bestow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_llm/data/lhotse_dataset.py | 3 ++ .../speech_llm/models/modular_models.py | 50 +++++++++++++------ 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py index 230164f9558d..33cdcd054130 100644 --- a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -115,8 +115,11 @@ def __getitem__(self, all_cuts: CutSet) -> dict[str, torch.Tensor | list[str] | pad_id = self.text_processor.pad_id text_minibatch = dict( text_input_ids=collate_vectors_lhotse([e.input_ids for e in text_examples], padding_value=pad_id), + text_input_lens=torch.tensor([len(e.input_ids) for e in text_examples], dtype=torch.int64), text_answer_ids=collate_vectors_lhotse([e.answer_ids for e in text_examples], padding_value=pad_id), + text_answer_lens=torch.tensor([len(e.answer_ids) for e in text_examples], dtype=torch.int64), text_context_ids=collate_vectors_lhotse([e.context_ids for e in text_examples], padding_value=pad_id), + text_context_lens=torch.tensor([len(e.context_ids) for e in text_examples], dtype=torch.int64), text_masks=collate_vectors_lhotse([e.mask for e in text_examples], padding_value=0), ) ans.update(text_minibatch) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 1f12bf1cc3c5..457acb014352 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -477,7 +477,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): seq_length = None # TODO(pzelasko): not sure if it is even needed ??? data_iter = get_iterator_k_split(batch, 1) - #data_iter = get_iterator_k_split(batch, get_num_microbatches()) + # data_iter = get_iterator_k_split(batch, get_num_microbatches()) # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches # if log_token_counts: @@ -507,10 +507,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only), data_iterator=self._make_data_iterator_list(data_iter), model=self.model, - num_microbatches=1, #get_num_microbatches(), + num_microbatches=1, # get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.data.train_ds.micro_batch_size, #get_micro_batch_size(), + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, # get_micro_batch_size(), first_val_step=first_val_step, ) @@ -1254,9 +1254,14 @@ def inference_step(self, dataloader_iter, mode): # Evaluation of multimodal data follows the same pattern as training except predict_step batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds - self._reconfigure_and_process_inference_batch(batch, data_cfg) - # Meta data from dataset - metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + if "tokens" in batch: + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + else: + batch["tokens"] = batch["text_input_ids"] + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + batch.pop("tokens") loss = super(MegatronGPTSFTModel, self).validation_step(itertools.chain([batch]), dataloader_idx) # We need _inference_config to get generation params @@ -1269,12 +1274,22 @@ def inference_step(self, dataloader_iter, mode): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] - labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] - preds_text = [ - self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) - for t, l in zip(output['token_ids'], batch['context_lengths']) - ] + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + if audio_batch: + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in audio_batch['contexts']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in audio_batch['answers']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], audio_batch['context_lengths']) + ] + else: + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in text_batch['text_context_ids']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in text_batch['text_answer_ids']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], text_batch['text_context_lens']) + ] if data_cfg.get("end_string", None): # sometimes data_cfg.end_string != self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string)) @@ -1356,7 +1371,7 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int inference_config['end_strings'] = [self.cfg.data.end_string] global_batch_size_per_gpu = batch['tokens'].size(0) - num_micro_batches_before_decode = 1 #get_num_microbatches() + num_micro_batches_before_decode = 1 # get_num_microbatches() compute_logprob = inference_config.get('compute_logprob', False) if compute_logprob: @@ -1371,6 +1386,12 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int # for megatron_gpt_eval.py if isinstance(batch, list): inference_config['inputs'] = batch + elif "text_context_ids" in batch: + # Text mini-batch + inference_config['inputs'] = ( + batch['text_context_ids'].cuda(), + batch['text_context_lens'].cuda(), + ) elif 'num_audios' in batch: # peft_eval.py inference_config['inputs'] = ( @@ -1401,7 +1422,8 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int ) # add audio offsets to context lengths for properly decoding only the response - batch['context_lengths'] = batch['context_lengths'].cuda() + response['audio_feat_lens'] + if 'context_lengths' in batch: + batch['context_lengths'] = batch['context_lengths'].cuda() + response['audio_feat_lens'] return response From 949f1efdfe501123498f935dd01f9dc09b75f1e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 4 Oct 2024 15:14:05 +0000 Subject: [PATCH 36/63] Add unit tests for EMMETT / refactor prompt_format_fn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 9 ++--- .../common/data/lhotse/dataloader.py | 19 +++------ nemo/collections/common/prompts/canary.py | 40 +++++++++---------- nemo/collections/common/prompts/llama.py | 3 +- nemo/collections/common/prompts/t5nmt.py | 20 ++++------ .../multimodal/test_speechllm_dataset.py | 11 +++-- 6 files changed, 43 insertions(+), 59 deletions(-) diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 89dcc61655e8..912188cb9c02 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -76,14 +76,13 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) # Fast-path: the tokenization and prompt formatting was already done before sampling. - attrs = ("tokenized_prompt", "tokenized_transcript", "tokenized_prompted_transcript") + attrs = ("input_ids", "context_ids", "answer_ids") pre_formatted = all(hasattr(c, a) for c in cuts for a in attrs) if pre_formatted: - prompts_with_answers, prompts, answers = zip( - *((c.tokenized_prompted_transcript, c.tokenized_prompt, c.tokenized_transcript) for c in cuts) - ) + prompts_with_answers, prompts, answers = zip(*((c.input_ids, c.context_ids, c.answer_ids) for c in cuts)) else: - prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer) + ans = self.prompt_format_fn(cuts, self.tokenizer) + prompts_with_answers, prompts, answers = ans["input_ids"], ans["context_ids"], ans["answer_ids"] transcript, transcript_lens = self._collate_tokens(answers) prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 709c5782c46a..4587573b3b3b 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -750,17 +750,10 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str | Promp if isinstance(example, Cut): prompt_format_fn = get_prompt_format_fn(prompt_format) ans = prompt_format_fn(CutSet([example]), tokenizer) - if isinstance(ans, tuple): - (tokenized_prompted_transcript,), (tokenized_prompt,), (tokenized_transcript,) = ans - example.tokenized_prompted_transcript = tokenized_prompted_transcript - example.tokenized_prompt = tokenized_prompt - example.tokenized_transcript = tokenized_transcript - elif isinstance(ans, dict): - example.tokenized_prompted_transcript = ans["input_ids"][0] - example.tokenized_prompt = ans["context_ids"][0] - example.tokenized_transcript = ans["answer_ids"][0] - else: - raise RuntimeError(f"Unexpected return type from prompt_format_fn (must be dict or tuple): {ans}") + example.input_ids = ans["input_ids"][0] + example.context_ids = ans["context_ids"][0] + example.answer_ids = ans["answer_ids"][0] + example.answer_mask = ans["mask"][0] elif isinstance(example, NeMoMultimodalConversation): example = example.tokenize(tokenizer, prompt_format) else: @@ -846,8 +839,8 @@ def __call__(self, example) -> bool: def _measure_tokens(cut: Cut) -> int: - if hasattr(cut, "tokenized_prompted_transcript"): - return len(cut.tokenized_prompted_transcript) # tokenized with prompt formatter + if hasattr(cut, "input_ids"): + return len(cut.input_ids) # tokenized with prompt formatter supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] assert len(supervisions_with_tokens) > 0, ( "Cannot measure tokens-per-second with untokenized supervisions. " diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index 0eb3296bcff9..d06e01b50666 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -1,3 +1,4 @@ +from collections import defaultdict from typing import Any import torch @@ -95,9 +96,7 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s @registered_prompt_format_fn -def canary( - cuts: CutSet, tokenizer: TokenizerSpec -) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: +def canary(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -120,9 +119,9 @@ def canary( (i.e., spoken language in the recording) and the second occurrence is for the "target" language (i.e., the language in which we are going to output the text). """ - formatter = CanaryPromptFormatter(tokenizer) + prompt = CanaryPromptFormatter(tokenizer) - prompts_with_answers, prompts, answers = [], [], [] + ans = defaultdict(list) for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut @@ -132,7 +131,7 @@ def canary( ) # first, validate the utterance - expected_slots = set(formatter.get_slots("user")) + expected_slots = set(prompt.get_slots("user")) missing_keys = expected_slots - set(cut.custom) if "task" in missing_keys and "taskname" in cut.custom: # Compatibility with "old" Canary manifest format. @@ -150,7 +149,7 @@ def canary( role="user", slots={ **{slot: cut.custom[slot] for slot in expected_slots}, - formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, + prompt.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, }, ) ] @@ -161,21 +160,18 @@ def canary( role="assistant", slots={ "text": text, - formatter.PROMPT_LANGUAGE_SLOT: ifnone( - cut.supervisions[0].language, cut.custom.get("target_lang") - ), + prompt.PROMPT_LANGUAGE_SLOT: ifnone(cut.supervisions[0].language, cut.custom.get("target_lang")), }, ), ) - encoded = formatter.encode_dialog(turns) - prompts_with_answers.append(encoded["input_ids"]) - prompts.append(encoded["context_ids"]) - if "answer_ids" in encoded: - assert ( - encoded["answer_ids"][-1].item() == formatter.tokenizer.eos - ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" - answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS - else: - answers.append([]) - - return prompts_with_answers, prompts, answers + + for k, v in prompt.encode_dialog(turns).items(): + if k == "answer_ids": + assert ( + v[-1].item() == prompt.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {v}" + ans[k].append(v[:-1]) # Strip Canary's EOS + else: + ans[k].append(v) + + return ans diff --git a/nemo/collections/common/prompts/llama.py b/nemo/collections/common/prompts/llama.py index 7b2e1fe1d758..4364af608497 100644 --- a/nemo/collections/common/prompts/llama.py +++ b/nemo/collections/common/prompts/llama.py @@ -1,5 +1,6 @@ from collections import defaultdict +import torch from lhotse import CutSet from lhotse.cut import MixedCut from lhotse.utils import ifnone @@ -41,7 +42,7 @@ class Llama2PromptFormatter(PromptFormatter): @registered_prompt_format_fn -def llama2(cuts: CutSet, tokenizer: TokenizerSpec): +def llama2(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: prompt = Llama2PromptFormatter(tokenizer) ans = defaultdict(list) for cut in cuts: diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index 8a91b7ef56bc..b9acb050ccb9 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -1,3 +1,4 @@ +from collections import defaultdict import torch from lhotse import CutSet, MonoCut from lhotse.cut import MixedCut @@ -17,7 +18,7 @@ class T5NMTPromptFormatter(PromptFormatter): OUTPUT_ROLE = "assistant" TEMPLATE = { "user": { - "template": f"|target_lang||message|", + "template": f"|target_lang| |message|", "slots": { "target_lang": Modality.Text, "message": Modality.Text, @@ -43,10 +44,10 @@ def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: d @registered_prompt_format_fn -def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: - formatter = T5NMTPromptFormatter(tokenizer) +def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: + prompt = T5NMTPromptFormatter(tokenizer) - prompts_with_answers, prompts, answers = [], [], [] + ans = defaultdict(list) for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut @@ -76,12 +77,7 @@ def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> tuple[list[torch.Tensor], l slots={"message": cut.supervisions[0].text}, ) ) - encoded = formatter.encode_dialog(turns) - prompts_with_answers.append(encoded["input_ids"]) - prompts.append(encoded["context_ids"]) - if "answer_ids" in encoded: - answers.append(encoded["answer_ids"]) - else: - answers.append([]) + for k, v in prompt.encode_dialog(turns).items(): + ans[k].append(v) - return prompts_with_answers, prompts, answers + return ans diff --git a/tests/collections/multimodal/test_speechllm_dataset.py b/tests/collections/multimodal/test_speechllm_dataset.py index de554a219ca4..b4c51c4fc978 100644 --- a/tests/collections/multimodal/test_speechllm_dataset.py +++ b/tests/collections/multimodal/test_speechllm_dataset.py @@ -84,7 +84,6 @@ def test_speechllm_dataset(tokenizer, cuts): ) batch = dataset[cuts] - print(batch) expected_keys = { "sample_ids", @@ -368,8 +367,8 @@ def test_speechllm_dataset_tokens_to_generate_increases_seq_len(llama_tokenizer, max_seq_length=512, ) batch = dataset[cuts] - assert batch["tokens"].shape == (1, 347) # was 351 before padding optimization - assert batch["labels"].shape == (1, 347) # was 351 before padding optimization - assert batch["contexts"].shape == (1, 337) # was 352 before padding optimization - assert batch["answers"].shape == (1, 267) # was 352 before padding optimization - assert batch["position_ids"].shape == (1, 348) # was 352 before padding optimization + assert batch["tokens"].shape == (1, 91) + assert batch["labels"].shape == (1, 91) + assert batch["contexts"].shape == (1, 337) + assert batch["answers"].shape == (1, 11) + assert batch["position_ids"].shape == (1, 92) From d91348eab76dd88122ca8eff170ed898efec9605 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 4 Oct 2024 18:17:05 +0000 Subject: [PATCH 37/63] make t5nmt prompt formatter auto discoverable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/prompts/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py index 99950f2b3a98..e4b7785c6243 100644 --- a/nemo/collections/common/prompts/__init__.py +++ b/nemo/collections/common/prompts/__init__.py @@ -10,3 +10,4 @@ Phi2QAPromptFormatter, ) from nemo.collections.common.prompts.plain import PlainPromptFormatter +from nemo.collections.common.prompts.t5nmt import T5NMTPromptFormatter From 39543a9527bf7b6a62e9905779888e1a186ac2d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 4 Oct 2024 19:06:10 +0000 Subject: [PATCH 38/63] include token count / tpt filtering in estimate_token_bins MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../speech_recognition/estimate_token_bins.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index fca14bce4285..881dfcc9fdb0 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -33,7 +33,9 @@ LhotseDataLoadingConfig, MultimodalFixedBucketBatchSizeConstraint2D, MultimodalSamplingConstraint, + TokenCountFilter, TokenPerSecondFilter, + TokenPerTokenFilter, tokenize, tokenize_with_prompt, ) @@ -87,26 +89,26 @@ def parse_args(): help="The number of examples (utterances) to estimate the bins. -1 means use all data " "(be careful: it could be iterated over infinitely).", ) - # parser.add_argument( - # "-l", - # "--min_tokens", - # type=float, - # default=-float("inf"), - # help="If specified, we'll filter out examples with less tokens than this number.", - # ) - # parser.add_argument( - # "-u", - # "--max_tokens", - # type=float, - # default=float("inf"), - # help="If specified, we'll filter out examples with more tokens than this number.", - # ) - # parser.add_argument( - # "--max_tpt", - # type=float, - # default=float("inf"), - # help="If specified, we'll filter out examples with more output tokens per input token than this. " - # ) + parser.add_argument( + "-l", + "--min_tokens", + type=float, + default=-float("inf"), + help="If specified, we'll filter out examples with less tokens than this number.", + ) + parser.add_argument( + "-u", + "--max_tokens", + type=float, + default=float("inf"), + help="If specified, we'll filter out examples with more tokens than this number.", + ) + parser.add_argument( + "--max_tpt", + type=float, + default=float("inf"), + help="If specified, we'll filter out examples with more output tokens per input token than this. ", + ) parser.add_argument( "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." ) @@ -286,13 +288,13 @@ def main(): OmegaConf.from_dotlist([f"input_cfg={args.input}"]), ) cuts, _ = read_cutset_from_config(config) - # duration_filter = RejectionsCounter(DurationFilter(args.min_duration, args.max_duration), "Duration filtering") - # cuts = cuts.filter(duration_filter) + token_filter = RejectionsCounter(TokenCountFilter(args.min_tokens, args.max_tokens), "Token count filtering") + cuts = cuts.filter(token_filter) cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) - if hasattr(cuts, "prefetch"): - cuts = cuts.prefetch() # to be released in lhotse 1.27 - # tpt_filter = RejectionsCounter(TokensPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") - # cuts = cuts.filter(tpt_filter) + # if hasattr(cuts, "prefetch"): + # cuts = cuts.prefetch() # to be released in lhotse 1.27 + tpt_filter = RejectionsCounter(TokenPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") + cuts = cuts.filter(tpt_filter) if (N := args.num_examples) > 0: cuts = islice(cuts, N) @@ -309,8 +311,8 @@ def main(): if args.quiet: print(token_bins) return - # duration_filter.print_report() - # tps_filter.print_report() + token_filter.print_report() + tpt_filter.print_report() print("Use the following options in your config:") print(f"\tnum_buckets={args.buckets}") print(f"\tbucket_duration_bins={token_bins}") From b684750cdd753da3c03d62b2cf71e5b64a4b639e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 4 Oct 2024 19:13:11 +0000 Subject: [PATCH 39/63] fix max token filter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/speech_recognition/estimate_token_bins.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 881dfcc9fdb0..5a6d8ea23272 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -288,11 +288,11 @@ def main(): OmegaConf.from_dotlist([f"input_cfg={args.input}"]), ) cuts, _ = read_cutset_from_config(config) + cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) + if hasattr(cuts, "prefetch"): + cuts = cuts.prefetch() # to be released in lhotse 1.27 token_filter = RejectionsCounter(TokenCountFilter(args.min_tokens, args.max_tokens), "Token count filtering") cuts = cuts.filter(token_filter) - cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) - # if hasattr(cuts, "prefetch"): - # cuts = cuts.prefetch() # to be released in lhotse 1.27 tpt_filter = RejectionsCounter(TokenPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") cuts = cuts.filter(tpt_filter) if (N := args.num_examples) > 0: From 6064bb42925fbc2217038ab394123acd0381f526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 9 Oct 2024 13:20:57 -0700 Subject: [PATCH 40/63] some fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 5 +++-- nemo/collections/common/prompts/t5nmt.py | 7 ++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 4587573b3b3b..b4221c4b1d91 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -752,8 +752,9 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str | Promp ans = prompt_format_fn(CutSet([example]), tokenizer) example.input_ids = ans["input_ids"][0] example.context_ids = ans["context_ids"][0] - example.answer_ids = ans["answer_ids"][0] - example.answer_mask = ans["mask"][0] + if "answer_ids" in ans: + example.answer_ids = ans["answer_ids"][0] + example.answer_mask = ans["mask"][0] elif isinstance(example, NeMoMultimodalConversation): example = example.tokenize(tokenizer, prompt_format) else: diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index b9acb050ccb9..6ec69862dc4a 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -70,14 +70,15 @@ def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: slots={"target_lang": context, "message": ""}, ), ] - if len(cut.supervisions) > 0 and cut.supervisions[0].text: + if len(cut.supervisions) > 0 and cut.supervisions[0].text is not None: turns.append( dict( - role="assistant", + role=prompt.OUTPUT_ROLE, slots={"message": cut.supervisions[0].text}, ) ) - for k, v in prompt.encode_dialog(turns).items(): + enc = prompt.encode_dialog(turns) + for k, v in enc.items(): ans[k].append(v) return ans From 92c81bbb7d561fbe84ce5c08b8c3a17e1b4901de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 14 Oct 2024 20:01:36 -0400 Subject: [PATCH 41/63] custom mixin for text adapters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 4 ++-- .../common/data/lhotse/text_adapters.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index ad276957eac2..796a4dc589b7 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -223,8 +223,8 @@ def read_txt_pair_paths(config: DictConfig) -> CutSet: target_paths=config.target_paths, source_language=config.source_language, target_language=config.target_language, - questions_path=config.questions_path, - questions_language=config.questions_language, + questions_path=config.get("questions_path"), + questions_language=config.get("questions_language"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 76ed74740a2a..352844bbb518 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -22,6 +22,7 @@ import numpy as np import torch from lhotse import Recording +from lhotse.custom import CustomFieldMixin from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.serialization import load_jsonl @@ -41,7 +42,7 @@ @dataclass -class TextExample: +class TextExample(CustomFieldMixin): """ Represents a single text example. Useful e.g. for language modeling. """ @@ -93,7 +94,7 @@ def __iter__(self) -> Iterator[TextExample]: @dataclass -class SourceTargetTextExample: +class SourceTargetTextExample(CustomFieldMixin): """ Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks. Supports a ``question`` field, used as the prompt for LLM. @@ -127,9 +128,12 @@ def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) ] ) elif isinstance(prompt, T5NMTPromptFormatter): + ctx = f"<{self.target.language}>" + if self.has_custom("extra_prompt"): + ctx = f"{ctx} {self.extra_prompt}" ans = prompt.encode_dialog( [ - {"role": "user", "slots": {"message": self.source.text, "target_lang": self.target.language}}, + {"role": "user", "slots": {"message": self.source.text, "target_lang": ctx}}, {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, ] ) @@ -221,7 +225,7 @@ def __iter__(self) -> Iterator[SourceTargetTextExample]: @dataclass -class NeMoSFTExample: +class NeMoSFTExample(CustomFieldMixin): data: dict language: str | None = None input_ids: np.ndarray | None = None @@ -353,7 +357,7 @@ class AudioTurn: @dataclass -class NeMoMultimodalConversation: +class NeMoMultimodalConversation(CustomFieldMixin): id: str turns: list[TextTurn | AudioTurn] input_ids: np.ndarray | None = None From 68e27dbe955b4f33457b9defb68a26d967baa976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 16 Oct 2024 06:45:36 -0700 Subject: [PATCH 42/63] Warmup in oomptimizer-speechlm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../oomptimizer-speechllm.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_recognition/oomptimizer-speechllm.py index b9313c06950c..3761099b614d 100755 --- a/scripts/speech_recognition/oomptimizer-speechllm.py +++ b/scripts/speech_recognition/oomptimizer-speechllm.py @@ -247,7 +247,6 @@ def type_cast_value(self, ctx, value): @click.option( "-c", "--config-path", type=str, default=None, help="Path to the training configuration file for MODULE_NAME." ) -@click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.") @click.option( "--schema", type=str, @@ -318,7 +317,6 @@ def oomptimizer( pretrained_name: str | None, module_name: str | None, config_path: str | None, - optimizer_name: str, schema: str, buckets: list[float], threshold: float, @@ -388,7 +386,6 @@ def oomptimizer( model.log = lambda *args, **kwargs: None model_clones.append(model) model = model_clones[-1] - # model.setup(stage="fit") model.init_consumed_samples = 0 model._compute_consumed_samples_after_training_step = lambda *args, **kwargs: 1 @@ -409,7 +406,21 @@ def oomptimizer( schema = model.oomptimizer_schema(schema) click.echo("Setting up the optimizers.") - optimizer, _ = model.setup_optimization({"name": optimizer_name, "lr": 1e-7, "weight_decay": 0.0}) + optimizer = model.configure_optimizers() + if isinstance(optimizer, tuple): + optimizer = optimizer[0][0] + + # warmup - preallocate model/optimizer memory for all modality modules + for sch_ in ("text", "audio"): + gen_ = ProfilingBatchGenerator(model.oomptimizer_schema(sch_), start_batch_size=1) + with torch.autocast("cuda", getattr(torch, dtype)): + if sch_ == "audio": + batch_ = gen_(17519, 13) + else: + batch_ = gen_(9, 7) + optimizer.zero_grad() + out = model.training_step(iter([batch_])) + optimizer.step() is_2d_bucketing = all( isinstance(item, (list, tuple)) and len(item) == 2 and all(isinstance(v, Number) for v in item) @@ -484,8 +495,8 @@ def step(): f"\tCurrent settings | batch_size={gen._current} | gap: {gen.current_rel_gap}... ", nl=False ) optimizer.zero_grad() + # In SpeechLLM training_step performs both forward and backward; no need for manual backward out = model.training_step(iter([batch])) - # out['loss'].sum().backward() optimizer.step() except torch.cuda.OutOfMemoryError as e: click.secho(f"OOM!", fg="yellow") @@ -506,7 +517,7 @@ def step(): # but we have found out empirically that this causes a mismatched condition # between OOMptimizer and the actual training. During training, there is some # degree of memory fragmentation and it's better to simulate that in OOMptimizer. - torch.cuda.memory.empty_cache() + #torch.cuda.memory.empty_cache() torch.cuda.reset_max_memory_allocated() return oom From 0c3314642584c9a2d0806d3697007a000372b525 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 16 Oct 2024 06:46:58 -0700 Subject: [PATCH 43/63] Move oomptimizer-speechllm to separate directory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../oomptimizer-speechllm.py => speech_llm/oomptimizer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{speech_recognition/oomptimizer-speechllm.py => speech_llm/oomptimizer.py} (100%) diff --git a/scripts/speech_recognition/oomptimizer-speechllm.py b/scripts/speech_llm/oomptimizer.py similarity index 100% rename from scripts/speech_recognition/oomptimizer-speechllm.py rename to scripts/speech_llm/oomptimizer.py From c3ea064607b1acb09ae8d4955d03c6d2404d783b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 16 Oct 2024 10:16:08 -0400 Subject: [PATCH 44/63] Initial cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- ...o_gpt_config_cross_llama_lhotse_multi.yaml | 59 +++++---- .../salm/modular_audio_t5_multi_config.yaml | 12 +- .../speech_llm/modular_audio_gpt_text_eval.py | 118 ------------------ nemo/collections/common/data/lhotse/cutset.py | 2 - .../common/data/lhotse/dataloader.py | 16 ++- .../speech_llm/models/modular_models.py | 15 +-- .../speech_llm/models/modular_t5_models.py | 10 +- 7 files changed, 49 insertions(+), 183 deletions(-) rename examples/multimodal/speech_llm/conf/{ => bestow}/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml (89%) delete mode 100644 examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml similarity index 89% rename from examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml rename to examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 89fded05024b..3d0c1c43bf4a 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -240,7 +240,7 @@ model: multi_config: true audio: input_cfg: ??? - sampler_fusion: zip + sampler_fusion: round_robin seed: 0 shard_seed: "trng" batch_size: null @@ -251,6 +251,7 @@ model: bucket_buffer_size: 20000 num_workers: 4 shuffle: true + prompt_format: llama2 text: input_cfg: ??? use_multimodal_sampling: true @@ -261,6 +262,7 @@ model: bucket_buffer_size: 20000 num_workers: 4 shuffle: true + prompt_format: llama2 global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} @@ -276,7 +278,6 @@ model: separate_prompt_and_response_with_newline: False truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" validation_ds: manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. @@ -284,7 +285,7 @@ model: global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False - num_workers: 0 + num_workers: 1 pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -300,7 +301,6 @@ model: output_file_path_prefix: null # Prefix of the file to write predictions to. truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" tokens_to_generate: 128 # ASR configs sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} @@ -311,32 +311,31 @@ model: average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. num_classes: null - # test_ds: - # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. - # force_finite: true # workaround to allow using input_cfg - # names: null # Names of the corresponding datasets used to log metrics. - # global_batch_size: ${model.global_batch_size} - # micro_batch_size: ${model.micro_batch_size} - # shuffle: False - # num_workers: 4 - # pin_memory: True - # max_seq_length: 2048 - # min_seq_length: 1 - # drop_last: False - # context_key: 'input' - # label_key: 'output' - # add_eos: ${model.data.train_ds.add_eos} - # end_string: ${model.data.end_string} - # add_sep: ${model.data.train_ds.add_sep} - # add_bos: ${model.data.train_ds.add_bos} - # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} - # write_predictions_to_file: False - # output_file_path_prefix: null # Prefix of the file to write predictions to. - # truncation_field: "context" # Options: ['context', 'answer'] - # index_mapping_dir: null # Path to a directory to write index mapping files. - # prompt_template: ${model.data.train_ds.prompt_template} - # # ASR configs - # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + test_ds: + manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 1 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} # metric: # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml index b8472fdf53fd..e0b262db3adb 100644 --- a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -256,8 +256,6 @@ model: separate_prompt_and_response_with_newline: False truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "{context}{answer}" - #prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" validation_ds: force_finite: true # workaround to allow using input_cfg @@ -265,7 +263,7 @@ model: global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False - num_workers: 0 + num_workers: 1 pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -280,10 +278,9 @@ model: output_file_path_prefix: null # Prefix of the file to write predictions to. truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" tokens_to_generate: 128 # ASR configs - sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + sample_rate: 16000 log_every_n_steps: 10 metric: @@ -299,7 +296,7 @@ model: global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False - num_workers: 4 + num_workers: 1 pin_memory: True max_seq_length: 2048 min_seq_length: 1 @@ -314,9 +311,8 @@ model: output_file_path_prefix: null # Prefix of the file to write predictions to. truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: ${model.data.train_ds.prompt_template} # ASR configs - sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + sample_rate: 16000 # metric: # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py b/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py deleted file mode 100644 index d76e479829fa..000000000000 --- a/examples/multimodal/speech_llm/modular_audio_gpt_text_eval.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pathlib import Path - -import torch.multiprocessing as mp -from omegaconf.omegaconf import OmegaConf - -from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel -from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder -from nemo.core.config import hydra_runner -from nemo.utils import logging - -mp.set_start_method("spawn", force=True) - -""" -This is the script to run inference with a ModularAudioGPTModel. - -If you want to evaluate an ModularAudioGPTModel: - -MEGATRON_CKPT=/path/to/megatron-llm.nemo -ALM_DIR=/path/to/nemo_experiments/job_name -ALM_YAML=$ALM_DIR/version_0/hparams.yaml -ALM_CKPT="$ALM_DIR/checkpoints/AudioGPT--validation_wer\=0.5-step\=103-epoch\=0-last.ckpt" - -VAL_MANIFESTS="[/data/libri-test-other.json,/data/MCV_7.1_test.json,/data/wsj-test.json]" -VAL_NAMES="[ls-test-other,mcv7.1-test,wsj-test]" - -HYDRA_FULL_ERROR=1 \ -CUDA_VISIBLE_DEVICES=0 python modular_audio_gpt_eval.py \ - model.restore_from_path=$MEGATRON_CKPT \ - model.peft.restore_from_path=$ALM_CKPT \ - model.peft.restore_from_hparams_path=$ALM_YAML \ - model.data.test_ds.manifest_filepath=$VAL_MANIFESTS \ - model.data.test_ds.names=$VAL_NAMES \ - model.data.test_ds.global_batch_size=8 \ - model.data.test_ds.micro_batch_size=8 \ - model.data.test_ds.tokens_to_generate=256 \ - ++inference.greedy=False \ - ++inference.top_k=50 \ - ++inference.top_p=0.95 \ - ++inference.temperature=0.4 \ - ++inference.repetition_penalty=1.2 \ - ++model.data.test_ds.output_dir=${ALM_DIR} -""" - - -@hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_eval") -def main(cfg) -> None: - logging.info("\n\n************** Experiment configuration ***********") - logging.info(f"\n{OmegaConf.to_yaml(cfg)}") - logging.info("**************************************************\n\n") - - trainer = MegatronTrainerBuilder(cfg).create_trainer() - - if cfg.model.from_pretrained: - # Load model from NGC or HuggingFace - logging.info(f"Loading model from cloud: {cfg.model.from_pretrained}") - model_cfg = ModularAudioGPTModel.from_pretrained( - cfg.model.from_pretrained, trainer=trainer, return_config=True - ) - model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer, model_cfg) - model_file = ModularAudioGPTModel.from_pretrained( - cfg.model.from_pretrained, trainer=trainer, return_model_file=True - ) - model = ModularAudioGPTModel.restore_from( - restore_path=model_file, - trainer=trainer, - override_config_path=model_cfg, - strict=False, - map_location="cpu", - ) - if "peft" in model_cfg and model_cfg.peft.get("peft_scheme", None): - # need this due to the way that MegatronGPTSFTModel doesn't load adapters in model initialization - model.load_adapters(model_file, map_location="cpu") - else: - # Load model from a local file - model_cfg = ModularAudioGPTModel.merge_inference_cfg(cfg, trainer) - model = ModularAudioGPTModel.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=model_cfg, - strict=False, - map_location="cpu", - ) - model = ModularAudioGPTModel.load_adapters_for_inference(cfg, model_cfg, model) - model = ModularAudioGPTModel.load_audio_encoder_for_inference(cfg, model_cfg, model) - - model.freeze() - if cfg.get("save_as_nemo", None): - model.setup("predict") # need to call setup() to load adapters and prepare for saving - model.save_to(cfg.save_as_nemo) - logging.info(f"Model saved to {Path(cfg.save_as_nemo).absolute()}, exiting...") - exit(0) - - if not cfg.model.get('use_flash_attention', False): - cfg.inference.compute_attention_mask = True - config = OmegaConf.to_container(cfg.inference, resolve=True) - model.set_inference_config(config) - - # run inference - trainer.test(model) - - -if __name__ == "__main__": - main() diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 796a4dc589b7..c38ab8daeddd 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -153,8 +153,6 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: # Resolve /path/to/input_cfg.yaml into config contents if needed. input_cfg = OmegaConf.load(input_cfg) cuts, is_tarred = parse_and_combine_datasets(input_cfg, propagate_attrs=propagate_attrs) - if propagate_attrs["force_finite"]: - is_tarred = False # TEMPORARY Disables IterableDatasetWrapper behaviour for finite datasets return cuts, is_tarred diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index b4221c4b1d91..d8418f74966c 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -101,22 +101,26 @@ class LhotseDataLoadingConfig: token_equivalent_duration: float | None = None batch_tokens: int | None = None quadratic_factor: float | None = None + + # 2.2 Filters on sequence lengths. + # * Speech input + min_duration: float | None = -1 + max_duration: float | None = float("inf") + min_tps: int = -1 # allowed tokens per second (audio-only) + max_tps: float = float("inf") + # * Text input min_tokens: int | None = -1 max_tokens: int | None = 1_000_000_000 + min_tpt: int = -1 # allowed tokens per token (text-only) + max_tpt: float = float("inf") # 3. Supported existing NeMo options. shuffle: bool = False sample_rate: int = 16000 - min_duration: float | None = -1 - max_duration: float | None = float("inf") seed: int | str = 0 num_workers: int = 0 pin_memory: bool = False channel_selector: int | str | None = None - min_tps: int = -1 # allowed tokens per second (audio-only) - max_tps: float = float("inf") - min_tpt: int = -1 # allowed tokens per token (text-only) - max_tpt: float = float("inf") # 4. Optional Lhotse data augmentation. # a. On-the-fly noise/audio mixing. diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 0364a4212a28..b4cfea49ccc0 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -495,8 +495,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): else: seq_length = None # TODO(pzelasko): not sure if it is even needed ??? - data_iter = get_iterator_k_split(batch, 1) - # data_iter = get_iterator_k_split(batch, get_num_microbatches()) + data_iter = get_iterator_k_split(batch, get_num_microbatches()) # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches # if log_token_counts: @@ -526,10 +525,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only), data_iterator=self._make_data_iterator_list(data_iter), model=self.model, - num_microbatches=1, # get_num_microbatches(), + num_microbatches=get_num_microbatches(), forward_only=forward_only, seq_length=seq_length, - micro_batch_size=self.cfg.data.train_ds.micro_batch_size, # get_micro_batch_size(), + micro_batch_size=get_micro_batch_size(), first_val_step=first_val_step, ) @@ -1392,7 +1391,7 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int inference_config['end_strings'] = [self.cfg.data.end_string] global_batch_size_per_gpu = batch['tokens'].size(0) - num_micro_batches_before_decode = 1 # get_num_microbatches() + num_micro_batches_before_decode = get_num_microbatches() compute_logprob = inference_config.get('compute_logprob', False) if compute_logprob: @@ -1802,12 +1801,6 @@ def oomptimizer_schema(self, schema: str = "audio") -> dict: ], } elif schema == "text": - # TODO: add support for text - # input_ids = text_batch["text_input_ids"][:, :-1] - # labels = text_batch["text_input_ids"][:, 1:] - # attention_mask = self._create_attention_mask(input_ids) - # loss_mask = text_batch["text_masks"][:, 1:] - return { "cls": dict, "inputs": [ diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 75439b775d1e..00cda52539a4 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -953,8 +953,8 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = output["input_text"] # [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] - labels_text = output["labels_text"] # [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + inputs_text = output["input_text"] + labels_text = output["labels_text"] preds_text = output['preds_text'] if data_cfg.get("log_every_n_steps", None) is not None: if batch_idx % data_cfg.log_every_n_steps == 0: @@ -1423,12 +1423,6 @@ def oomptimizer_schema(self, schema: str = "audio") -> dict: ], } elif schema == "text": - # TODO: add support for text - # input_ids = text_batch["text_input_ids"][:, :-1] - # labels = text_batch["text_input_ids"][:, 1:] - # attention_mask = self._create_attention_mask(input_ids) - # loss_mask = text_batch["text_masks"][:, 1:] - return { "cls": dict, "inputs": [ From 2a16008196560110ef155b3f6dabc1cc53a290aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 16 Oct 2024 16:32:27 -0400 Subject: [PATCH 45/63] Refactoring of prompt format fn and length measurement and filtering for data types; improved unit test coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../asr/data/audio_to_text_lhotse_prompted.py | 14 +- .../asr/models/aed_multitask_models.py | 4 +- nemo/collections/common/data/__init__.py | 2 + .../common/data/lhotse/__init__.py | 13 +- nemo/collections/common/data/lhotse/cutset.py | 172 +++-- .../common/data/lhotse/dataloader.py | 185 +++-- .../common/data/lhotse/text_adapters.py | 658 +++++------------- nemo/collections/common/data/prompt_fn.py | 77 ++ nemo/collections/common/prompts/__init__.py | 1 - nemo/collections/common/prompts/canary.py | 108 ++- nemo/collections/common/prompts/fn.py | 38 - nemo/collections/common/prompts/gemma.py | 46 +- nemo/collections/common/prompts/llama.py | 97 ++- nemo/collections/common/prompts/plain.py | 35 +- nemo/collections/common/prompts/t5nmt.py | 81 ++- .../asr/test_asr_multitask_model_bpe.py | 8 +- .../common/test_lhotse_dataloading.py | 47 +- .../test_lhotse_multimodal_dataloading.py | 2 - .../test_lhotse_prompt_format_data_types.py | 283 ++++++++ .../common/test_lhotse_seqlen_filters.py | 171 +++++ 20 files changed, 1155 insertions(+), 887 deletions(-) create mode 100644 nemo/collections/common/data/prompt_fn.py delete mode 100644 nemo/collections/common/prompts/fn.py create mode 100644 tests/collections/common/test_lhotse_prompt_format_data_types.py create mode 100644 tests/collections/common/test_lhotse_seqlen_filters.py diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 912188cb9c02..51935bbbfdcd 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -19,6 +19,8 @@ from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors +from nemo.collections.common.data import apply_prompt_format_fn +from nemo.collections.common.prompts import CanaryPromptFormatter, PromptFormatter from nemo.collections.common.tokenizers import TokenizerSpec @@ -62,15 +64,13 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: TokenizerSpec, - prompt_format_fn: Callable[ - [CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]] - ], + prompt: PromptFormatter, ): super().__init__() self.tokenizer = tokenizer self.load_audio = AudioSamples(fault_tolerant=True) self.padding_value = self.tokenizer.pad - self.prompt_format_fn = prompt_format_fn + self.prompt = prompt def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) @@ -81,8 +81,10 @@ def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: if pre_formatted: prompts_with_answers, prompts, answers = zip(*((c.input_ids, c.context_ids, c.answer_ids) for c in cuts)) else: - ans = self.prompt_format_fn(cuts, self.tokenizer) - prompts_with_answers, prompts, answers = ans["input_ids"], ans["context_ids"], ans["answer_ids"] + formatted = [apply_prompt_format_fn(cut, self.prompt) for cut in cuts] + prompts_with_answers = [ex["input_ids"] for ex in formatted] + prompts = [ex["context_ids"] for ex in formatted] + answers = [ex["answer_ids"] for ex in formatted] transcript, transcript_lens = self._collate_tokens(answers) prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index d2d2213be6e6..454c79ee4e87 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -44,10 +44,10 @@ from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.prompt_fn import get_prompt_format_fn from nemo.collections.common.metrics import GlobalAverageLossMetric from nemo.collections.common.parts import transformer_weights_init from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.core.classes.common import typecheck from nemo.core.neural_types import ( @@ -510,7 +510,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): world_size=self.world_size, dataset=PromptedAudioToTextLhotseDataset( tokenizer=self.tokenizer, - prompt_format_fn=get_prompt_format_fn(self.prompt_format), + prompt=self.prompt, ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/common/data/__init__.py b/nemo/collections/common/data/__init__.py index ecc67ef05ea5..d4b43d2b4edc 100644 --- a/nemo/collections/common/data/__init__.py +++ b/nemo/collections/common/data/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset, ConcatMapDataset +from nemo.collections.common.data.lhotse import * +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn, get_prompt_format_fn diff --git a/nemo/collections/common/data/lhotse/__init__.py b/nemo/collections/common/data/lhotse/__init__.py index 6bbe9e991236..95f0d01db297 100644 --- a/nemo/collections/common/data/lhotse/__init__.py +++ b/nemo/collections/common/data/lhotse/__init__.py @@ -13,4 +13,15 @@ # limitations under the License. from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config -from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.dataloader import ( + LhotseDataLoadingConfig, + get_lhotse_dataloader_from_config, + get_lhotse_sampler_from_config, +) +from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator, LazyNeMoTarredIterator +from nemo.collections.common.data.lhotse.text_adapters import ( + NeMoMultimodalConversation, + NeMoSFTExample, + SourceTargetTextExample, + TextExample, +) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index c38ab8daeddd..6c314d4b9de8 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -17,7 +17,7 @@ from functools import partial from itertools import repeat from pathlib import Path -from typing import Sequence, Tuple, Union +from typing import KeysView, Sequence, Tuple, Union import omegaconf from lhotse import CutSet, Features, Recording @@ -35,13 +35,15 @@ from nemo.collections.common.parts.preprocessing.manifest import get_full_path -def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]: +def read_cutset_from_config(config: DictConfig | dict) -> Tuple[CutSet, bool]: """ Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests. Returns a tuple of ``CutSet`` and a boolean indicating whether the data is tarred (True) or not (False). """ # First, check if the dataset is specified in the new configuration format and use it if possible. + if not isinstance(config, DictConfig): + config = DictConfig(config) if config.get("input_cfg") is not None: return read_dataset_config(config) # Now, we'll figure out if we should read Lhotse manifest or NeMo manifest. @@ -50,31 +52,64 @@ def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]: assert ( config.get("manifest_filepath") is not None ), "You must specify either: manifest_filepath, cuts_path, or shar_path" - is_tarred = config.get("tarred_audio_filepaths") is not None + cuts, is_tarred = read_nemo_manifest(config) else: - is_tarred = config.get("shar_path") is not None - if use_nemo_manifest: - # Read NeMo manifest -- use the right wrapper depending on tarred/non-tarred. - cuts = read_nemo_manifest(config, is_tarred) - else: - # Read Lhotse manifest (again handle both tarred(shar)/non-tarred). - cuts = read_lhotse_manifest(config, is_tarred) + cuts, is_tarred = read_lhotse_manifest(config) return cuts, is_tarred -KNOWN_DATASET_CONFIG_TYPES = frozenset( - ( - "nemo", - "nemo_tarred", - "lhotse", - "lhotse_shar", - "txt", - "txt_pair", - "nemo_sft_jsonl", - "multimodal_conversation", - "group", - ) -) +KNOWN_DATA_CONFIG_TYPES = {} + + +def get_known_config_data_types() -> KeysView[str]: + """ + Return the names of all registered data type parsers. + + Example: + + >>> get_known_config_data_types() + ["nemo", "nemo_tarred", "lhotse", ...] + """ + return KNOWN_DATA_CONFIG_TYPES.keys() + + +def get_parser_fn(data_type_name: str): + """ + Return the parsing function for a given data type name. + Parsing function reads a dataloading config and returns a tuple + of lhotse ``CutSet`` and boolean indicating whether we should use + iterable dataset (True) or map dataset (False) mechanism ("is tarred"). + """ + return KNOWN_DATA_CONFIG_TYPES[data_type_name] + + +def data_type_parser(name: str | list[str]): + """ + Decorator used to register data type parser functions. + Parsing function reads a dataloading config and returns a tuple + of lhotse ``CutSet`` and boolean indicating whether we should use + iterable dataset (True) or map dataset (False) mechanism ("is tarred"). + + Example: + + >>> @data_type_parser("my_new_format") + ... def my_new_format(config): + ... return CutSet(read_my_format(**config)), True + ... + ... fn = get_parser_fn("my_new_format") + ... cuts, is_tarred = fn({"my_arg_0": ..., "my_arg_1": ..., ...}) + """ + + def _decorator(fn): + global KNOWN_DATA_CONFIG_TYPES + if isinstance(name, str): + KNOWN_DATA_CONFIG_TYPES[name] = fn + else: + for n in name: + KNOWN_DATA_CONFIG_TYPES[n] = fn + return fn + + return _decorator def read_dataset_config(config) -> tuple[CutSet, bool]: @@ -142,11 +177,12 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: propagate_attrs = { "shuffle": config.shuffle, "shard_seed": config.shard_seed, - "text_field": config.text_field, - "lang_field": config.lang_field, - "metadata_only": config.metadata_only, - "force_finite": config.force_finite, - "max_open_streams": config.max_open_streams, + "text_field": config.get("text_field", "text"), + "lang_field": config.get("lang_field", "lang"), + "metadata_only": config.get("metadata_only", False), + "force_finite": config.get("force_finite", False), + "max_open_streams": config.get("max_open_streams", None), + "token_equivalent_duration": config.get("token_equivalent_duration", None), } input_cfg = config.input_cfg if isinstance(input_cfg, (str, Path)): @@ -157,50 +193,27 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: - assert grp_cfg.type in KNOWN_DATASET_CONFIG_TYPES, f"Unknown item type in dataset config list: {grp_cfg.type=}" - if grp_cfg.type == "nemo_tarred": - is_tarred = True - cuts = read_nemo_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "nemo": - is_tarred = False - cuts = read_nemo_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "lhotse_shar": - is_tarred = True - cuts = read_lhotse_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "lhotse": - is_tarred = False - cuts = read_lhotse_manifest(grp_cfg, is_tarred=is_tarred) - # Note: "txt" and "txt_pair" have "is_tarred" set to True. - # The main reason is to enable combination of tarred audio and text dataloading, - # since we don't allow combination of tarred and non-tarred datasets. - # We choose to treat text as-if it was tarred, which also tends to be more + assert grp_cfg.type in get_known_config_data_types(), f"Unknown item type in dataset config list: {grp_cfg.type=}" + + # Note: Text data types will return is_tarred=True. + # We choose to treat text as-if it was tarred, which tends to be more # efficient as it moves the text file iteration into dataloading subprocess. - elif grp_cfg.type == "txt": - is_tarred = True - cuts = read_txt_paths(grp_cfg) - elif grp_cfg.type == "txt_pair": - is_tarred = True - cuts = read_txt_pair_paths(grp_cfg) - elif grp_cfg.type == "nemo_sft_jsonl": - is_tarred = True - cuts = read_nemo_sft_jsonl(grp_cfg) - elif grp_cfg.type == "multimodal_conversation": - is_tarred = True - cuts = read_multimodal_conversation_jsonl(grp_cfg) - elif grp_cfg.type == "group": + if grp_cfg.type != "group": + parser_fn = get_parser_fn(grp_cfg.type) + cuts, is_tarred = parser_fn(grp_cfg) + else: cuts, is_tarred = parse_and_combine_datasets( grp_cfg.input_cfg, propagate_attrs=propagate_attrs, ) - else: - raise ValueError(f"Unrecognized group: {grp_cfg.type}") # Attach extra tags to every utterance dynamically, if provided. if (extra_tags := grp_cfg.get("tags")) is not None: cuts = cuts.map(partial(attach_tags, tags=extra_tags), apply_fn=None) return cuts, is_tarred -def read_txt_paths(config: DictConfig) -> CutSet: +@data_type_parser("txt") +def read_txt_paths(config: DictConfig) -> tuple[CutSet, bool]: cuts = CutSet( LhotseTextAdapter( paths=config.paths, @@ -211,16 +224,17 @@ def read_txt_paths(config: DictConfig) -> CutSet: ) if not config.get("force_finite", False): cuts = cuts.repeat() - return cuts + return cuts, True -def read_txt_pair_paths(config: DictConfig) -> CutSet: +@data_type_parser("txt_pair") +def read_txt_pair_paths(config: DictConfig) -> tuple[CutSet, bool]: cuts = CutSet( LhotseTextPairAdapter( source_paths=config.source_paths, target_paths=config.target_paths, - source_language=config.source_language, - target_language=config.target_language, + source_language=config.get("source_language"), + target_language=config.get("target_language"), questions_path=config.get("questions_path"), questions_language=config.get("questions_language"), shuffle_shards=config.shuffle, @@ -229,36 +243,39 @@ def read_txt_pair_paths(config: DictConfig) -> CutSet: ) if not config.get("force_finite", False): cuts = cuts.repeat() - return cuts + return cuts, True -def read_nemo_sft_jsonl(config: DictConfig) -> CutSet: +@data_type_parser("nemo_sft_jsonl") +def read_nemo_sft_jsonl(config: DictConfig) -> tuple[CutSet, bool]: cuts = CutSet( NeMoSFTJsonlAdapter( paths=config.paths, - language=config.language, + language=config.get("language"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) ) if not config.get("force_finite", False): cuts = cuts.repeat() - return cuts + return cuts, True -def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet: +@data_type_parser("multimodal_conversation") +def read_multimodal_conversation_jsonl(config: DictConfig) -> tuple[CutSet, bool]: cuts = CutSet( NeMoMultimodalConversationJsonlAdapter( manifest_filepath=config.manifest_filepath, tarred_audio_filepaths=config.get("tarred_audio_filepaths"), audio_locator_tag=config.audio_locator_tag, + token_equivalent_duration=config.get("token_equivalent_duration"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) ) if not config.get("force_finite", False): cuts = cuts.repeat() - return cuts + return cuts, True def attach_tags(cut, tags: dict): @@ -267,6 +284,7 @@ def attach_tags(cut, tags: dict): return cut +@data_type_parser("group") def parse_and_combine_datasets( config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict ) -> tuple[CutSet, bool]: @@ -312,7 +330,9 @@ def parse_and_combine_datasets( return cuts, tarred_status[0] -def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: +@data_type_parser(["lhotse", "lhotse_shar"]) +def read_lhotse_manifest(config) -> tuple[CutSet, bool]: + is_tarred = config.shar_path is not None if is_tarred: # Lhotse Shar is the equivalent of NeMo's native "tarred" dataset. # The combination of shuffle_shards, and repeat causes this to @@ -381,7 +401,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: # Regular Lhotse manifest points to individual audio files (like native NeMo manifest). path = config.cuts_path cuts = CutSet.from_file(path).map(partial(resolve_relative_paths, manifest_path=path)) - return cuts + return cuts, is_tarred def _resolve_shar_inputs(path: str | Path, only_metadata: bool) -> dict: @@ -439,7 +459,8 @@ def resolve_array(value): return cut -def read_nemo_manifest(config, is_tarred: bool) -> CutSet: +@data_type_parser(["nemo", "nemo_tarred"]) +def read_nemo_manifest(config) -> tuple[CutSet, bool]: common_kwargs = { "text_field": config.text_field, "lang_field": config.lang_field, @@ -456,6 +477,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: notar_kwargs = {"metadata_only": config.metadata_only} metadata_only = config.metadata_only force_finite = config.force_finite + is_tarred = config.get("tarred_audio_filepaths") is not None if isinstance(config.manifest_filepath, (str, Path)): logging.info(f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'") if is_tarred and not metadata_only: @@ -535,7 +557,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: seed=config.shard_seed, force_finite=force_finite or metadata_only, ) - return cuts + return cuts, is_tarred def mux( diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index d8418f74966c..2881ffd2e24b 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import bisect +import math import os import random import warnings from dataclasses import dataclass from functools import partial -from typing import Any, List, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, Sequence import numpy as np import torch @@ -38,17 +39,17 @@ from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint from lhotse.lazy import LazyFlattener from lhotse.utils import fastcopy, fix_random_seed -from omegaconf import DictConfig, ListConfig, OmegaConf +from omegaconf import DictConfig, OmegaConf from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config from nemo.collections.common.data.lhotse.text_adapters import ( - NeMoMultimodalConversation, + Formattable, NeMoSFTExample, SourceTargetTextExample, TextExample, ) +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn from nemo.collections.common.prompts import PromptFormatter -from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.utils import logging @@ -109,8 +110,9 @@ class LhotseDataLoadingConfig: min_tps: int = -1 # allowed tokens per second (audio-only) max_tps: float = float("inf") # * Text input - min_tokens: int | None = -1 - max_tokens: int | None = 1_000_000_000 + min_tokens: int | None = None + max_tokens: int | None = None + measure_total_length: bool = True min_tpt: int = -1 # allowed tokens per token (text-only) max_tpt: float = float("inf") @@ -446,7 +448,9 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Duration filtering, same as native NeMo dataloaders. # We can filter after the augmentations because they are applied only when calling load_audio(). cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration)) - cuts = cuts.filter(TokenCountFilter(config.min_tokens, config.max_tokens)) + cuts = cuts.filter( + TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) + ) bucket_duration_bins = determine_bucket_duration_bins(config) if config.use_multimodal_sampling: @@ -586,13 +590,15 @@ def determine_bucket_duration_bins(config): return None -def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: +def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfig: """ Checks the schema and fills missing default option values. Warns the user if any of the fields are not supported by the current schema but does not raise exceptions. """ default = OmegaConf.structured(LhotseDataLoadingConfig) + if not isinstance(config, DictConfig): + config = DictConfig(config) # Remove unsupported keys and warn about them. supported_keys = set(OmegaConf.to_container(default).keys()) @@ -610,23 +616,29 @@ def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: @dataclass class MultimodalSamplingConstraint(SamplingConstraint): - # how many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. token_equivalent_duration: float | None = None - # defines maximum batch size (may be lower than that if batch_length is also specified) + # Defines maximum batch size (may be lower than that if batch_length is also specified). batch_size: int | None = None - # defines the total number of tokens in a mini-batch - # setting this enables dynamic batch sizes - # we will use ``token_equivalent_duration`` to convert audio examples to token sizes + # Defines the total number of tokens in a mini-batch. + # Setting this enables dynamic batch sizes. + # We will use ``token_equivalent_duration`` to convert audio examples to token sizes. batch_tokens: int | None = None - # when specified, this value is inversely proportional to the penalty we assign + # When specified, this value is inversely proportional to the penalty we assign # to longer examples when measuring their length/duration; - # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty - # tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing + # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty. + # Tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing. quadratic_factor: float | None = None + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + _internal = None def __post_init__(self): @@ -637,9 +649,8 @@ def __post_init__(self): ) def add(self, example: Any) -> None: - if isinstance(example, Cut): - num_tokens = self.measure_length(example) - example.num_tokens = num_tokens + num_tokens = self.measure_length(example) + example.num_tokens = num_tokens self._internal.add(example) def exceeded(self) -> bool: @@ -653,16 +664,26 @@ def reset(self) -> None: def measure_length(self, example: Any) -> float: if isinstance(example, Cut): - # "length" of a Cut (audio+text example) is counted as the sum of: - # * num_tokens in each supervision segment ("utterance") in the Cut - # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) - return example.duration / self.token_equivalent_duration + text_tokens - if isinstance(example, (TextExample, SourceTargetTextExample, NeMoSFTExample)): - return example.num_tokens + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + if self.measure_total_length: + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + elif isinstance(example, Formattable): + try: + return example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + "Couldn't determine the length of a text example; " + "have you provided both prompt_format and tokenizer when instantiating the dataloader?" + ) from e raise RuntimeError(f"Unsupported example type: {type(example)}") @@ -707,38 +728,35 @@ def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = No @dataclass class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. token_equivalent_duration: float | None = None + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + def measure_length(self, example: Any) -> float | tuple[float, float]: - # Case 1: audio if isinstance(example, Cut): - assert ( - self.token_equivalent_duration is not None - ), "Cannot use MultimodalFixedBucketBatchSizeConstraint with speech data when token_equivalent_duration was not specified." - in_tokens = example.duration / self.token_equivalent_duration - if self.bucketing_2d_enabled: - return in_tokens, _measure_tokens(example) + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + if self.measure_total_length: + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + return audio_len_in_tokens + text_tokens else: - return in_tokens - # Case 2: text - if self.bucketing_2d_enabled: - if hasattr(example, "context_ids") and hasattr(example, "answer_ids"): - return len(example.context_ids), len(example.answer_ids) - else: - if hasattr(example, "num_tokens"): - return example.num_tokens - + return audio_len_in_tokens + elif isinstance(example, Formattable): + return example.total_length if self.measure_total_length else example.input_length raise RuntimeError(f"Unsupported example type: {type(example)}") -def is_text(example) -> bool: - return isinstance(example, (TextExample, SourceTargetTextExample, NeMoSFTExample)) - - -Example = TypeVar("Example", bound=Union[Cut, TextExample, SourceTargetTextExample, NeMoSFTExample]) - - -def tokenize(example: Example, tokenizer) -> Example: +def tokenize(example, tokenizer): if isinstance(example, Cut): for s in example.supervisions: if s.text is not None: @@ -750,23 +768,12 @@ def tokenize(example: Example, tokenizer) -> Example: return example -def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str | PromptFormatter) -> Example: - if isinstance(example, Cut): - prompt_format_fn = get_prompt_format_fn(prompt_format) - ans = prompt_format_fn(CutSet([example]), tokenizer) - example.input_ids = ans["input_ids"][0] - example.context_ids = ans["context_ids"][0] - if "answer_ids" in ans: - example.answer_ids = ans["answer_ids"][0] - example.answer_mask = ans["mask"][0] - elif isinstance(example, NeMoMultimodalConversation): - example = example.tokenize(tokenizer, prompt_format) - else: - # TODO: need an equivalent of get_prompt_format_fn for text modality - # to be able to construct different kinds of turns specific to a given application - if isinstance(prompt_format, str): - prompt_format = PromptFormatter.resolve(prompt_format)(tokenizer) - example = example.tokenize(tokenizer, prompt=prompt_format) +def tokenize_with_prompt(example, tokenizer, prompt_format: str | PromptFormatter): + if isinstance(prompt_format, str): + prompt_format = PromptFormatter.resolve(prompt_format)(tokenizer) + encoded = apply_prompt_format_fn(example, prompt_format) + for key, value in encoded.items(): + setattr(example, key, value) return example @@ -791,18 +798,46 @@ def __call__(self, example) -> bool: class TokenCountFilter: - """Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise.""" + """ + Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. + + It is only applicable to data types that derive from class ``Formattable`` and lhotse ``Cut`` objects. + Acts as a passthrough for Cuts. + Raises exception if a non-Formattable and non-Cut data are provided. + + The ``measure_total_length`` option allows to select whether we should filter on context_ids length (=False) + or input_ids length (=True). + The difference is that for decoder-only models, we collapse input and output into a single sequence, + so we should measure the example length using input_ids (measure_total_length=True). + However, for models which have separate inputs and outputs such as encoder-decoder models, + we want to measure the input lengths only here (measure_total_length=False), + and enable ``TokenPerTokenFilter`` for additional filtering on the output sequence length. + """ - def __init__(self, t_min: float, t_max: float) -> None: + def __init__(self, t_min: float, t_max: float, measure_total_length: bool) -> None: self.t_min = t_min self.t_max = t_max + self.measure_total_length = measure_total_length def __call__(self, example) -> bool: + if self.t_min is None and self.t_max is None: + return True # disabled if isinstance(example, Cut): return True # does not apply to Cuts - elif hasattr(example, "num_tokens") and example.num_tokens is not None: - return self.t_min <= example.num_tokens <= self.t_max - return True # applies only to non-audio with num_tokens property + assert isinstance(example, Formattable), ( + f"TokenCountFilter can only be applied to data examples that derive Formattable class. " + f"Formattable objects define properties input_length, output_length, and total_length that " + f"allow us to select the right sequence length for filtering. We got: {example}" + ) + try: + length = example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + f"Cannot measure token count for example: {example} " + f"-- did you forget to apply prompt formatting? If instantiating Lhotse dataloader, " + f"make sure you provided 'prompt_format' option and passed the tokenizer." + ) from e + return self.t_min <= length <= self.t_max class TokenPerSecondFilter: diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 352844bbb518..08b9804d75ad 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy +import math import random from dataclasses import dataclass from itertools import groupby @@ -29,20 +28,57 @@ from lhotse.utils import Pathlike from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn, registered_prompt_format_fn from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.common.prompts import Llama2PromptFormatter, PromptFormatter -from nemo.collections.common.prompts.t5nmt import T5NMTPromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.utils import logging """ -Basic text example, adequate for pretraining-style language modeling. +Formattable: mixin class with data fields for prompt formatter outputs and method for +applying prompt formatters to derived data types. +""" + + +class Formattable: + def __init__(self): + self.input_ids: np.ndarray | torch.Tensor | None = None + self.context_ids: np.ndarray | torch.Tensor | None = None + self.answer_ids: np.ndarray | torch.Tensor | None = None + self.mask: np.ndarray | torch.Tensor | None = None + + @property + def input_length(self) -> int | None: + if self.context_ids is None: + return None + return self.context_ids.shape[0] + + @property + def output_length(self) -> int | None: + if self.answer_ids is None: + return None + return self.answer_ids.shape[0] + + @property + def total_length(self) -> int | None: + if self.input_ids is None: + return None + return self.input_ids.shape[0] + + def apply_prompt_format(self, prompt) -> "Formattable": + ans = apply_prompt_format_fn(self, prompt) + self.input_ids = ans["input_ids"] + self.context_ids = ans["context_ids"] + self.answer_ids = ans["answer_ids"] + self.mask = ans["mask"] + return self + + +""" +TextExample: data types, file parser, default prompt formatting logic. """ @dataclass -class TextExample(CustomFieldMixin): +class TextExample(Formattable, CustomFieldMixin): """ Represents a single text example. Useful e.g. for language modeling. """ @@ -50,12 +86,7 @@ class TextExample(CustomFieldMixin): text: str language: str | None = None tokens: Optional[np.ndarray] = None - - @property - def num_tokens(self) -> Optional[int]: - if self.tokens is None: - return None - return len(self.tokens) + custom: dict = None def tokenize(self, tokenizer: TokenizerWrapper) -> "TextExample": self.tokens = np.asarray(tokenizer(self.text, self.language)) @@ -88,13 +119,26 @@ def __iter__(self) -> Iterator[TextExample]: yield TextExample(line, language=self.language) +@registered_prompt_format_fn(TextExample) +def default_text_example_prompt_format_fn(example: TextExample, prompt): + # It doesn't really make sense to prompt format a single line text example, + # but we implement some default logic for the sake of completeness. + # The default logic here is to treat the whole example as an assistant turn, + # so that the mask is all set to true for the training loss. + return prompt.encode_dialog( + [ + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.text}}, + ] + ) + + """ -Source-target text examples (e.g., machine translation). +SourceTargetTextExample: data types, file parser, default prompt formatting logic. """ @dataclass -class SourceTargetTextExample(CustomFieldMixin): +class SourceTargetTextExample(Formattable, CustomFieldMixin): """ Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks. Supports a ``question`` field, used as the prompt for LLM. @@ -103,68 +147,13 @@ class SourceTargetTextExample(CustomFieldMixin): source: TextExample target: TextExample question: TextExample | None = None - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None - - @property - def num_tokens(self) -> Optional[int]: - if self.input_ids is not None: - return self.input_ids.shape[0] - return None - - def tokenize(self, tokenizer: TokenizerWrapper, prompt: PromptFormatter = None) -> "SourceTargetTextExample": - - if prompt is not None: - if isinstance(prompt, Llama2PromptFormatter): - ans = prompt.encode_dialog( - [ - { - "role": "system_and_user", - "slots": {"system": self.question.text, "message": self.source.text}, - }, - {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, - ] - ) - elif isinstance(prompt, T5NMTPromptFormatter): - ctx = f"<{self.target.language}>" - if self.has_custom("extra_prompt"): - ctx = f"{ctx} {self.extra_prompt}" - ans = prompt.encode_dialog( - [ - {"role": "user", "slots": {"message": self.source.text, "target_lang": ctx}}, - {"role": prompt.OUTPUT_ROLE, "slots": {"message": self.target.text}}, - ] - ) - else: - raise RuntimeError(f"Unexpected prompt formatter: {prompt}") - self.input_ids = ans["input_ids"] - self.context_ids = ans["context_ids"] - self.answer_ids = ans["answer_ids"] - self.mask = ans["mask"] - return self - - input_ids = [] - context_ids = [] - if self.question: - ans = tokenizer(self.question.text, self.question.language) - input_ids.extend(ans) - context_ids.extend(ans) - ans = tokenizer(self.source.text, self.source.language) - input_ids.extend(ans) - context_ids.extend(ans) - - answer_ids = tokenizer(self.target.text, self.target.language) - input_ids.extend(answer_ids) - - self.input_ids = np.asarray(input_ids) - self.context_ids = np.asarray(context_ids) - self.answer_ids = np.asarray(answer_ids) - mask = np.zeros_like(self.input_ids, dtype=np.bool_) - mask[self.context_ids.shape[0] :] = True - self.mask = mask + custom: dict = None + def tokenize(self, tokenizer: TokenizerWrapper) -> "SourceTargetTextExample": + self.source = self.source.tokenize(tokenizer) + self.target = self.target.tokenize(tokenizer) + if self.question is not None: + self.question = self.question.tokenize(tokenizer) return self @@ -224,75 +213,46 @@ def __iter__(self) -> Iterator[SourceTargetTextExample]: ) +@registered_prompt_format_fn(SourceTargetTextExample) +def default_src_tgt_prompt_format_fn(example: SourceTargetTextExample, prompt): + if example.question is not None: + ctx = f"{example.question.text} {example.source.text}" + else: + ctx = example.source.text + return prompt.encode_dialog( + [ + {"role": "user", "slots": {"message": ctx}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) + + +""" +NeMoSFTExample: data types, file parser, default prompt formatting logic. +""" + + @dataclass -class NeMoSFTExample(CustomFieldMixin): +class NeMoSFTExample(Formattable, CustomFieldMixin): data: dict language: str | None = None - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None metadata: dict | None = None + custom: dict = None - def tokenize(self, tokenizer: TokenizerWrapper | TokenizerSpec) -> "NeMoSFTExample": - """ - Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). - Supports BPE tokenizers and aggregate tokenizers. - - The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. - """ - special_tokens = { - "system_turn_start": "", - "turn_start": "", - "label_start": "", - "end_of_turn": "\n", - "end_of_name": "\n", - } - - if isinstance(tokenizer, TokenizerWrapper): - tokenizer = tokenizer._tokenizer - if isinstance(tokenizer, AggregateTokenizer): - assert self.language is not None, ( - f"Error: attempted to use AggregateTokenizer for NeMoSFTExample which did not specify language. " - f"Problematic example: {self}" - ) - assert self.language in tokenizer.tokenizers_dict, ( - f"Error: attempted to use AggregateTokenizer for NeMoSFTExample with unsupported language: {self.language}. " - f"The set of supported languages is: {' '.join(tokenizer.tokenizers_dict.keys())}. " - f"Problematic example: {self}" - ) - tokenizer = tokenizer.tokenizers_dict[self.language] - - label_start_tokens, name_end_token_ids, num_turn_start_tokens = _build_samples_mapping( - tokenizer, special_tokens - ) - tokenized = preprocess( - source=self.data, - tokenizer=tokenizer, - name_end_token_ids=name_end_token_ids, - label_start_ids=label_start_tokens, - special_tokens=special_tokens, - num_turn_start_tokens=num_turn_start_tokens, +@registered_prompt_format_fn(NeMoSFTExample) +def default_sft_prompt_format_fn(example: NeMoSFTExample, prompt): + if "system" in example.data and example.data["system"]: + raise RuntimeError( + f"Default prompt format for NeMoSFTExample doesn't support 'system' prompt. " + f"Please specialize the prompt_format_fn for PromptFormatter of type {prompt}" ) - self.input_ids = tokenized["input_ids"].numpy() - self.context_ids = tokenized["context_ids"].numpy() - self.answer_ids = tokenized["answer_ids"].numpy() - self.mask = tokenized["mask"].numpy() - self.metadata = {k: v for k, v in self.data.items() if k not in ['conversations']} - - return self - - # TODO(pzelasko): for mini-batch sampling purposes, should we consider input_ids or answer_ids - # as representative of the sequence length? Putting input_ids here for now. - - @property - def tokens(self) -> np.ndarray: - return self.input_ids - - @property - def num_tokens(self) -> int: - return self.input_ids.shape[0] + return prompt.encode_dialog( + [ + {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}} + for turn in example.data["conversations"] + ] + ) @dataclass @@ -318,11 +278,6 @@ class NeMoSFTJsonlAdapter: "dataset": str, "category": str, } - - Refer to examples of this format here: - - * TODO: links to examples? - * TODO: links to more detailed schema definition? """ paths: Union[Pathlike, list[Pathlike]] @@ -343,6 +298,11 @@ def __iter__(self) -> Iterator[NeMoSFTExample]: yield NeMoSFTExample(data, language=self.language) +""" +NeMoMultimodalConversation: data types, file parser, default prompt formatting logic. +""" + + @dataclass class TextTurn: value: str @@ -357,56 +317,78 @@ class AudioTurn: @dataclass -class NeMoMultimodalConversation(CustomFieldMixin): +class NeMoMultimodalConversation(Formattable, CustomFieldMixin): id: str turns: list[TextTurn | AudioTurn] - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None - - def tokenize( - self, - tokenizer: TokenizerWrapper | TokenizerSpec, - prompt: PromptFormatter = None, - ) -> "NeMoMultimodalConversation": - """ - Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). - Supports BPE tokenizers and aggregate tokenizers. - - The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. - """ - if isinstance(tokenizer, TokenizerWrapper): - tokenizer = tokenizer._tokenizer - if isinstance(tokenizer, AggregateTokenizer): - raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.") - if prompt is None: - prompt = PromptFormatter.resolve("plain")(tokenizer) - elif isinstance(prompt, str): - prompt = PromptFormatter.resolve(prompt)(tokenizer) - - # Collapse consecutive same-role turns into single turn for proper prompt formatting. - turns = groupby( - [ - { - "role": turn.role, - "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, - } - for turn in self.turns - ], - key=lambda turn: turn["role"], - ) - turns = [ - {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} - for role, turn_grp in turns + token_equivalent_duration: float = None + custom: dict = None + + @property + def input_length(self) -> int | None: + if self.context_ids is None: + return None + extra = _compute_num_audio_tokens(self, "context") + return self.context_ids.shape[0] + extra + + @property + def output_length(self) -> int | None: + if self.answer_ids is None: + return None + extra = _compute_num_audio_tokens(self, "answer") + return self.answer_ids.shape[0] + extra + + @property + def total_length(self) -> int | None: + if self.input_ids is None: + return None + extra = _compute_num_audio_tokens(self, "all") + return self.input_ids.shape[0] + extra + + +def _compute_num_audio_tokens(example: NeMoMultimodalConversation, mode: Literal["context", "answer", "all"]) -> int: + assert example.token_equivalent_duration is not None, ( + "Cannot compute the length of a NeMoMultimodalConversation: " + "token_equivalent_duration must be set in order to estimate the number of tokens equivalent to audio turns. " + "Did you forget to set token_equivalent_duration option in your dataloading config? " + "Tip: generally it should be set to frame_shift * total_subsampling_factor of your audio encoder model." + ) + match mode: + case "context": + turns = example.turns[:-1] + case "answer": + turns = example.turns[-1] + case "all": + turns = example.turns + case _: + raise RuntimeError(f"invalid mode for number of audio token computation: {mode}") + return sum( + [ + # subtract 1 for each audio locator tag as its token will be replaced + math.ceil(turn.cut.duration / example.token_equivalent_duration) - 1 + for turn in turns + if isinstance(turn, AudioTurn) ] - ans = prompt.encode_dialog(turns) - self.input_ids = ans["input_ids"] - self.context_ids = ans["context_ids"] - self.answer_ids = ans["answer_ids"] - self.mask = ans["mask"] + ) - return self + +@registered_prompt_format_fn(NeMoMultimodalConversation) +def default_multimodal_conversation_prompt_format_fn(example: NeMoMultimodalConversation, prompt): + # Collapse consecutive same-role turns into single turn for proper prompt formatting. + turns = groupby( + [ + { + "role": turn.role, + "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, + } + for turn in example.turns + ], + key=lambda turn: turn["role"], + ) + turns = [ + {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} + for role, turn_grp in turns + ] + return prompt.encode_dialog(turns) @dataclass @@ -434,6 +416,7 @@ class NeMoMultimodalConversationJsonlAdapter: manifest_filepath: str | list[str] audio_locator_tag: str tarred_audio_filepaths: str | list[str] = None + token_equivalent_duration: float = None shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" @@ -473,298 +456,5 @@ def __iter__(self) -> Iterator[NeMoMultimodalConversation]: ) for turn in data["conversations"] ], + token_equivalent_duration=self.token_equivalent_duration, ) - - -""" -The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py -with minimal modifications in order to avoid importing the NLP collection. - -We require this code for on-the-fly text example tokenization in a compatible way with Megatron, -so that we can determine the mini-batch sizes using the token counts. -""" - - -def preprocess( - source: dict, - tokenizer: TokenizerSpec, - name_end_token_ids: int, - label_start_ids: list, - special_tokens: dict, - num_turn_start_tokens: int, -): - """ - Given a conversation list. This transform: - 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; - 2. Concatenate conversations together; - 3. Tokenize the concatenated conversation; - 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. - """ - header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(source, special_tokens) - # tokenize conversations - input_ids = tokenizer.text_to_ids(conversation) - target = copy.deepcopy(input_ids) - header_tokens = tokenizer.text_to_ids(header) - header_len = len(header_tokens) - - ids = [] - tokenized_lens = [] - assert torch.equal(torch.tensor(target[:header_len]), torch.tensor(header_tokens)) - for s in source['conversations']: - # hack to remove the extra empty token in front - id1 = tokenizer.text_to_ids(PREFIX_STR + s["value"]) - id2 = tokenizer.text_to_ids(PREFIX_STR) - tokenized_sentence = id1[len(id2) :] - ids.append(torch.tensor(tokenized_sentence)) - tokenized_lens.append(len(tokenized_sentence)) - speakers = [sentence["from"] for sentence in source['conversations']] - assert mask_role in speakers, "mask role not in the conversation" - target = torch.LongTensor(target) - # not going to train on the header - target[:header_len] = IGNORE_INDEX - input_ids = torch.LongTensor(input_ids) - _mask_targets( - target, - tokenized_lens, - speakers, - header_len, - ids, - tokenizer, - mask_role, - data_type, - name_end_token_ids, - special_tokens, - label_start_ids, - num_turn_start_tokens, - ) - mask = (target != IGNORE_INDEX).bool() - assert mask.sum().item() != 0, "mask is empty" - # Choose the last conversation as answer other history are context - last_ignore_index_pos = torch.nonzero(target == IGNORE_INDEX)[-1].item() + 1 - context_ids = input_ids[:last_ignore_index_pos] - answer_ids = input_ids[last_ignore_index_pos:] - return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids) - - -def _build_samples_mapping(tokenizer, special_tokens): - # Copied from gpt_sft_chat_dataset.py - LABEL_START = special_tokens['label_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - id1 = tokenizer.text_to_ids(PREFIX_STR) - id2 = tokenizer.text_to_ids(PREFIX_STR + LABEL_START) - label_start_tokens = id2[len(id1) :] - - id1 = tokenizer.text_to_ids(PREFIX_STR + END_NAME_SIGNAL) - id2 = tokenizer.text_to_ids(PREFIX_STR) - name_end_token_ids = id1[len(id2) :] - - id1 = tokenizer.text_to_ids(PREFIX_STR + special_tokens['turn_start']) - id2 = tokenizer.text_to_ids(PREFIX_STR) - num_turn_start_tokens = len(id1) - len(id2) - - return label_start_tokens, name_end_token_ids, num_turn_start_tokens - - -PREFIX_STR = ( - "\x00" # the prefix string used in the tokenizer to deal with the added empty token for some of the tokenizers -) - -IGNORE_INDEX = -100 -SYSTEM_TOKEN = "System" - -TYPE_INSTRUCTION = { - 'TEXT_TO_VALUE': "", - 'VALUE_TO_TEXT': '', -} - - -def _get_header_conversation_type_mask_role(source, special_tokens): - END_SIGNAL = special_tokens['end_of_turn'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - data_type = None - if 'type' in source: - data_type = source['type'] - if data_type is not None: - assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported" - # add end signal and concatenate together - conversation = source['system'] - if data_type is not None: - if TYPE_INSTRUCTION[data_type] != '': - conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type] - mask_role = source.get('mask', 'User') - header = f"{special_tokens['system_turn_start']}{SYSTEM_TOKEN}{END_NAME_SIGNAL}{conversation}{END_SIGNAL}" - conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type, special_tokens) - return header, conversation, data_type, mask_role - - -def identify_start_index_of_subsequence(subsequence, sequence): - """find the location of the small tensor in the large tensor. - e.g. small = [1,3], large = [2,3,1,3], returns 2 - small = [3,2], large = [2,3,1,3], returns -1 - Args: - small (tensor): small tensor - large (tensor): large tensor - """ - for i in range(sequence.size(0) - subsequence.size(0) + 1): - if torch.equal(sequence[i : i + subsequence.size(0)], subsequence): - return i - return -1 - - -def _mask_targets( - target, - tokenized_lens, - speakers, - header_len, - s_ids, - tokenizer, - mask_role, - gtype, - name_end_token_ids, - special_tokens, - label_start_ids, - num_turn_start_tokens, -): - """This function masks the tokens so the loss is computed only on the non-masked role's responses. - For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes. - - Args: - target (Tensor): input ids - tokenized_lens (List[int]): array of lengths of each turns - speakers (List[str]): array of speakers of each turns - header_len (int): the system prompt length - s_ids (List[Tensor]): array of tokenized ids of each turns - tokenizer (TokenizerSpec): tokenizer object - mask_role (str): the speaker id to be masked from loss computation - gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT' - name_end_token_ids (int): end of name token ids - special_tokens (dict): special tokens used for the chat prompt. It has the keys: system_turn_start, turn_start, label_start, end_of_turn - label_start_ids (list): list of label start token ids, - num_turn_start_tokens (int): number of tokens of the turn_start str - """ - TURN_TOKEN = special_tokens['turn_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - label_start_ids = torch.tensor(label_start_ids) - name_end_token_ids = torch.tensor(name_end_token_ids) - - cur_idx = header_len - tgt_len = target.shape[0] - for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)): - # note, sentence piece will add extra empty token in front. has to compute the diff - id1 = tokenizer.text_to_ids(PREFIX_STR) - id2 = tokenizer.text_to_ids(PREFIX_STR + TURN_TOKEN + speaker + END_NAME_SIGNAL) - skip_name_len = len(id2) - len( - id1 - ) # s_ids[:skip_name_len] is the name part of the prompt 'TURN_TOKEN + speaker + END_NAME_SIGNAL' - # get the position of the label start string in this turn - location = identify_start_index_of_subsequence(label_start_ids, s_id) - - if location >= 0: - # if it contains the label start tokens - if gtype == 'VALUE_TO_TEXT': - # handles the case that condition on labels to generate respone - # the next token after the name part of the prompt is the beginning of the label start tokens - assert skip_name_len == location - # find the first new line token after the label part, which indicates the end of the whole label string - # newline_loc = torch.where((s_id[skip_name_len:] == name_end_token_ids))[0] - newline_loc = identify_start_index_of_subsequence(name_end_token_ids, s_id[skip_name_len:]) - if newline_loc < 0: - # cannot find new line token, which means the the whole turn is just a partial label string. Mask the whole turn - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - continue - # skip the label part and the new line token - more_skip_len = newline_loc + len(name_end_token_ids) - # skip the name part and the label part - skip_name_len += more_skip_len - elif gtype == 'TEXT_TO_VALUE': - # handles the case that condition on response to generate label - # skip the name part, response and the label start tokens part, the remainder is the label string without label start, e.g. 'quality:9,toxicity:8...' - skip_name_len = location + len(label_start_ids) - if cur_idx >= tgt_len: - break - # elif cur_idx + tokenized_len < tgt_len: - # # Check whether the mask is applied to the correct position, the first token is turn start tokens - # if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]): - # logging.warning("a sentence mismatches the corresponding piece " "in the conversation") - if i == 0 and (gtype == 'VALUE_TO_TEXT' or gtype is None): - # mask the first turn completely to provide at least one turn as context for the rest - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and i == 1 and gtype == 'TEXT_TO_VALUE': - # leave the first turn start tag unmasked, servers severs as the end of turn signal - target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and (i > 1): - # leave the first turn start tag unmasked, which severs as the end of turn signal - target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and (i <= 1): - # mask out everything in the second turn - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - else: - # mask up to name part, label part for VALUE_TO_TEXT, or name part, response and label start tokens for TEXT_TO_VALUE, or just the name part if gtype is None - target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX - cur_idx += tokenized_len - - -def _add_speaker_and_signal(header, source, mask_role, gtype, special_tokens): - TURN_TOKEN = special_tokens['turn_start'] - END_SIGNAL = special_tokens['end_of_turn'] - LABEL_START = special_tokens['label_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - """Add speaker and start/end signal on each round.""" - BEGIN_SIGNAL = "" - conversation = header - for i, sentence in enumerate(source): - sentence_from = sentence["from"] - role_token = TURN_TOKEN - if gtype is None: - sentence["value"] = ( - BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL - ) - elif gtype == "VALUE_TO_TEXT": - sentence["value"] = ( - BEGIN_SIGNAL - + role_token - + sentence_from - + END_NAME_SIGNAL - + ( - response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) - if 'label' in sentence - else '' - ) - + sentence["value"] - + END_SIGNAL - ) - elif gtype == "TEXT_TO_VALUE": - sentence["value"] = ( - BEGIN_SIGNAL - + role_token - + sentence_from - + END_NAME_SIGNAL - + sentence["value"] - + END_SIGNAL - + ( - response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) - if 'label' in sentence - else '' - ) - ) - else: - raise ValueError( - f"source type {gtype} not supported, only 'VALUE_TO_TEXT' and 'TEXT_TO_VALUE' are supported" - ) - conversation += sentence["value"] - # if the last turn is not masked, add next token start token to the end, which will be included for loss calculation - if sentence_from != mask_role and i == len(source) - 1: - conversation += TURN_TOKEN - return conversation - - -def response_value_formater(label, label_start, end_signal): - if isinstance(label, str): - return label_start + label + end_signal - elif label is None: - return '' - else: - raise ValueError(f'Unknown label type {type(label)}, only str type is supported') diff --git a/nemo/collections/common/data/prompt_fn.py b/nemo/collections/common/data/prompt_fn.py new file mode 100644 index 000000000000..e55610ee7fa8 --- /dev/null +++ b/nemo/collections/common/data/prompt_fn.py @@ -0,0 +1,77 @@ +from typing import Callable, Type + +import torch + + +PromptFormatFnReturnType = dict[str, list[torch.Tensor]] +PromptFormatSignature = Callable[[object, object], PromptFormatFnReturnType] +PROMPT_FORMAT_FNS: dict[tuple[Type, Type] | Type, PromptFormatSignature] = {} + + +def registered_prompt_format_fn(example_type: Type, formatter_type: Type = None): + """ + Decorator for registering text prompt functions. + It allows to select the right prompt formatting function based on the types of the + example and the prompt formatter, allowing different strategies for formatting different + types of data with different prompt formats. + + When formatter_type is set None, registers a default prompt format function for a given data type. + + Example:: + + >>> @registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) + ... def my_src_tgt_text_prompt(example, formatter): + ... pass + ... + ... @registered_prompt_format_fn(Cut, Llama2PromptFormatter) + ... def my_audio_prompt(example, formatter): + ... pass + ... + ... prompt_fn = get_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) + """ + + def _decorator(prompt_fn: Callable[[object, object], dict[str, list[torch.Tensor]]]): + global PROMPT_FORMAT_FNS + if formatter_type is None: + PROMPT_FORMAT_FNS[example_type] = prompt_fn + else: + PROMPT_FORMAT_FNS[(example_type, formatter_type)] = prompt_fn + return prompt_fn + + return _decorator + + +def get_prompt_format_fn(example: Type | object, prompt: Type | object = None) -> PromptFormatSignature: + """See the documentation of ``text_prompt_formatter`` above.""" + + # If the user provided objects, resolve their types. + if not isinstance(example, type): + example = type(example) + if not isinstance(prompt, type): + prompt = type(prompt) + + # For the example type, first try to match it directly, then fall back to its parent classes. + for subtype in example.mro(): + + # First check the match for specific example type and a specific prompt format. + if (subtype, prompt) in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[(subtype, prompt)] + + # Then for the same specific example type, fall back to its default prompt formatter implementation. + # Note: the data example type takes precedence over the prompt formatter type for this resolution. + if subtype in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[subtype] + + raise ValueError( + f"Unknown prompt format function for ({example}, {prompt}). " + f"Available choices are: {list(PROMPT_FORMAT_FNS.keys())}" + ) + + +def apply_prompt_format_fn(example: object | Type, prompt: object | Type) -> PromptFormatFnReturnType: + """ + Utility for resolving the prompt format function and applying it to an example in one go. + See the documentation of ``text_prompt_formatter`` above. + """ + fn = get_prompt_format_fn(example, prompt) + return fn(example, prompt) diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py index e4b7785c6243..77e6c346fd4d 100644 --- a/nemo/collections/common/prompts/__init__.py +++ b/nemo/collections/common/prompts/__init__.py @@ -1,5 +1,4 @@ from nemo.collections.common.prompts.canary import CanaryPromptFormatter -from nemo.collections.common.prompts.fn import get_prompt_format_fn, registered_prompt_format_fn from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.prompts.gemma import GemmaPromptFormatter from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index d06e01b50666..eb7412920576 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -1,14 +1,12 @@ -from collections import defaultdict from typing import Any import torch -from lhotse import CutSet, MonoCut -from lhotse.cut import MixedCut +from lhotse import MonoCut +from lhotse.cut import Cut, MixedCut from lhotse.utils import ifnone -from nemo.collections.common.prompts.fn import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.common.tokenizers.canary_tokenizer import ( CANARY_BOS, CANARY_EOS, @@ -95,8 +93,8 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s return slot_values -@registered_prompt_format_fn -def canary(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: +@registered_prompt_format_fn(Cut, CanaryPromptFormatter) +def canary(cut: Cut, prompt: CanaryPromptFormatter) -> dict[str, torch.Tensor]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -119,59 +117,51 @@ def canary(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: (i.e., spoken language in the recording) and the second occurrence is for the "target" language (i.e., the language in which we are going to output the text). """ - prompt = CanaryPromptFormatter(tokenizer) - - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut._first_non_padding_cut - if not isinstance(cut, MonoCut): - raise TypeError( - f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" - ) - - # first, validate the utterance - expected_slots = set(prompt.get_slots("user")) - missing_keys = expected_slots - set(cut.custom) - if "task" in missing_keys and "taskname" in cut.custom: - # Compatibility with "old" Canary manifest format. - # For compatbility with inference options, this slot is now called "task". - cut.custom["task"] = cut.custom["taskname"] - missing_keys.remove("task") - if missing_keys: - raise RuntimeError( - f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" - f"Please ensure that every utterance in the input manifests contains these keys." - ) - - turns = [ - dict( - role="user", - slots={ - **{slot: cut.custom[slot] for slot in expected_slots}, - prompt.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, - }, - ) - ] - # If data has no transcript, create empty response with only. - text = ' '.join(s.text for s in cut.supervisions if s.text is not None) - turns.append( - dict( - role="assistant", - slots={ - "text": text, - prompt.PROMPT_LANGUAGE_SLOT: ifnone(cut.supervisions[0].language, cut.custom.get("target_lang")), - }, - ), + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" ) - for k, v in prompt.encode_dialog(turns).items(): - if k == "answer_ids": - assert ( - v[-1].item() == prompt.tokenizer.eos - ), f"Expected the last token in answer_ids to be EOS, but we got {v}" - ans[k].append(v[:-1]) # Strip Canary's EOS - else: - ans[k].append(v) + # first, validate the utterance + expected_slots = set(prompt.get_slots("user")) + missing_keys = expected_slots - set(cut.custom) + if "task" in missing_keys and "taskname" in cut.custom: + # Compatibility with "old" Canary manifest format. + # For compatbility with inference options, this slot is now called "task". + cut.custom["task"] = cut.custom["taskname"] + missing_keys.remove("task") + if missing_keys: + raise RuntimeError( + f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" + f"Please ensure that every utterance in the input manifests contains these keys." + ) + turns = [ + dict( + role="user", + slots={ + **{slot: cut.custom[slot] for slot in expected_slots}, + prompt.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, + }, + ) + ] + # If data has no transcript, create empty response with only. + text = ' '.join(s.text for s in cut.supervisions if s.text is not None) + turns.append( + dict( + role="assistant", + slots={ + "text": text, + prompt.PROMPT_LANGUAGE_SLOT: ifnone(cut.supervisions[0].language, cut.custom.get("target_lang")), + }, + ), + ) + + ans = prompt.encode_dialog(turns) + assert ( + ans["answer_ids"][-1].item() == prompt.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}" + ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS return ans diff --git a/nemo/collections/common/prompts/fn.py b/nemo/collections/common/prompts/fn.py deleted file mode 100644 index ce7d2fc8a69a..000000000000 --- a/nemo/collections/common/prompts/fn.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Callable, Sequence - -import torch -from lhotse import CutSet - -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - -PROMPT_FORMAT_FNS = {} - - -def registered_prompt_format_fn( - prompt_fn: Callable[[CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]] -): - """ - Decorator for registering prompt functions under a name. - - Example:: - - >>> @registered_prompt_format_fn - ... def my_prompt(cuts, tokenizer): - ... pass - ... - ... prompt_fn = get_prompt_format_fn("my_prompt") - """ - global PROMPT_FORMAT_FNS - - PROMPT_FORMAT_FNS[prompt_fn.__name__] = prompt_fn - return prompt_fn - - -def get_prompt_format_fn( - name: str, -) -> Callable[[CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]]: - if name not in PROMPT_FORMAT_FNS: - raise ValueError( - f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" - ) - return PROMPT_FORMAT_FNS[name] diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py index 128a5689e07f..5dd0e25c306e 100644 --- a/nemo/collections/common/prompts/gemma.py +++ b/nemo/collections/common/prompts/gemma.py @@ -2,14 +2,10 @@ Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format """ -from collections import defaultdict +from lhotse.cut import Cut, MixedCut -from lhotse import CutSet -from lhotse.cut import MixedCut - -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec GEMMA_BOS = "" GEMMA_END_OF_TURN = "" @@ -36,29 +32,19 @@ class GemmaPromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def gemma1(cuts: CutSet, tokenizer: TokenizerSpec): - prompt = GemmaPromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - if cut.has_custom("context"): - context = cut.context - elif cut.has_custom("question"): - context = cut.question - else: - context = cut.default_context - - turns = [] - if cut.has_custom("system_prompt"): - turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) - else: - turns.append({"role": "user", "slots": {"message": context}}) - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) +@registered_prompt_format_fn(Cut, GemmaPromptFormatter) +def gemma1(cut: Cut, prompt: GemmaPromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) + turns = [{"role": "user", "slots": {"message": context}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) - return ans + return prompt.encode_dialog(turns) diff --git a/nemo/collections/common/prompts/llama.py b/nemo/collections/common/prompts/llama.py index 4364af608497..ce4a131a921b 100644 --- a/nemo/collections/common/prompts/llama.py +++ b/nemo/collections/common/prompts/llama.py @@ -1,13 +1,9 @@ -from collections import defaultdict - import torch -from lhotse import CutSet -from lhotse.cut import MixedCut -from lhotse.utils import ifnone +from lhotse.cut import Cut, MixedCut -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import BOS_SLOT, EOS_SLOT, Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec class Llama2PromptFormatter(PromptFormatter): @@ -41,32 +37,67 @@ class Llama2PromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def llama2(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: - prompt = Llama2PromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - if cut.has_custom("context"): - context = cut.context - elif cut.has_custom("question"): - context = cut.question - else: - context = cut.default_context - - turns = [] - if cut.has_custom("system_prompt"): - turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) - else: - turns.append({"role": "user", "slots": {"message": context}}) - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) - - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) - - return ans +@registered_prompt_format_fn(Cut, Llama2PromptFormatter) +def llama2(cut: Cut, prompt: Llama2PromptFormatter) -> dict[str, torch.Tensor]: + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context + + turns = [] + if cut.has_custom("system_prompt"): + turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) + else: + turns.append({"role": "user", "slots": {"message": context}}) + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) + return prompt.encode_dialog(turns) + + +@registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) +def llama2_src_tgt_text_example(example: SourceTargetTextExample, prompt: Llama2PromptFormatter): + if example.question is not None: + user_turn = { + "role": "system_and_user", + "slots": {"system": example.question.text, "message": example.source.text}, + } + else: + user_turn = { + "role": "user", + "slots": {"message": example.source.text}, + } + return prompt.encode_dialog( + [ + user_turn, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) + + +@registered_prompt_format_fn(NeMoSFTExample, Llama2PromptFormatter) +def llama2_sft_text_example(example: NeMoSFTExample, prompt: Llama2PromptFormatter): + first_user_turn = example.data["conversations"][0]["value"] + if "system" in example.data and example.data["system"]: + first_turn = { + "role": "system_and_user", + "slots": {"system": example.data["system"], "message": first_user_turn}, + } + else: + first_turn = { + "role": "user", + "slots": {"message": first_user_turn}, + } + return prompt.encode_dialog( + [first_turn] + + [ + {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}} + for turn in example.data["conversations"][1:] + ] + ) LLAMA3_BOS = "<|begin_of_text|>" diff --git a/nemo/collections/common/prompts/plain.py b/nemo/collections/common/prompts/plain.py index efd7d989a9e2..de7fbe5a1830 100644 --- a/nemo/collections/common/prompts/plain.py +++ b/nemo/collections/common/prompts/plain.py @@ -1,11 +1,7 @@ -from collections import defaultdict +from lhotse.cut import Cut, MixedCut -from lhotse import CutSet -from lhotse.cut import MixedCut - -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec class PlainPromptFormatter(PromptFormatter): @@ -31,20 +27,17 @@ class PlainPromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def plain(cuts: CutSet, tokenizer: TokenizerSpec): - prompt = PlainPromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - assert cut.has_custom("context"), f"Missing mandatory metadata key 'context' in {cut=}" - - turns = [{"role": "user", "slots": {"message": cut.context}}] - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) +@registered_prompt_format_fn(Cut, PlainPromptFormatter) +def plain(cut: Cut, prompt: PlainPromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + ctx = cut.context + else: + ctx = "" - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) + turns = [{"role": "user", "slots": {"message": ctx}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) - return ans + return prompt.encode_dialog(turns) diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py index 6ec69862dc4a..0d89adcdb55a 100644 --- a/nemo/collections/common/prompts/t5nmt.py +++ b/nemo/collections/common/prompts/t5nmt.py @@ -1,11 +1,12 @@ from collections import defaultdict + import torch -from lhotse import CutSet, MonoCut -from lhotse.cut import MixedCut +from lhotse import MonoCut +from lhotse.cut import Cut, MixedCut -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec class T5NMTPromptFormatter(PromptFormatter): @@ -43,42 +44,48 @@ def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: d ) -@registered_prompt_format_fn -def t5nmt(cuts: CutSet, tokenizer: TokenizerSpec) -> dict[str, torch.Tensor]: - prompt = T5NMTPromptFormatter(tokenizer) - +@registered_prompt_format_fn(Cut, T5NMTPromptFormatter) +def t5nmt(cut: Cut, prompt: T5NMTPromptFormatter) -> dict[str, torch.Tensor]: ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut._first_non_padding_cut - if not isinstance(cut, MonoCut): - raise TypeError( - f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" - ) + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" + ) - if hasattr(cut, "context"): - context = cut.context - elif hasattr(cut, "default_context"): - context = cut.default_context - else: - raise RuntimeError("Missing context/default_context custom field in cut: {cut}") + if hasattr(cut, "context"): + context = cut.context + elif hasattr(cut, "default_context"): + context = cut.default_context + else: + raise RuntimeError("Missing context/default_context custom field in cut: {cut}") - turns = [ + turns = [ + dict( + role="user", + # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. + slots={"target_lang": context, "message": ""}, + ), + ] + if len(cut.supervisions) > 0 and cut.supervisions[0].text is not None: + turns.append( dict( - role="user", - # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. - slots={"target_lang": context, "message": ""}, - ), - ] - if len(cut.supervisions) > 0 and cut.supervisions[0].text is not None: - turns.append( - dict( - role=prompt.OUTPUT_ROLE, - slots={"message": cut.supervisions[0].text}, - ) + role=prompt.OUTPUT_ROLE, + slots={"message": cut.supervisions[0].text}, ) - enc = prompt.encode_dialog(turns) - for k, v in enc.items(): - ans[k].append(v) + ) + return prompt.encode_dialog(turns) - return ans + +@registered_prompt_format_fn(SourceTargetTextExample, T5NMTPromptFormatter) +def t5nmt_src_tgt_text_example(example: SourceTargetTextExample, prompt: T5NMTPromptFormatter): + ctx = f"<{example.target.language}>" + if example.has_custom("extra_prompt"): + ctx = f"{ctx} {example.extra_prompt}" + return prompt.encode_dialog( + [ + {"role": "user", "slots": {"message": example.source.text, "target_lang": ctx}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 3b3268423812..df91ad4f5e2f 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -404,7 +404,9 @@ def test_predict_step(self, asr_model, test_data_dir): c.target_lang = "en" c.task = "asr" c.pnc = "no" - dataset = PromptedAudioToTextLhotseDataset(tokenizer=asr_model.tokenizer, prompt_format_fn=canary) + dataset = PromptedAudioToTextLhotseDataset( + tokenizer=asr_model.tokenizer, prompt=CanaryPromptFormatter(asr_model.tokenizer) + ) batch = dataset[cuts] # Numpy array test @@ -434,7 +436,9 @@ def test_FrameBatchMultiTaskAED(self, asr_model, test_data_dir): @pytest.mark.unit def test_prompted_dataset(asr_model): - dataset = PromptedAudioToTextLhotseDataset(tokenizer=asr_model.tokenizer, prompt_format_fn=canary) + dataset = PromptedAudioToTextLhotseDataset( + tokenizer=asr_model.tokenizer, prompt=CanaryPromptFormatter(asr_model.tokenizer) + ) cuts = DummyManifest(CutSet, begin_id=0, end_id=3, with_data=True) diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 40a587099410..0e2179bb2d89 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -1352,18 +1352,18 @@ def test_text_file_pairs_shards_input(txt_pair_paths_shards: tuple[str, str], qu @pytest.fixture(scope="session") -def en_es_tokenizer(tmp_path_factory, txt_en_path, txt_es_path) -> TokenizerWrapper: +def en_es_tokenizer(tmp_path_factory, txt_en_path, txt_es_path) -> SentencePieceTokenizer: tmpdir = tmp_path_factory.mktemp("en_es_tokenizer") text_path = tmpdir / "text.txt" text_path.write_text(txt_en_path.read_text() + "\n" + txt_es_path.read_text()) create_spt_model(text_path, vocab_size=128, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) - return TokenizerWrapper(SentencePieceTokenizer(str(tmpdir / "tokenizer.model"))) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) def test_multimodal_text_audio_dataloading( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1396,6 +1396,7 @@ def test_multimodal_text_audio_dataloading( "shuffle": True, "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": BT, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1437,16 +1438,16 @@ def test_multimodal_text_audio_dataloading( assert isinstance(ex.source.text, str) assert isinstance(ex.target.text, str) assert isinstance(ex.question.text, str) - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) def test_multimodal_text_audio_dataloading_zip_strategy( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1471,6 +1472,7 @@ def test_multimodal_text_audio_dataloading_zip_strategy( ], "shuffle": True, "num_workers": 0, + "prompt_format": "plain", "use_multimodal_sampling": True, "batch_tokens": BT, # How to set token equivalent duration in actual training? @@ -1500,6 +1502,7 @@ def test_multimodal_text_audio_dataloading_zip_strategy( "shuffle": True, "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": 64, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1543,10 +1546,10 @@ def test_multimodal_text_audio_dataloading_zip_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) b = batches[1] assert isinstance(b, lhotse.CutSet) @@ -1565,16 +1568,16 @@ def test_multimodal_text_audio_dataloading_zip_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) def test_multimodal_text_audio_dataloading_round_robin_strategy( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1600,6 +1603,7 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( "shuffle": True, "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": BT, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1626,6 +1630,7 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( }, ], "shuffle": True, + "prompt_format": "plain", "num_workers": 0, "use_multimodal_sampling": True, "batch_tokens": BT, @@ -1677,10 +1682,10 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index 4ded7c25d12a..51b3085a8fc8 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -127,8 +127,6 @@ def test_multimodal_conversation_input(multimodal_conversations_path): assert isinstance(t, TextTurn) assert t.role == "assistant" assert t.value == "Of course!" - for key in ("input_ids", "context_ids", "answer_ids", "mask"): - assert getattr(ex, key) is None # not tokenized/prompted @pytest.fixture diff --git a/tests/collections/common/test_lhotse_prompt_format_data_types.py b/tests/collections/common/test_lhotse_prompt_format_data_types.py new file mode 100644 index 000000000000..4347c467a4ae --- /dev/null +++ b/tests/collections/common/test_lhotse_prompt_format_data_types.py @@ -0,0 +1,283 @@ +import lhotse.serialization +import pytest +from lhotse import CutSet, SupervisionSegment +from lhotse.cut import Cut +from lhotse.testing.dummies import dummy_cut + +from nemo.collections.common.data import ( + NeMoSFTExample, + SourceTargetTextExample, + TextExample, + get_lhotse_dataloader_from_config, +) +from nemo.collections.common.tokenizers import SentencePieceTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model + + +@pytest.fixture +def tokenizer(tmp_path_factory): + tmpdir = tmp_path_factory.mktemp("tok") + text_path = tmpdir / "text.txt" + text_path.write_text("\n".join(chr(i) for i in range(256))) + create_spt_model( + text_path, + vocab_size=512, + sample_size=-1, + do_lower_case=False, + output_dir=str(tmpdir), + bos=True, + eos=True, + user_defined_symbols=[ + "[INST]", + "[/INST]", + "<>", + "<>", + "[audio]", + "", + "", + ], + ) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +@pytest.fixture +def cuts_path(tmp_path_factory): + tmp_path = tmp_path_factory.getbasetemp() / "cuts.jsonl" + c = dummy_cut(0, duration=1.0, supervisions=[SupervisionSegment("", "", 0, 1.0, text="dummy text")]) + c.context = "dummy context" + CutSet([c]).to_file(tmp_path) + return tmp_path + + +@pytest.fixture +def src_tgt_example(tmp_path_factory): + d = tmp_path_factory.mktemp("src_tgt") + (d / "src.txt").write_text("an example") + (d / "tgt.txt").write_text("elpmaxe na") + return (d / "src.txt"), (d / "tgt.txt") + + +@pytest.fixture +def nemo_sft_example(tmp_path_factory): + tmp_path = tmp_path_factory.getbasetemp() / "nemo_sft.jsonl" + lhotse.serialization.save_to_jsonl( + [ + { + "system": "", + "mask": "User", + "dataset": "", + "conversations": [ + { + "from": "User", + "value": "Hi, how are you?", + }, + { + "from": "Assistant", + "value": "Good day, I'm a useful assistant.", + }, + ], + } + ], + tmp_path, + ) + return tmp_path + + +class Identity: + def __getitem__(self, item): + return item + + +def test_prompt_format_cut(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "min_duration": 0, + "max_duration": 10, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, Cut) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] dummy context [/INST] dummy text" + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] dummy context [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "dummy text" + + +def test_prompt_format_cut_filtered_out(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "min_duration": 0, + "max_duration": 0.5, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + next(iter(dl)) + + +def test_prompt_format_cut_max_tokens_has_no_filtering_effect(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "token_equivalent_duration": 0.1, + "min_tokens": 1, + "max_tokens": 2, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, Cut) + + +def test_prompt_format_src_tgt(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + {"type": "txt_pair", "source_paths": src_tgt_example[0], "target_paths": src_tgt_example[1]} + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 50, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, SourceTargetTextExample) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] an example [/INST] elpmaxe na" + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] an example [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "elpmaxe na" + + +def test_prompt_format_src_tgt_filtered_out(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + {"type": "txt_pair", "source_paths": src_tgt_example[0], "target_paths": src_tgt_example[1]} + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 10, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + batch = next(iter(dl)) + + +def test_prompt_format_src_tgt_2d(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": src_tgt_example[0], + "target_paths": src_tgt_example[1], + "target_language": "reversed", + } + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "t5nmt", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 50, + "use_total_length": False, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, SourceTargetTextExample) + assert tokenizer.ids_to_text(ex.input_ids) == " an example elpmaxe na" + assert tokenizer.ids_to_text(ex.context_ids) == " an example" + assert tokenizer.ids_to_text(ex.answer_ids) == "elpmaxe na" + + +def test_prompt_format_nemo_sft(nemo_sft_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [{"type": "nemo_sft_jsonl", "paths": nemo_sft_example}], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 100, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, NeMoSFTExample) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] Hi, how are you? [/INST] Good day, I'm a useful assistant." + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] Hi, how are you? [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "Good day, I'm a useful assistant." + + +def test_prompt_format_nemo_sft_filtered_out(nemo_sft_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [{"type": "nemo_sft_jsonl", "paths": nemo_sft_example}], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 5, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + batch = next(iter(dl)) diff --git a/tests/collections/common/test_lhotse_seqlen_filters.py b/tests/collections/common/test_lhotse_seqlen_filters.py new file mode 100644 index 000000000000..f9dbc49fa20b --- /dev/null +++ b/tests/collections/common/test_lhotse_seqlen_filters.py @@ -0,0 +1,171 @@ +from copy import deepcopy + +import numpy as np +import pytest +from lhotse import SupervisionSegment +from lhotse.testing.dummies import dummy_cut + +from nemo.collections.common.data.lhotse.dataloader import ( + DurationFilter, + TokenCountFilter, + TokenPerSecondFilter, + TokenPerTokenFilter, +) +from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample + + +@pytest.fixture +def cut(): + c = dummy_cut(0, duration=1.0, supervisions=[SupervisionSegment("", "", 0, 1.0, text="dummy")]) + c.supervisions[0].tokens = [1, 37, 12, 2] + return c + + +def test_cut_duration_filter(cut): + f = DurationFilter(0, 10) + assert f(cut) == True + + f = DurationFilter(0, 0.5) + assert f(cut) == False + + f = DurationFilter(1.5, 2.0) + assert f(cut) == False + + +def test_cut_token_per_second_filter(cut): + f = TokenPerSecondFilter(tps_min=0.0, tps_max=5.0) + assert f(cut) == True + + f = TokenPerSecondFilter(tps_min=0.0, tps_max=1.0) + assert f(cut) == False + + f = TokenPerSecondFilter(tps_min=10.0, tps_max=12.0) + assert f(cut) == False + + +def test_cut_passes_by_token_count_and_tpt_filter(cut): + assert TokenCountFilter(1, 10, use_total_length=True)(cut) == True + assert TokenPerTokenFilter(1, 10)(cut) == True + + +def test_cut_passes_by_token_count_and_tpt_filter(cut): + assert TokenCountFilter(1, 10, use_total_length=True)(cut) == True + assert TokenPerTokenFilter(1, 10)(cut) == True + + +@pytest.fixture +def src_tgt_example(): + return SourceTargetTextExample( + source=TextExample("", tokens=np.array([1, 37, 12, 2])), + target=TextExample("", tokens=np.array([1, 1823, 1245, 2446, 1038, 2])), + ) + + +def test_src_tgt_token_filter_requires_prompt_formatting(src_tgt_example): + with pytest.raises(RuntimeError): + TokenCountFilter(0, 1, True)(src_tgt_example) + + +def test_src_tgt_passes_by_duration_filter(src_tgt_example): + assert DurationFilter(1, 10)(src_tgt_example) == True + assert TokenPerSecondFilter(1, 10)(src_tgt_example) == True + + +def test_src_tgt_token_filter(src_tgt_example): + example = deepcopy(src_tgt_example) + example.input_ids = np.concatenate((example.source.tokens, example.target.tokens)) + example.context_ids = example.source.tokens + example.answer_ids = example.target.tokens + + """ + Input length measurement / encoder-decoder models / 2D bucketing + """ + f = TokenCountFilter(1, 5, use_total_length=False) + assert f(example) == True + + f = TokenCountFilter(1, 3, use_total_length=False) + assert f(example) == False + + f = TokenCountFilter(10, 30, use_total_length=False) + assert f(example) == False + + """ + Total length measurement / decoder-only models / 1D bucketing + """ + f = TokenCountFilter(1, 5, use_total_length=True) + assert f(example) == False + + f = TokenCountFilter(1, 20, use_total_length=True) + assert f(example) == True + + f = TokenCountFilter(1, 3, use_total_length=True) + assert f(example) == False + + f = TokenCountFilter(20, 30, use_total_length=True) + assert f(example) == False + + +@pytest.fixture +def nemo_sft_example(): + example = NeMoSFTExample( + data={ + "system": "", + "mask": "User", + "dataset": "", + "conversations": [ + { + "from": "User", + "value": "Hi, how are you?", + }, + { + "from": "Assistant", + "value": "Good day, I'm a useful assistant.", + }, + ], + }, + ) + return example + + +def test_nemo_sft_token_filter_requires_prompt_formatting(nemo_sft_example): + with pytest.raises(RuntimeError): + TokenCountFilter(0, 1, True)(nemo_sft_example) + + +def test_nemo_sft_passes_by_duration_filter(nemo_sft_example): + assert DurationFilter(1, 10)(nemo_sft_example) == True + assert TokenPerSecondFilter(1, 10)(nemo_sft_example) == True + + +def test_nemo_sft_token_filter(nemo_sft_example): + example = deepcopy(nemo_sft_example) + example.input_ids = np.array([1, 123, 3425, 123, 2345, 324, 54, 2]) + example.context_ids = np.array([1, 123, 3425]) + example.answer_ids = np.array([123, 2345, 324, 54, 2]) + + """ + Input length measurement / encoder-decoder models / 2D bucketing + """ + f = TokenCountFilter(1, 5, use_total_length=False) + assert f(example) == True + + f = TokenCountFilter(1, 2, use_total_length=False) + assert f(example) == False + + f = TokenCountFilter(10, 30, use_total_length=False) + assert f(example) == False + + """ + Total length measurement / decoder-only models / 1D bucketing + """ + f = TokenCountFilter(1, 5, use_total_length=True) + assert f(example) == False + + f = TokenCountFilter(1, 20, use_total_length=True) + assert f(example) == True + + f = TokenCountFilter(1, 3, use_total_length=True) + assert f(example) == False + + f = TokenCountFilter(10, 30, use_total_length=True) + assert f(example) == False From e2245c70a3619aa75da8deb853e950ddfd0eb437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 09:33:31 -0400 Subject: [PATCH 46/63] Refactor sampler constraints / filters into sampling.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 358 ++++-------------- .../common/data/lhotse/sampling.py | 284 ++++++++++++++ .../estimate_duration_bins_2d.py | 5 +- .../speech_recognition/estimate_token_bins.py | 6 +- .../common/test_2d_bucketing_constraint.py | 2 +- .../common/test_lhotse_seqlen_filters.py | 2 +- 6 files changed, 354 insertions(+), 303 deletions(-) create mode 100644 nemo/collections/common/data/lhotse/sampling.py diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 2881ffd2e24b..e9039b049023 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import bisect -import math import os import random import warnings @@ -35,18 +33,20 @@ make_worker_init_fn, ) from lhotse.dataset.dataloading import resolve_seed -from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint, TokenConstraint -from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.dataset.sampling.base import CutSampler, TimeConstraint from lhotse.lazy import LazyFlattener from lhotse.utils import fastcopy, fix_random_seed from omegaconf import DictConfig, OmegaConf from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config -from nemo.collections.common.data.lhotse.text_adapters import ( - Formattable, - NeMoSFTExample, - SourceTargetTextExample, - TextExample, +from nemo.collections.common.data.lhotse.sampling import ( + DurationFilter, + FixedBucketBatchSizeConstraint2D, + MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, + TokenCountFilter, + TokenPerSecondFilter, + TokenPerTokenFilter, ) from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn from nemo.collections.common.prompts import PromptFormatter @@ -452,39 +452,10 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) ) + # Select the strategy customizing Lhotse sampler behaviour. + # Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc. bucket_duration_bins = determine_bucket_duration_bins(config) - if config.use_multimodal_sampling: - if config.bucket_batch_size is not None: - assert ( - bucket_duration_bins is not None - ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." - constraint = MultimodalFixedBucketBatchSizeConstraint2D( - max_seq_len_buckets=bucket_duration_bins, - batch_sizes=config.bucket_batch_size, - token_equivalent_duration=config.token_equivalent_duration, - ) - else: - constraint = MultimodalSamplingConstraint( - token_equivalent_duration=config.token_equivalent_duration, - batch_size=config.batch_size, - batch_tokens=config.batch_tokens, - quadratic_factor=config.quadratic_factor, - ) - else: - if config.bucket_batch_size is not None: - assert ( - bucket_duration_bins is not None - ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." - constraint = FixedBucketBatchSizeConstraint2D( - max_seq_len_buckets=bucket_duration_bins, - batch_sizes=config.bucket_batch_size, - ) - else: - constraint = TimeConstraint( - max_cuts=config.batch_size, - max_duration=config.batch_duration, - quadratic_duration=config.quadratic_duration, - ) + constraint = determine_sampling_constraint(bucket_duration_bins, config) # 3. The sampler. if config.use_bucketing: @@ -562,7 +533,59 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No return sampler, is_tarred +def determine_sampling_constraint(bucket_duration_bins, config): + """ + Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration. + Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D). + It is the appropriate customization point to introduce support of other modalities, + as it defines a method for example sequence length measurement (audio duration, text tokens, etc.). + + Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are + multimodal constraints (joint text + audio) and their 2D bucketing extensions. + """ + if config.use_multimodal_sampling: + if config.bucket_batch_size is not None: + assert ( + bucket_duration_bins is not None + ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." + constraint = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=bucket_duration_bins, + batch_sizes=config.bucket_batch_size, + token_equivalent_duration=config.token_equivalent_duration, + ) + else: + constraint = MultimodalSamplingConstraint( + token_equivalent_duration=config.token_equivalent_duration, + batch_size=config.batch_size, + batch_tokens=config.batch_tokens, + quadratic_factor=config.quadratic_factor, + ) + else: + if config.bucket_batch_size is not None: + assert ( + bucket_duration_bins is not None + ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." + constraint = FixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=bucket_duration_bins, + batch_sizes=config.bucket_batch_size, + ) + else: + constraint = TimeConstraint( + max_cuts=config.batch_size, + max_duration=config.batch_duration, + quadratic_duration=config.quadratic_duration, + ) + return constraint + + def determine_bucket_duration_bins(config): + """ + Returns appropriate bucket bins based on configuration. + If user provided them explicitly, we just pass them along; + otherwise, we try to create provisional bins when min/max duration is available. + We might return None if it's impossible to determine the bins without computing data statistics, + in which case it will be automatically done at the start of training (but may take a few minutes). + """ if config.bucket_duration_bins is not None: # Bucket duration bins are provided: just use them. ans = OmegaConf.to_container(config.bucket_duration_bins) @@ -614,148 +637,6 @@ def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfi return OmegaConf.merge(default, config) -@dataclass -class MultimodalSamplingConstraint(SamplingConstraint): - # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. - # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. - token_equivalent_duration: float | None = None - - # Defines maximum batch size (may be lower than that if batch_length is also specified). - batch_size: int | None = None - - # Defines the total number of tokens in a mini-batch. - # Setting this enables dynamic batch sizes. - # We will use ``token_equivalent_duration`` to convert audio examples to token sizes. - batch_tokens: int | None = None - - # When specified, this value is inversely proportional to the penalty we assign - # to longer examples when measuring their length/duration; - # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty. - # Tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing. - quadratic_factor: float | None = None - - # When False (default), we only consider the input part of the example to determine its length, - # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. - # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). - measure_total_length: bool = False - - _internal = None - - def __post_init__(self): - self._internal = TokenConstraint( - max_tokens=self.batch_tokens, - max_examples=self.batch_size, - quadratic_length=self.quadratic_factor, - ) - - def add(self, example: Any) -> None: - num_tokens = self.measure_length(example) - example.num_tokens = num_tokens - self._internal.add(example) - - def exceeded(self) -> bool: - return self._internal.exceeded() - - def close_to_exceeding(self) -> bool: - return self._internal.close_to_exceeding() - - def reset(self) -> None: - self._internal.reset() - - def measure_length(self, example: Any) -> float: - if isinstance(example, Cut): - audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) - if self.measure_total_length: - # Total length of a Cut (audio+text example) is counted as the sum of: - # * num_tokens in each supervision segment ("utterance") in the Cut - # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) - return audio_len_in_tokens + text_tokens - else: - return audio_len_in_tokens - elif isinstance(example, Formattable): - try: - return example.total_length if self.measure_total_length else example.input_length - except (AttributeError, AssertionError) as e: - raise RuntimeError( - "Couldn't determine the length of a text example; " - "have you provided both prompt_format and tokenizer when instantiating the dataloader?" - ) from e - raise RuntimeError(f"Unsupported example type: {type(example)}") - - -@dataclass -class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): - @property - def bucketing_2d_enabled(self) -> bool: - return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 - - def measure_length(self, example: Any) -> tuple[float, float]: - if self.bucketing_2d_enabled: - return example.duration, _measure_tokens(example) - else: - return example.duration - - def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: - if not self.bucketing_2d_enabled: - return super().select_bucket(buckets=buckets, example=example, example_len=example_len) - if example_len is None: - example_len = self.measure_length(example) - bucket_idx = bisect.bisect_left(buckets, example_len) - # For 2D bucketing we have to refine the initially found bucket_idx, as bisect - # looks primarily at the first index of a tuple (i.e. duration). - # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) - # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. - # To refine, we'll try to push the example to as many buckets to the right as possible, - # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 - # is smaller than the bin's dim1 (e.g., output token sequence length). - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - num_buckets = len(self.max_seq_len_buckets) - while ( - (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket - and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. - # The example's 2nd dim is between that of the current and the next bucket; or, - # the next bucket's 2nd dim is still smaller than example. - and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) - ): - bucket_idx = next_idx - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - return bucket_idx - - -@dataclass -class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): - # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. - # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. - token_equivalent_duration: float | None = None - - # When False (default), we only consider the input part of the example to determine its length, - # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. - # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). - measure_total_length: bool = False - - def measure_length(self, example: Any) -> float | tuple[float, float]: - if isinstance(example, Cut): - audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) - if self.measure_total_length: - # Total length of a Cut (audio+text example) is counted as the sum of: - # * num_tokens in each supervision segment ("utterance") in the Cut - # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) - return audio_len_in_tokens + text_tokens - else: - return audio_len_in_tokens - elif isinstance(example, Formattable): - return example.total_length if self.measure_total_length else example.input_length - raise RuntimeError(f"Unsupported example type: {type(example)}") - - def tokenize(example, tokenizer): if isinstance(example, Cut): for s in example.supervisions: @@ -783,117 +664,6 @@ def tokenize_with_prompt(example, tokenizer, prompt_format: str | PromptFormatte # to support pickling lambdas if its ever truly necessary. -class DurationFilter: - """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" - - def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max - - def __call__(self, example) -> bool: - if isinstance(example, Cut): - return self.d_min <= example.duration <= self.d_max - else: - return True # does not apply to text etc. - - -class TokenCountFilter: - """ - Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. - - It is only applicable to data types that derive from class ``Formattable`` and lhotse ``Cut`` objects. - Acts as a passthrough for Cuts. - Raises exception if a non-Formattable and non-Cut data are provided. - - The ``measure_total_length`` option allows to select whether we should filter on context_ids length (=False) - or input_ids length (=True). - The difference is that for decoder-only models, we collapse input and output into a single sequence, - so we should measure the example length using input_ids (measure_total_length=True). - However, for models which have separate inputs and outputs such as encoder-decoder models, - we want to measure the input lengths only here (measure_total_length=False), - and enable ``TokenPerTokenFilter`` for additional filtering on the output sequence length. - """ - - def __init__(self, t_min: float, t_max: float, measure_total_length: bool) -> None: - self.t_min = t_min - self.t_max = t_max - self.measure_total_length = measure_total_length - - def __call__(self, example) -> bool: - if self.t_min is None and self.t_max is None: - return True # disabled - if isinstance(example, Cut): - return True # does not apply to Cuts - assert isinstance(example, Formattable), ( - f"TokenCountFilter can only be applied to data examples that derive Formattable class. " - f"Formattable objects define properties input_length, output_length, and total_length that " - f"allow us to select the right sequence length for filtering. We got: {example}" - ) - try: - length = example.total_length if self.measure_total_length else example.input_length - except (AttributeError, AssertionError) as e: - raise RuntimeError( - f"Cannot measure token count for example: {example} " - f"-- did you forget to apply prompt formatting? If instantiating Lhotse dataloader, " - f"make sure you provided 'prompt_format' option and passed the tokenizer." - ) from e - return self.t_min <= length <= self.t_max - - -class TokenPerSecondFilter: - """ - Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) - is in range [tps_min, tps_max] and ``False`` otherwise. - """ - - def __init__(self, tps_min: float, tps_max: float) -> None: - assert tps_min <= tps_max - self.tps_min = tps_min - self.tps_max = tps_max - self.enabled = tps_min > 0 or tps_max < float("inf") - - def __call__(self, example) -> bool: - if not isinstance(example, Cut) or not self.enabled: - return True # pass-through for non-audio examples. - tps = _measure_tps(example) - return self.tps_min <= tps <= self.tps_max - - -class TokenPerTokenFilter: - """ - Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) - is in range [tps_min, tps_max] and ``False`` otherwise. - """ - - def __init__(self, tpt_min: float, tpt_max: float) -> None: - assert tpt_min <= tpt_max - self.tpt_min = tpt_min - self.tpt_max = tpt_max - self.enabled = tpt_min > 0 or tpt_max < float("inf") - - def __call__(self, example) -> bool: - if isinstance(example, Cut) or not self.enabled: - return True # pass-through for non-text examples. - tpt = example.answer_ids.shape[0] / example.context_ids.shape[0] - return self.tpt_min <= tpt <= self.tpt_max - - -def _measure_tokens(cut: Cut) -> int: - if hasattr(cut, "input_ids"): - return len(cut.input_ids) # tokenized with prompt formatter - supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] - assert len(supervisions_with_tokens) > 0, ( - "Cannot measure tokens-per-second with untokenized supervisions. " - "Did you forget to provide the tokenizer argument to get_lhotse_dataloader_from_config() method?" - ) - return sum(len(s.tokens) for s in supervisions_with_tokens) - - -def _measure_tps(cut: Cut) -> float: - num_tokens = _measure_tokens(cut) - return num_tokens / cut.duration - - def _normalize_loudness(cuts: CutSet, db_norm: float) -> CutSet: return cuts.normalize_loudness(target=db_norm, mix_first=False) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py new file mode 100644 index 000000000000..992ec05e8474 --- /dev/null +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -0,0 +1,284 @@ +import bisect +import math +from dataclasses import dataclass +from typing import Any, Sequence + +from lhotse.cut import Cut +from lhotse.dataset import SamplingConstraint, TokenConstraint +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint + +from nemo.collections.common.data.lhotse.text_adapters import Formattable + + +@dataclass +class MultimodalSamplingConstraint(SamplingConstraint): + """ + Sampling strategy that customizes Lhotse samplers to measure sequence lengths as token counts. + It provides a unified interface for audio and text examples - audio duration is converted to + an equivalent token count. + """ + + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. + token_equivalent_duration: float | None = None + + # Defines maximum batch size (may be lower than that if batch_length is also specified). + batch_size: int | None = None + + # Defines the total number of tokens in a mini-batch. + # Setting this enables dynamic batch sizes. + # We will use ``token_equivalent_duration`` to convert audio examples to token sizes. + batch_tokens: int | None = None + + # When specified, this value is inversely proportional to the penalty we assign + # to longer examples when measuring their length/duration; + # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty. + # Tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing. + quadratic_factor: float | None = None + + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + + _internal = None + + def __post_init__(self): + self._internal = TokenConstraint( + max_tokens=self.batch_tokens, + max_examples=self.batch_size, + quadratic_length=self.quadratic_factor, + ) + + def add(self, example: Any) -> None: + num_tokens = self.measure_length(example) + example.num_tokens = num_tokens + self._internal.add(example) + + def exceeded(self) -> bool: + return self._internal.exceeded() + + def close_to_exceeding(self) -> bool: + return self._internal.close_to_exceeding() + + def reset(self) -> None: + self._internal.reset() + + def measure_length(self, example: Any) -> float: + if isinstance(example, Cut): + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + if self.measure_total_length: + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + elif isinstance(example, Formattable): + try: + return example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + "Couldn't determine the length of a text example; " + "have you provided both prompt_format and tokenizer when instantiating the dataloader?" + ) from e + raise RuntimeError(f"Unsupported example type: {type(example)}") + + +@dataclass +class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): + """ + Sampling strategy that customizes Lhotse samplers to support 2D bucket selection (it also supports 1D). + It is intended only for audio examples (i.e., Lhotse Cut objects). + """ + + @property + def bucketing_2d_enabled(self) -> bool: + return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 + + def measure_length(self, example: Cut) -> tuple[float, float] | float: + if self.bucketing_2d_enabled: + return example.duration, _measure_tokens(example) + else: + return example.duration + + def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: + if not self.bucketing_2d_enabled: + return super().select_bucket(buckets=buckets, example=example, example_len=example_len) + if example_len is None: + example_len = self.measure_length(example) + bucket_idx = bisect.bisect_left(buckets, example_len) + # For 2D bucketing we have to refine the initially found bucket_idx, as bisect + # looks primarily at the first index of a tuple (i.e. duration). + # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) + # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. + # To refine, we'll try to push the example to as many buckets to the right as possible, + # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 + # is smaller than the bin's dim1 (e.g., output token sequence length). + bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] + num_buckets = len(self.max_seq_len_buckets) + while ( + (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket + and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. + # The example's 2nd dim is between that of the current and the next bucket; or, + # the next bucket's 2nd dim is still smaller than example. + and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) + ): + bucket_idx = next_idx + bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] + return bucket_idx + + +@dataclass +class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): + """ + Sampling strategy that customizes Lhotse samplers to support both multimodal sampling and 2D bucket selection. + It combines the capabilities of :class:`FixedBucketBatchSizeConstraint2D` and :class:`MultimodalSamplingConstraint`. + """ + + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. + token_equivalent_duration: float | None = None + + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + + def measure_length(self, example: Any) -> float | tuple[float, float]: + if isinstance(example, Cut): + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + if self.measure_total_length: + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + elif isinstance(example, Formattable): + return example.total_length if self.measure_total_length else example.input_length + raise RuntimeError(f"Unsupported example type: {type(example)}") + + +class DurationFilter: + """ + Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, d_min: float, d_max: float) -> None: + self.d_min = d_min + self.d_max = d_max + + def __call__(self, example) -> bool: + if isinstance(example, Cut): + return self.d_min <= example.duration <= self.d_max + else: + return True # does not apply to text etc. + + +class TokenCountFilter: + """ + Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. + + It is only applicable to data types that derive from class ``Formattable`` and lhotse ``Cut`` objects. + Acts as a passthrough for Cuts. + Raises exception if a non-Formattable and non-Cut data are provided. + + The ``measure_total_length`` option allows to select whether we should filter on context_ids length (=False) + or input_ids length (=True). + The difference is that for decoder-only models, we collapse input and output into a single sequence, + so we should measure the example length using input_ids (measure_total_length=True). + However, for models which have separate inputs and outputs such as encoder-decoder models, + we want to measure the input lengths only here (measure_total_length=False), + and enable ``TokenPerTokenFilter`` for additional filtering on the output sequence length. + """ + + def __init__(self, t_min: float, t_max: float, measure_total_length: bool) -> None: + self.t_min = t_min + self.t_max = t_max + self.measure_total_length = measure_total_length + + def __call__(self, example) -> bool: + if self.t_min is None and self.t_max is None: + return True # disabled + if isinstance(example, Cut): + return True # does not apply to Cuts + assert isinstance(example, Formattable), ( + f"TokenCountFilter can only be applied to data examples that derive Formattable class. " + f"Formattable objects define properties input_length, output_length, and total_length that " + f"allow us to select the right sequence length for filtering. We got: {example}" + ) + try: + length = example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + f"Cannot measure token count for example: {example} " + f"-- did you forget to apply prompt formatting? If instantiating Lhotse dataloader, " + f"make sure you provided 'prompt_format' option and passed the tokenizer." + ) from e + return self.t_min <= length <= self.t_max + + +class TokenPerSecondFilter: + """ + Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) + is in range [tps_min, tps_max] and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, tps_min: float, tps_max: float) -> None: + assert tps_min <= tps_max + self.tps_min = tps_min + self.tps_max = tps_max + self.enabled = tps_min > 0 or tps_max < float("inf") + + def __call__(self, example) -> bool: + if not isinstance(example, Cut) or not self.enabled: + return True # pass-through for non-audio examples. + tps = _measure_tps(example) + return self.tps_min <= tps <= self.tps_max + + +class TokenPerTokenFilter: + """ + Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) + is in range [tps_min, tps_max] and ``False`` otherwise. + Acts as a pass-through for audio examples (Cuts). + """ + + def __init__(self, tpt_min: float, tpt_max: float) -> None: + assert tpt_min <= tpt_max + self.tpt_min = tpt_min + self.tpt_max = tpt_max + self.enabled = tpt_min > 0 or tpt_max < float("inf") + + def __call__(self, example) -> bool: + if isinstance(example, Cut) or not self.enabled: + return True # pass-through for non-text examples. + tpt = example.answer_ids.shape[0] / example.context_ids.shape[0] + return self.tpt_min <= tpt <= self.tpt_max + + +def _measure_tokens(cut: Cut) -> int: + if hasattr(cut, "input_ids"): + return len(cut.input_ids) # tokenized with prompt formatter + supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] + assert len(supervisions_with_tokens) > 0, ( + "Cannot measure tokens-per-second with untokenized supervisions. " + "Did you forget to provide the tokenizer argument to get_lhotse_dataloader_from_config() method?" + ) + return sum(len(s.tokens) for s in supervisions_with_tokens) + + +def _measure_tps(cut: Cut) -> float: + num_tokens = _measure_tokens(cut) + return num_tokens / cut.duration diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 52d5b3620a2a..0f4a021e09cc 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -27,12 +27,11 @@ from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config -from nemo.collections.common.data.lhotse.dataloader import ( +from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize +from nemo.collections.common.data.lhotse.sampling import ( DurationFilter, FixedBucketBatchSizeConstraint2D, - LhotseDataLoadingConfig, TokenPerSecondFilter, - tokenize, ) from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 5a6d8ea23272..50820920a44f 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -27,17 +27,15 @@ from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config -from nemo.collections.common.data.lhotse.dataloader import ( +from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize, tokenize_with_prompt +from nemo.collections.common.data.lhotse.sampling import ( DurationFilter, FixedBucketBatchSizeConstraint2D, - LhotseDataLoadingConfig, MultimodalFixedBucketBatchSizeConstraint2D, MultimodalSamplingConstraint, TokenCountFilter, TokenPerSecondFilter, TokenPerTokenFilter, - tokenize, - tokenize_with_prompt, ) from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index ba67d2e1fabb..fa771eb75f85 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -3,7 +3,7 @@ from lhotse import CutSet, Seconds, SupervisionSegment from lhotse.dataset import DynamicBucketingSampler from lhotse.testing.dummies import DummyManifest, dummy_cut -from nemo.collections.common.data.lhotse.dataloader import FixedBucketBatchSizeConstraint2D +from nemo.collections.common.data.lhotse.sampling import FixedBucketBatchSizeConstraint2D @pytest.fixture diff --git a/tests/collections/common/test_lhotse_seqlen_filters.py b/tests/collections/common/test_lhotse_seqlen_filters.py index f9dbc49fa20b..04ded9f3186a 100644 --- a/tests/collections/common/test_lhotse_seqlen_filters.py +++ b/tests/collections/common/test_lhotse_seqlen_filters.py @@ -5,7 +5,7 @@ from lhotse import SupervisionSegment from lhotse.testing.dummies import dummy_cut -from nemo.collections.common.data.lhotse.dataloader import ( +from nemo.collections.common.data.lhotse.sampling import ( DurationFilter, TokenCountFilter, TokenPerSecondFilter, From 27d03861f7790be1906b311d646abd0c17654a4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 10:51:45 -0400 Subject: [PATCH 47/63] Tests and support for sampler length measurement of multimodal conversations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/sampling.py | 47 +++-- .../common/data/lhotse/text_adapters.py | 12 +- .../test_lhotse_multimodal_dataloading.py | 173 +++++++++++++++++- 3 files changed, 218 insertions(+), 14 deletions(-) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 992ec05e8474..1ea2232f80fe 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -1,4 +1,5 @@ import bisect +import logging import math from dataclasses import dataclass from typing import Any, Sequence @@ -130,6 +131,17 @@ def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = No ): bucket_idx = next_idx bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] + + if example_len[0] > bin_dim0 or example_len[1] > bin_dim1: + logging.warning( + f"Data sample exceeds 2D bucket specification: lengths={example_len} bucket=({bin_dim0}, {bin_dim1}) " + f"(there is no larger bucket that would fit this example). " + f"We will keep it but expect OutOfMemoryError to happen during the training. " + f"You can fix this by stricter filtering with max_duration, max_tokens, max_tps, max_tpt; " + f"or re-estimating your bucket bins to match the actual data length distribution. " + f"Details: {example=}" + ) + return bucket_idx @@ -152,19 +164,32 @@ class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2 def measure_length(self, example: Any) -> float | tuple[float, float]: if isinstance(example, Cut): audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) - if self.measure_total_length: - # Total length of a Cut (audio+text example) is counted as the sum of: - # * num_tokens in each supervision segment ("utterance") in the Cut - # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) - return audio_len_in_tokens + text_tokens + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + + if self.bucketing_2d_enabled: + return audio_len_in_tokens, text_tokens + else: - return audio_len_in_tokens + if self.measure_total_length: + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + elif isinstance(example, Formattable): - return example.total_length if self.measure_total_length else example.input_length + if self.bucketing_2d_enabled: + assert ( + not self.measure_total_length + ), "2D bucketing requires measure_total_length=False, but it was set to True." + return example.input_length, example.output_length + else: + return example.total_length if self.measure_total_length else example.input_length + raise RuntimeError(f"Unsupported example type: {type(example)}") diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 08b9804d75ad..6a906861a219 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -344,8 +344,18 @@ def total_length(self) -> int | None: extra = _compute_num_audio_tokens(self, "all") return self.input_ids.shape[0] + extra + @property + def has_audio_turns(self) -> bool: + return any(isinstance(t, AudioTurn) for t in self.turns) + + @property + def has_text_turns(self) -> bool: + return any(isinstance(t, TextTurn) for t in self.turns) + def _compute_num_audio_tokens(example: NeMoMultimodalConversation, mode: Literal["context", "answer", "all"]) -> int: + if not example.has_audio_turns: + return 0 assert example.token_equivalent_duration is not None, ( "Cannot compute the length of a NeMoMultimodalConversation: " "token_equivalent_duration must be set in order to estimate the number of tokens equivalent to audio turns. " @@ -356,7 +366,7 @@ def _compute_num_audio_tokens(example: NeMoMultimodalConversation, mode: Literal case "context": turns = example.turns[:-1] case "answer": - turns = example.turns[-1] + turns = example.turns[-1:] case "all": turns = example.turns case _: diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index 51b3085a8fc8..c7cc96e79ea0 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -4,12 +4,16 @@ import lhotse import pytest import torch -from lhotse.testing.dummies import dummy_cut, dummy_recording +from lhotse.testing.dummies import dummy_recording from omegaconf import OmegaConf from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.sampling import ( + MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, +) from nemo.collections.common.data.lhotse.text_adapters import AudioTurn, NeMoMultimodalConversation, TextTurn -from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.prompts import Llama2PromptFormatter from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model @@ -210,3 +214,168 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path assert (ex.mask[30:72] == True).all() # assistant turn assert (ex.mask[72:95] == False).all() # user turn assert (ex.mask[95:] == True).all() # assistant turn + + +def test_text_only_conversation_length_measurement(tokenizer): + convo = NeMoMultimodalConversation( + id="textonly-1", + turns=[ + TextTurn("hello", "user"), + TextTurn("hi", "assistant"), + ], + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + assert tokenizer.ids_to_text(convo.input_ids) == "[INST] hello [/INST] hi" + assert tokenizer.ids_to_text(convo.context_ids) == "[INST] hello [/INST]" + assert tokenizer.ids_to_text(convo.answer_ids) == "hi" + + assert convo.input_length == len(convo.context_ids) == 10 + assert convo.output_length == len(convo.answer_ids) == 4 + assert convo.total_length == len(convo.input_ids) == 14 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 10 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 14 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[5, 10, 15], batch_sizes=[3, 2, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 14 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 2 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[(5, 2), (5, 5), (15, 3), (15, 6), (15, 10)], + batch_sizes=[5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (10, 4) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 3 + + +def test_audio_only_conversation_length_measurement(tokenizer, tmp_path_factory): + audio_dir = tmp_path_factory.mktemp("audio") + c1 = dummy_recording(0, duration=7.16, with_data=True).to_cut().save_audio(audio_dir / "1.wav") + c2 = dummy_recording(1, duration=15.96, with_data=True).to_cut().save_audio(audio_dir / "2.wav") + convo = NeMoMultimodalConversation( + id="audioonly-1", + turns=[ + AudioTurn(c1, "user", "[audio]"), + AudioTurn(c2, "assistant", "[audio]"), + ], + token_equivalent_duration=0.1, # 10ms frame_shift * 10x subsampling for easy testing + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + assert tokenizer.ids_to_text(convo.input_ids) == "[INST] [audio] [/INST] [audio]" + assert tokenizer.ids_to_text(convo.context_ids) == "[INST] [audio] [/INST]" + assert tokenizer.ids_to_text(convo.answer_ids) == "[audio]" + + # NOTE: Unlike text-only, len(context_ids) != convo.input_length! The same is true for answer and input ids. + # 7.16s with 100ms frame is 72 tokens, we have 7 context tokens, but replace 1 audio locator tag. + assert len(convo.context_ids) == 7 + assert convo.input_length == 78 + + # 15.96s with 100ms frame is 160 tokens, we have 3 answer tokens, but replace 1 audio locator tag. + assert len(convo.answer_ids) == 3 + assert convo.output_length == 162 + + assert len(convo.input_ids) == 10 + assert convo.total_length == 162 + 78 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 78 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 162 + 78 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[100, 200, 300, 400], batch_sizes=[3, 2, 1, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 162 + 78 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 2 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[ + (50, 50), + (50, 100), + (50, 200), + (100, 50), + (100, 150), + (100, 200), + (100, 300), + (400, 400), + ], + batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (78, 162) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 5 + + +def test_multimodal_conversation_length_measurement(tokenizer, tmp_path_factory): + audio_dir = tmp_path_factory.mktemp("audio") + c1 = dummy_recording(0, duration=7.16, with_data=True).to_cut().save_audio(audio_dir / "1.wav") + c2 = dummy_recording(1, duration=15.96, with_data=True).to_cut().save_audio(audio_dir / "2.wav") + convo = NeMoMultimodalConversation( + id="multimodal-1", + turns=[ + TextTurn("listen to this and tell me your opinion", "user"), + AudioTurn(c1, "user", "[audio]"), + TextTurn("its fine", "assistant"), + TextTurn("remove the noise", "user"), + TextTurn("sure", "assistant"), + AudioTurn(c2, "assistant", "[audio]"), + ], + token_equivalent_duration=0.1, # 10ms frame_shift * 10x subsampling for easy testing + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + print(convo) + assert ( + tokenizer.ids_to_text(convo.input_ids) + == "[INST] listen to this and tell me your opinion [audio] [/INST] its fine [INST] remove the noise [/INST] sure [audio]" + ) + assert ( + tokenizer.ids_to_text(convo.context_ids) + == "[INST] listen to this and tell me your opinion [audio] [/INST] its fine [INST] remove the noise [/INST]" + ) + assert tokenizer.ids_to_text(convo.answer_ids) == "sure [audio]" + + assert len(convo.context_ids) == 66 + assert convo.input_length == 66 + 72 - 1 == 137 + + # 15.96s with 100ms frame is 160 tokens, we have 3 answer tokens, but replace 1 audio locator tag. + assert len(convo.answer_ids) == 7 + assert convo.output_length == 7 + 160 - 1 == 166 + + assert len(convo.input_ids) == 73 + assert convo.total_length == 137 + 166 == 303 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 137 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 303 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[100, 200, 300, 400], batch_sizes=[3, 2, 1, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 303 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 3 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[ + (50, 50), + (50, 100), + (50, 200), + (100, 50), + (100, 150), + (100, 200), + (100, 300), + (400, 400), + ], + batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (137, 166) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 7 From 30bda107d0c6fa4130194d46adca68e493ec39c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 10:56:36 -0400 Subject: [PATCH 48/63] Update estimate_token_bins.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/sampling.py | 2 +- scripts/speech_recognition/estimate_token_bins.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 1ea2232f80fe..0e6a0a0b79ee 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -298,7 +298,7 @@ def _measure_tokens(cut: Cut) -> int: return len(cut.input_ids) # tokenized with prompt formatter supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] assert len(supervisions_with_tokens) > 0, ( - "Cannot measure tokens-per-second with untokenized supervisions. " + "Cannot measure the number of tokens with untokenized supervisions. " "Did you forget to provide the tokenizer argument to get_lhotse_dataloader_from_config() method?" ) return sum(len(s.tokens) for s in supervisions_with_tokens) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_recognition/estimate_token_bins.py index 50820920a44f..b198158498c1 100644 --- a/scripts/speech_recognition/estimate_token_bins.py +++ b/scripts/speech_recognition/estimate_token_bins.py @@ -17,7 +17,6 @@ import math from functools import partial from itertools import islice -from pathlib import Path from typing import Callable, Iterable import numpy as np @@ -29,12 +28,9 @@ from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize, tokenize_with_prompt from nemo.collections.common.data.lhotse.sampling import ( - DurationFilter, - FixedBucketBatchSizeConstraint2D, MultimodalFixedBucketBatchSizeConstraint2D, MultimodalSamplingConstraint, TokenCountFilter, - TokenPerSecondFilter, TokenPerTokenFilter, ) from nemo.collections.common.prompts.formatter import PromptFormatter @@ -142,9 +138,9 @@ def estimate_token_buckets( is_2d = num_subbuckets is not None if is_2d: - constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0]) + constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0], measure_total_length=False) else: - constraint = MultimodalSamplingConstraint() + constraint = MultimodalSamplingConstraint(measure_total_length=True) # Gather the duration and token count statistics for the dataset. num_input_tokens = [] From 968238c12f050cbe99fd5bba65c8652cd7fb6ce5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 10:57:17 -0400 Subject: [PATCH 49/63] Move estimate_token_bins.py to speech_llm scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- scripts/{speech_recognition => speech_llm}/estimate_token_bins.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{speech_recognition => speech_llm}/estimate_token_bins.py (100%) diff --git a/scripts/speech_recognition/estimate_token_bins.py b/scripts/speech_llm/estimate_token_bins.py similarity index 100% rename from scripts/speech_recognition/estimate_token_bins.py rename to scripts/speech_llm/estimate_token_bins.py From f7d7453df781624f1af0c39f4bd2668ef6305e88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 11:03:21 -0400 Subject: [PATCH 50/63] Minor tweaks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/dataloader.py | 2 ++ nemo/collections/common/data/lhotse/sampling.py | 10 ++-------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index e9039b049023..7793ec19ea2c 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -112,6 +112,8 @@ class LhotseDataLoadingConfig: # * Text input min_tokens: int | None = None max_tokens: int | None = None + # When true, combine context+answer lengths into a total length; otherwise report context length. + # For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len). measure_total_length: bool = True min_tpt: int = -1 # allowed tokens per token (text-only) max_tpt: float = float("inf") diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 0e6a0a0b79ee..8de49bbaf500 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -163,14 +163,11 @@ class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2 def measure_length(self, example: Any) -> float | tuple[float, float]: if isinstance(example, Cut): - audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) # Total length of a Cut (audio+text example) is counted as the sum of: # * num_tokens in each supervision segment ("utterance") in the Cut # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + text_tokens = _measure_tokens(example) if self.bucketing_2d_enabled: return audio_len_in_tokens, text_tokens @@ -183,9 +180,6 @@ def measure_length(self, example: Any) -> float | tuple[float, float]: elif isinstance(example, Formattable): if self.bucketing_2d_enabled: - assert ( - not self.measure_total_length - ), "2D bucketing requires measure_total_length=False, but it was set to True." return example.input_length, example.output_length else: return example.total_length if self.measure_total_length else example.input_length From 25192496dfe06e0af4b65d0114dcd863c334f24b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 15:48:34 +0000 Subject: [PATCH 51/63] Fixes for SpeechLLM dataset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 2 +- .../speech_llm/parts/utils/data_utils.py | 9 +++--- .../common/test_lhotse_seqlen_filters.py | 32 +++++++++---------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 6c314d4b9de8..80382c6dcb14 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -332,7 +332,7 @@ def parse_and_combine_datasets( @data_type_parser(["lhotse", "lhotse_shar"]) def read_lhotse_manifest(config) -> tuple[CutSet, bool]: - is_tarred = config.shar_path is not None + is_tarred = config.get("shar_path") is not None if is_tarred: # Lhotse Shar is the equivalent of NeMo's native "tarred" dataset. # The combination of shuffle_shards, and repeat causes this to diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index 43f08afea4c9..494667c5bfb1 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -18,7 +18,8 @@ import torch from lhotse.cut import Cut -from nemo.collections.common.prompts import PromptFormatter, get_prompt_format_fn +from nemo.collections.common.data.prompt_fn import get_prompt_format_fn +from nemo.collections.common.prompts import PromptFormatter from nemo.utils import logging, logging_mode @@ -403,7 +404,8 @@ def __init__( audio_locator: Optional[str] = None, max_seq_length: Optional[int] = 8192, ): - self.prompt_format_fn = get_prompt_format_fn(prompt_format) + self.prompt = PromptFormatter.resolve(prompt_format)(tokenizer) + self.prompt_format_fn = get_prompt_format_fn(Cut, self.prompt) self.tokenizer = tokenizer self.audio_locator = audio_locator self.max_seq_length = max_seq_length @@ -418,8 +420,7 @@ def __init__( ) def _process_example(self, cut: Cut): - ans = self.prompt_format_fn([cut], self.tokenizer) - ans = {k: v[0] for k, v in ans.items()} + ans = self.prompt_format_fn(cut, self.prompt) context_start_idx = [0] if self.audio_locator_id is not None: if len(self.audio_locator_id) == 1: # fast case, special "insert audio" token diff --git a/tests/collections/common/test_lhotse_seqlen_filters.py b/tests/collections/common/test_lhotse_seqlen_filters.py index 04ded9f3186a..ba77b235c6e5 100644 --- a/tests/collections/common/test_lhotse_seqlen_filters.py +++ b/tests/collections/common/test_lhotse_seqlen_filters.py @@ -44,12 +44,12 @@ def test_cut_token_per_second_filter(cut): def test_cut_passes_by_token_count_and_tpt_filter(cut): - assert TokenCountFilter(1, 10, use_total_length=True)(cut) == True + assert TokenCountFilter(1, 10, measure_total_length=True)(cut) == True assert TokenPerTokenFilter(1, 10)(cut) == True def test_cut_passes_by_token_count_and_tpt_filter(cut): - assert TokenCountFilter(1, 10, use_total_length=True)(cut) == True + assert TokenCountFilter(1, 10, measure_total_length=True)(cut) == True assert TokenPerTokenFilter(1, 10)(cut) == True @@ -80,28 +80,28 @@ def test_src_tgt_token_filter(src_tgt_example): """ Input length measurement / encoder-decoder models / 2D bucketing """ - f = TokenCountFilter(1, 5, use_total_length=False) + f = TokenCountFilter(1, 5, measure_total_length=False) assert f(example) == True - f = TokenCountFilter(1, 3, use_total_length=False) + f = TokenCountFilter(1, 3, measure_total_length=False) assert f(example) == False - f = TokenCountFilter(10, 30, use_total_length=False) + f = TokenCountFilter(10, 30, measure_total_length=False) assert f(example) == False """ Total length measurement / decoder-only models / 1D bucketing """ - f = TokenCountFilter(1, 5, use_total_length=True) + f = TokenCountFilter(1, 5, measure_total_length=True) assert f(example) == False - f = TokenCountFilter(1, 20, use_total_length=True) + f = TokenCountFilter(1, 20, measure_total_length=True) assert f(example) == True - f = TokenCountFilter(1, 3, use_total_length=True) + f = TokenCountFilter(1, 3, measure_total_length=True) assert f(example) == False - f = TokenCountFilter(20, 30, use_total_length=True) + f = TokenCountFilter(20, 30, measure_total_length=True) assert f(example) == False @@ -146,26 +146,26 @@ def test_nemo_sft_token_filter(nemo_sft_example): """ Input length measurement / encoder-decoder models / 2D bucketing """ - f = TokenCountFilter(1, 5, use_total_length=False) + f = TokenCountFilter(1, 5, measure_total_length=False) assert f(example) == True - f = TokenCountFilter(1, 2, use_total_length=False) + f = TokenCountFilter(1, 2, measure_total_length=False) assert f(example) == False - f = TokenCountFilter(10, 30, use_total_length=False) + f = TokenCountFilter(10, 30, measure_total_length=False) assert f(example) == False """ Total length measurement / decoder-only models / 1D bucketing """ - f = TokenCountFilter(1, 5, use_total_length=True) + f = TokenCountFilter(1, 5, measure_total_length=True) assert f(example) == False - f = TokenCountFilter(1, 20, use_total_length=True) + f = TokenCountFilter(1, 20, measure_total_length=True) assert f(example) == True - f = TokenCountFilter(1, 3, use_total_length=True) + f = TokenCountFilter(1, 3, measure_total_length=True) assert f(example) == False - f = TokenCountFilter(10, 30, use_total_length=True) + f = TokenCountFilter(10, 30, measure_total_length=True) assert f(example) == False From 69bb7e18655a537e03ea465a0da7b14035194c67 Mon Sep 17 00:00:00 2001 From: pzelasko Date: Thu, 17 Oct 2024 15:52:49 +0000 Subject: [PATCH 52/63] Apply isort and black reformatting Signed-off-by: pzelasko --- scripts/speech_llm/oomptimizer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/speech_llm/oomptimizer.py b/scripts/speech_llm/oomptimizer.py index 3761099b614d..63afbe743364 100755 --- a/scripts/speech_llm/oomptimizer.py +++ b/scripts/speech_llm/oomptimizer.py @@ -391,7 +391,10 @@ def oomptimizer( from megatron.core.parallel_state import initialize_model_parallel from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo - initialize_model_parallel_for_nemo(world_size=1, global_rank=0, local_rank=0, micro_batch_size=16, global_batch_size=16) + + initialize_model_parallel_for_nemo( + world_size=1, global_rank=0, local_rank=0, micro_batch_size=16, global_batch_size=16 + ) torch.distributed.init_process_group("nccl", world_size=1, rank=0) initialize_model_parallel() @@ -517,7 +520,7 @@ def step(): # but we have found out empirically that this causes a mismatched condition # between OOMptimizer and the actual training. During training, there is some # degree of memory fragmentation and it's better to simulate that in OOMptimizer. - #torch.cuda.memory.empty_cache() + # torch.cuda.memory.empty_cache() torch.cuda.reset_max_memory_allocated() return oom From bc7dcde6560f7cff3a191aca8eb4143940d28ce0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 17 Oct 2024 10:17:00 -0700 Subject: [PATCH 53/63] Add missing emmett tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- tests/collections/multimodal/test_emmett.py | 239 ++++++++++++++++++++ 1 file changed, 239 insertions(+) create mode 100644 tests/collections/multimodal/test_emmett.py diff --git a/tests/collections/multimodal/test_emmett.py b/tests/collections/multimodal/test_emmett.py new file mode 100644 index 000000000000..553343b8a711 --- /dev/null +++ b/tests/collections/multimodal/test_emmett.py @@ -0,0 +1,239 @@ +import pytest +import torch +from lhotse import CutSet, MonoCut, SupervisionSegment +from lhotse.testing.dummies import dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample, TextExample +from nemo.collections.common.tokenizers import SentencePieceTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model +from nemo.collections.multimodal.speech_llm.data.lhotse_dataset import LhotseAudioQuestionAnswerDataset +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import PromptFormatterTextProcessing + + +class Identity(torch.utils.data.Dataset): + def __getitem__(self, cuts): + return cuts + + +@pytest.fixture +def tokenizer(capsys, tmp_path_factory): + TOKENIZER_TRAIN_TEXT = """ + Example system message. + Example user message. + Example assistant message. + TEST + [INST] + [/INST] + + + <> + <> + User: Assistant: + user model + Instruct Output + \n\n + + <| + |> + <|en|> <|de|> <|fr|> <|es|> <|transcribe|> <|translate|> <|pnc|> <|nopnc|> <|startoftranscript|> <|endoftext|> + Feel free to add new tokens for your own tests!? + But know that if you do so, you may need to update the token IDs in the existing tests! + So, it might be a good idea to create a new tokenizer instead when adding new prompt formats. + """ + tmpdir = tmp_path_factory.mktemp("bpe_tokenizer") + text_path = tmpdir / "text.txt" + text_path.write_text(TOKENIZER_TRAIN_TEXT) + with capsys.disabled(): + create_spt_model(str(text_path), vocab_size=512, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +""" +TEST FOR AUDIO DATALOADING WITH EMMETT +""" + + +@pytest.fixture +def cuts(): + return CutSet( + [ + MonoCut( + id="ex0", + start=0, + duration=5.0, + channel=0, + supervisions=[ + SupervisionSegment( + id="ex0", + recording_id="dummy-recording-0000", + start=0, + duration=5.0, + text="some transcription", + language="en", + ) + ], + recording=dummy_recording(0, duration=5.0, with_data=True), + custom={ + "context": "", + "answer": "some desired answer", + }, + ), + ] + ) + + +@pytest.fixture +def cuts_path(tmp_path_factory, cuts): + tmp_path = tmp_path_factory.mktemp("data") + p = tmp_path / "cuts.jsonl.gz" + pa = tmp_path / "audio" + cuts.save_audios(pa).to_file(p) + return p + + +def test_audio_example_with_prompt_emmett_t5(cuts_path, tokenizer): + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "lhotse", + "cuts_path": cuts_path, + }, + ], + "prompt_format": "t5nmt", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # First test that sampling is correct and tokenizer + prompt formatter is applied there + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, MonoCut) + + assert ex.has_custom("context_ids") + assert torch.is_tensor(ex.context_ids) + assert tokenizer.ids_to_text(ex.context_ids) == "" + + assert ex.has_custom("answer_ids") + assert torch.is_tensor(ex.answer_ids) + assert tokenizer.ids_to_text(ex.answer_ids) == "some transcription" + + assert ex.has_custom("input_ids") + assert torch.is_tensor(ex.input_ids) + assert tokenizer.ids_to_text(ex.input_ids) == " some transcription" + + # Test that speechlm dataset processes the example correctly + + text_processor = PromptFormatterTextProcessing(tokenizer=tokenizer, prompt_format="t5nmt") + dataset = LhotseAudioQuestionAnswerDataset( + text_processor=text_processor, + default_context="", + tokens_to_generate=0, + pad_to_max_length=False, + max_seq_length=64, + ) + + batch = dataset[batches[0]] + assert tokenizer.ids_to_text(batch["tokens"][0]) == " some transcriptio" + assert tokenizer.ids_to_text(batch["labels"][0]) == "en> some transcription" + assert tokenizer.ids_to_text(batch["contexts"][0]) == "" + assert tokenizer.ids_to_text(batch["answers"][0]) == "some transcription" + + +""" +TEST FOR TEXT DATALOADING WITH EMMETT +""" + + +@pytest.fixture +def nmt_paths(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("nmtdata") + src = tmp_path / "src.txt" + tgt = tmp_path / "tgt.txt" + q = tmp_path / "q.txt" + src.write_text("fake german") + tgt.write_text("real english") + q.write_text("") + return src, tgt, q + + +def test_text_example_with_prompt_emmett_t5(nmt_paths, tokenizer): + src, tgt, q = nmt_paths + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": src, + "target_paths": tgt, + "source_language": "de", + "target_language": "en", + "questions_path": q, + "questions_language": "en", + }, + ], + "prompt_format": "t5nmt", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # First test that sampling is correct and tokenizer + prompt formatter is applied there + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, SourceTargetTextExample) + + assert torch.is_tensor(ex.context_ids) + assert tokenizer.ids_to_text(ex.context_ids) == " fake german" + + assert torch.is_tensor(ex.answer_ids) + assert tokenizer.ids_to_text(ex.answer_ids) == "real english" + + assert torch.is_tensor(ex.input_ids) + assert tokenizer.ids_to_text(ex.input_ids) == " fake german real english" + + # Test that speechlm dataset processes the example correctly + + text_processor = PromptFormatterTextProcessing(tokenizer=tokenizer, prompt_format="t5nmt") + dataset = LhotseAudioQuestionAnswerDataset( + text_processor=text_processor, + default_context="", + tokens_to_generate=0, + pad_to_max_length=False, + max_seq_length=64, + ) + + batch = dataset[batches[0]] + + assert tokenizer.ids_to_text(batch["text_input_ids"][0]) == " fake german real english" + assert tokenizer.ids_to_text(batch["text_context_ids"][0]) == " fake german" + assert tokenizer.ids_to_text(batch["text_answer_ids"][0]) == "real english" From fb29842fc8c094589f141ab5832d9d35e8a66ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 18 Oct 2024 14:13:21 -0400 Subject: [PATCH 54/63] Add tutorial about multimodal lhotse dataloading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 4 +- .../common/data/lhotse/dataloader.py | 2 - nemo/collections/common/data/prompt_fn.py | 15 +- .../Multimodal Lhotse Dataloading.ipynb | 1015 +++++++++++++++++ 4 files changed, 1025 insertions(+), 11 deletions(-) create mode 100644 tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 80382c6dcb14..65027e366fbe 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -175,8 +175,8 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: tgt_lang: en """ propagate_attrs = { - "shuffle": config.shuffle, - "shard_seed": config.shard_seed, + "shuffle": config.get("shuffle", False), + "shard_seed": config.get("shard_seed", "trng"), "text_field": config.get("text_field", "text"), "lang_field": config.get("lang_field", "lang"), "metadata_only": config.get("metadata_only", False), diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index c74cec62c269..61451ef09e1f 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -313,8 +313,6 @@ def get_lhotse_dataloader_from_multi_config( The first config is treated as a "main" config that determines the RNG, CUDA allocator, and sampler fusion settings. """ - logging.info(f"We will be using a multi config Lhotse DataLoader with groups: {list(configs.keys())}.") - configs = [make_structured_with_schema_warnings(c) for c in configs.values() if isinstance(c, DictConfig)] main_config = configs[0] maybe_set_cuda_expandable_segments(enabled=main_config.cuda_expandable_segments) diff --git a/nemo/collections/common/data/prompt_fn.py b/nemo/collections/common/data/prompt_fn.py index e55610ee7fa8..bd1e45ea92e2 100644 --- a/nemo/collections/common/data/prompt_fn.py +++ b/nemo/collections/common/data/prompt_fn.py @@ -51,16 +51,17 @@ def get_prompt_format_fn(example: Type | object, prompt: Type | object = None) - prompt = type(prompt) # For the example type, first try to match it directly, then fall back to its parent classes. - for subtype in example.mro(): + for example_subtype in example.mro(): - # First check the match for specific example type and a specific prompt format. - if (subtype, prompt) in PROMPT_FORMAT_FNS: - return PROMPT_FORMAT_FNS[(subtype, prompt)] + # First check the match for specific example type and a specific prompt format, + # and all parent types of that specific prompt formatter type. + for prompt_subtype in prompt.mro(): + if (example_subtype, prompt_subtype) in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[(example_subtype, prompt_subtype)] # Then for the same specific example type, fall back to its default prompt formatter implementation. - # Note: the data example type takes precedence over the prompt formatter type for this resolution. - if subtype in PROMPT_FORMAT_FNS: - return PROMPT_FORMAT_FNS[subtype] + if example_subtype in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[example_subtype] raise ValueError( f"Unknown prompt format function for ({example}, {prompt}). " diff --git a/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb new file mode 100644 index 000000000000..79104f21a3ba --- /dev/null +++ b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb @@ -0,0 +1,1015 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e930b0c5f0cffbce", + "metadata": {}, + "source": [ + "# Multimodal Lhotse Dataloading\n", + "\n", + "This tutorial explains how NeMo uses Lhotse for multimodal dataloading.\n", + "The modalities supported as of the time of writing are audio and text.\n", + "The intended audience of this tutorial are NeMo developers and persons who build/modify NeMo models.\n", + "After finishing this tutorial, you should have an understanding how to use various Lhotse building blocks in NeMo for designing the kind of model you want.\n", + "\n", + "We cover the following topics:\n", + "* What are data types?\n", + "* What data types are availabe in NeMo?\n", + "* How do we read them from files?\n", + "* How to apply prompt formatting to various data types?\n", + "* How to create tensors for training with these examples?\n", + "* How to optimize the training by stratifying data sampling on sequence lengths, and how these lengths are measured for different examples and models. \n", + "* How to train on multiple data types together?" + ] + }, + { + "cell_type": "markdown", + "id": "72bd180c65992eba", + "metadata": {}, + "source": [ + "## Data types\n", + "\n", + "A data type represents examples of your training data: speech recordings, text sentences, text sentence pairs, conversations, etc.\n", + "\n", + "A data type consists of:\n", + "* a class that represents a single sample\n", + " * includes properties allowing sequence length measurement for sampling purposes\n", + "* a parser class that's initialized with a config (e.g. paths to data) and acts as an iterator of examples\n", + "* extension functions that define how to apply prompt formatting to a given data type\n", + "\n", + "NeMo uses Lhotse Cuts as a basic data type for audio, and defines several data types for text. We'll go over them below.\n", + "\n", + "External references:\n", + "* [Lhotse documentation](https://lhotse.readthedocs.io/en/latest/getting-started.html)\n", + "* [Lhotse in NeMo documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#lhotse-dataloading)" + ] + }, + { + "cell_type": "markdown", + "id": "cf32bf3ea5a9cb17", + "metadata": {}, + "source": [ + "### Audio examples (Lhotse cuts)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d2d747f6b32d5942", + "metadata": { + "jupyter": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from lhotse import MonoCut, Recording, SupervisionSegment, AudioSource\n", + "from lhotse.testing.dummies import dummy_cut\n", + "\n", + "\n", + "# A basic audio example: recording with transcription\n", + "cut = MonoCut(\n", + " id=\"utt-0\",\n", + " start=0.0,\n", + " duration=10.0,\n", + " channel=0,\n", + " supervisions=[SupervisionSegment(id=\"utt-0\", recording_id=\"rec-0\", start=0.0, duration=10.0, text=\"Welcome to Lhotse!\")],\n", + " recording=Recording(\n", + " id=\"rec-0\",\n", + " sources=[AudioSource(type=\"file\", channels=[0], source=\"/path/to/recording.wav\")],\n", + " sampling_rate=16000,\n", + " duration=10.0,\n", + " num_samples=160000,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9b121afd920bdab2", + "metadata": {}, + "source": [ + "## Single text examples " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "41b0c148e0d7ac1c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TextExample(text='This is a single sentence, which may be used in language modeling.', language='en', tokens=None, custom=None)\n" + ] + } + ], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import TextExample\n", + "\n", + "# A basic text example: single line of text.\n", + "text = TextExample(\n", + " text=\"This is a single sentence, which may be used in language modeling.\",\n", + " language=\"en\"\n", + ")\n", + "print(text)" + ] + }, + { + "cell_type": "markdown", + "id": "2abb821b69f71a91", + "metadata": {}, + "source": [ + "## Pairs of text examples" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "282560cc3df9174a", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample\n", + "\n", + "# A pair of text examples, usable e.g. in machine translation.\n", + "text_pair = SourceTargetTextExample(\n", + " source=TextExample(\n", + " text=\"Some machine translation example.\",\n", + " language=\"en\",\n", + " ),\n", + " target=TextExample(\n", + " text=\"Algunos ejemplos de traducción automática.\",\n", + " language=\"es\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "858d6cb6abb1ccd6", + "metadata": {}, + "source": [ + "## Conversations: text, audio, and multimodal" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e5bd8caee40100b1", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import NeMoMultimodalConversation, TextTurn, AudioTurn\n", + "\n", + "# A text-only conversation, useful for chat LLM training.\n", + "text_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-text-0\",\n", + " turns=[\n", + " TextTurn(value=\"Is this a text-only conversation?\", role=\"user\"),\n", + " TextTurn(value=\"Yes, but we can do more than that.\", role=\"assistant\"),\n", + " TextTurn(value=\"Tell me more.\", role=\"user\"),\n", + " TextTurn(value=\"Of course! Let's move on to the next example.\", role=\"assistant\"),\n", + " ]\n", + ")\n", + "\n", + "# An audio-only conversation, useful for chat speech LLM training.\n", + "# We'll explain [audio] tag and token_equivalent_duration later in this tutorial.\n", + "audio_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-audio-0\",\n", + " turns=[\n", + " AudioTurn(cut=dummy_cut(0, duration=7.18, with_data=True), role=\"user\", audio_locator_tag=\"[audio]\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=21.64, with_data=True), role=\"assistant\", audio_locator_tag=\"[audio]\"),\n", + " ],\n", + " token_equivalent_duration=0.08,\n", + ")\n", + "\n", + "# A multimodal conversation.\n", + "multimodal_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-multimodal-0\",\n", + " turns=[\n", + " TextTurn(value=\"Is this a text-only conversation?\", role=\"user\"),\n", + " TextTurn(value=\"No, feel free to speak to me.\", role=\"assistant\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=5.87, with_data=True), role=\"user\", audio_locator_tag=\"[audio]\"),\n", + " TextTurn(value=\"Should I respond in voice too?\", role=\"assistant\"),\n", + " TextTurn(value=\"Yes\", role=\"user\"),\n", + " TextTurn(value=\"Certainly!\", role=\"assistant\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=14.62, with_data=True), role=\"assistant\", audio_locator_tag=\"[audio]\"),\n", + " ],\n", + " token_equivalent_duration=0.08,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b21e0e5e84904d89", + "metadata": {}, + "source": [ + "As you can see, these data structures serve as a complete description of training examples of different types, \n", + "as they contain both the data (audio) and various metadata." + ] + }, + { + "cell_type": "markdown", + "id": "9198210580be10bf", + "metadata": {}, + "source": [ + "## Parsing data types from files\n", + "\n", + "Related: for an overview of NeMo data configuration format, please see these docs: \n", + "* [Extended multi-dataset configuration format](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#extended-multi-dataset-configuration-format)\n", + "* [Configuring multi-modal dataloading](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#configuring-multi-modal-dataloading)\n", + "\n", + "The goal of data type parser is to read a configuration specifying where the data is located / how to read it,\n", + "create an iterable over the corresponding data type, and wrap it into a Lhotse CutSet.\n", + "\n", + "Adding support for a new data type parser requires two components:\n", + "* An adapter/iterator class dedicated to your data type.\n", + "* A function that instantiates this adapter/iterator, registered with a `@data_type_parser(\"name\")` decorator to make it auto-detectable by NeMo.\n", + "\n", + "We'll take a deeper look at how source-target text example pairs parsing is implemented. We'll implement a custom parser for `SourceTargetTextExample` that reads them from JSON files." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f0e35b53c7ac77b4", + "metadata": {}, + "outputs": [], + "source": [ + "from lhotse.serialization import load_jsonl\n", + "import random\n", + "from typing import Literal, Iterator\n", + "from dataclasses import dataclass\n", + "\n", + "from lhotse import CutSet\n", + "from lhotse.dataset.dataloading import resolve_seed\n", + "from omegaconf import DictConfig\n", + "from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths\n", + "from nemo.collections.common.data.lhotse.cutset import data_type_parser\n", + "\n", + "\n", + "@dataclass\n", + "class LhotseTextPairAdapterFromJsonl:\n", + " manifest_path: str | list[str]\n", + " shuffle_shards: bool = False\n", + " shard_seed: int | Literal[\"trng\", \"randomized\"] = \"trng\"\n", + "\n", + " def __post_init__(self):\n", + " self.manifest_path = expand_sharded_filepaths(self.manifest_path)\n", + "\n", + " def __iter__(self) -> Iterator[SourceTargetTextExample]:\n", + " seed = resolve_seed(self.shard_seed)\n", + " rng = random.Random(seed)\n", + " paths = self.manifest_path\n", + " if self.shuffle_shards:\n", + " rng.shuffle(paths)\n", + " for p in paths:\n", + " for item in load_jsonl(p):\n", + " yield SourceTargetTextExample(\n", + " source=TextExample(item[\"source\"], item.get(\"source_lang\")),\n", + " target=TextExample(item[\"target\"], item.get(\"target_lang\")),\n", + " question=(\n", + " TextExample(item[\"prompt\"], language=item(\"prompt_lang\"))\n", + " if \"prompt\" in item\n", + " else None\n", + " ),\n", + " )\n", + "\n", + "\n", + "@data_type_parser(\"txt_pair_jsonl\")\n", + "def read_txt_pair_paths(config: DictConfig) -> tuple[CutSet, bool]:\n", + " cuts = CutSet(\n", + " LhotseTextPairAdapterFromJsonl(\n", + " manifest_path=config.manifest_path,\n", + " shuffle_shards=config.shuffle,\n", + " shard_seed=config.shard_seed,\n", + " )\n", + " )\n", + " if not config.get(\"force_finite\", False):\n", + " cuts = cuts.repeat()\n", + " return cuts, True" + ] + }, + { + "cell_type": "markdown", + "id": "64367e6596754ee6", + "metadata": {}, + "source": [ + "Note that there is a bit of boilerplate (`expand_sharded_filepaths`, `force_finite`, `shuffle_shards`, `shard_seed`) - we might reduce the amount of necessary boilerplate in the future, but for now it is required.\n", + "\n", + "Let's test that it works. We'll first create two JSONL files (shards) with one entry each, and later use NeMo's path expansion mechanism to provide them as the input configuration.\n", + "\n", + "Then, we'll read it using the high-level API `read_cutset_from_config` that's actually used by NeMo+Lhotse dataloader to show that the auto-registration mechanism works as expected." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7987fce8db39b008", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-10-18 14:12:16 nemo_logging:349] /Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", + " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom=None)\n" + ] + } + ], + "source": [ + "!echo '{\"source\": \"A\", \"target\": \"B\"}' >> _tutorial_nmt_0.jsonl\n", + "!echo '{\"source\": \"C\", \"target\": \"D\"}' >> _tutorial_nmt_1.jsonl\n", + "\n", + "from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config\n", + "\n", + "data, use_iterable_dataset = read_cutset_from_config(\n", + " {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"txt_pair_jsonl\", \n", + " \"manifest_path\": \"_tutorial_nmt__OP_0..1_CL_.jsonl\", \n", + " }\n", + " ]\n", + " }\n", + ")\n", + "\n", + "example = next(iter(data))\n", + "assert isinstance(example, SourceTargetTextExample)\n", + "assert example.source.text == \"A\"\n", + "assert example.target.text == \"B\"\n", + "print(example)" + ] + }, + { + "cell_type": "markdown", + "id": "be48872625d1a2e0", + "metadata": {}, + "source": [ + "## Prompt formatting and conversion of data types to tensors\n", + "\n", + "Since we now understand how data types are read, let's see how to convert them to actual training examples.\n", + "Because this tutorial is focused on multimodal LLM / speech LLM training, we'll be using prompt templates adequate for various LLMs to prepare the training data. In this example, we'll use Llama2 prompt template to format each data type.\n", + "\n", + " We'll need to initialize a prompt formatter and a tokenizer; we'll just train a dummy BPE tokenizer for the purpose of the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6e1d296be0d363d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-18 14:12:19 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n" + ] + } + ], + "source": [ + "import string\n", + "import shlex\n", + "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n", + "from nemo.collections.common.prompts.formatter import PromptFormatter\n", + "\n", + "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n", + "\n", + "tok_path, vocab_path = create_spt_model(\n", + " data_file=\"_tutorial_train_text.txt\", \n", + " output_dir=\"_tutorial_spt\",\n", + " vocab_size=512, \n", + " sample_size=-1, \n", + " do_lower_case=False, \n", + " bos=True, \n", + " eos=True, \n", + " pad=True, \n", + " user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n", + ")\n", + "\n", + "tokenizer = SentencePieceTokenizer(tok_path)\n", + "prompt = PromptFormatter.resolve(\"llama2\")(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "6988777c9dc1653b", + "metadata": {}, + "source": [ + "Now, we'll convert the data types to a training/inference friendly format. Specifically, we want to have 4 tensors:\n", + "* `context_ids`: token IDs that serve as the input for LLM (e.g. user query, conversation history, etc.)\n", + "* `answer_ids`: token IDs that serve as the answer for LLM (assistant response)\n", + "* `input_ids`: concatenated `context_ids` and `answer_ids`\n", + "* `mask`: loss mask that's only set to `True` for each token belonging to each of assistant's turns. Same length as `input_ids`.\n", + "\n", + "Let's first go through Cut, SourceTargetTextExample, and NeMoMultimodalConversation to see what happens with them." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5f8c0a54189e443d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cut:\n", + "\t* input_ids [INST] Repeat after me: [/INST] Welcome to Lhotse!\n", + "\t* context_ids [INST] Repeat after me: [/INST]\n", + "\t* answer_ids Welcome to Lhotse!\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True])\n", + "\n", + "SourceTargetTextExample:\n", + "\t* input_ids [INST] Some machine translation example. [/INST] Algunos ejemplos de traducci ⁇ n autom ⁇ tica.\n", + "\t* context_ids [INST] Some machine translation example. [/INST]\n", + "\t* answer_ids Algunos ejemplos de traducci ⁇ n autom ⁇ tica.\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True])\n", + "\n", + "NeMoMultimodalConversation:\n", + "\t* input_ids [INST] Is this a text-only conversation? [/INST] No, feel free to speak to me. [INST] [audio] [/INST] Should I respond in voice too? [INST] Yes [/INST] Certainly! [audio]\n", + "\t* context_ids [INST] Is this a text-only conversation? [/INST] No, feel free to speak to me. [INST] [audio] [/INST] Should I respond in voice too? [INST] Yes [/INST]\n", + "\t* answer_ids Certainly! [audio]\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " False, False, False, False, False, False, False, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True])\n", + "\n" + ] + } + ], + "source": [ + "from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn\n", + "\n", + "cut.context = \"Repeat after me:\"\n", + "print(\"Cut:\")\n", + "formatted = apply_prompt_format_fn(cut, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()\n", + "\n", + "print(\"SourceTargetTextExample:\")\n", + "formatted = apply_prompt_format_fn(text_pair, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()\n", + "\n", + "print(\"NeMoMultimodalConversation:\")\n", + "formatted = apply_prompt_format_fn(multimodal_conversation, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()" + ] + }, + { + "cell_type": "markdown", + "id": "e1b50937e5f75d10", + "metadata": {}, + "source": [ + "Note how each example got converted into the same prompt format. \n", + "\n", + "For multimodal conversation we have a special mechanism that replaces audio turns with an `audio_locator_tag`. \n", + "We expect that the tokenizer contains this tag as a special token.\n", + "The user will later replace these special tokens with audio representations (tokenized, or not) in the training step of the model. \n", + "\n", + "If you create a new prompt format, or a new data type, or want to specialize how a given data type is formatted with a given prompt, it is easily customizable by defining a single function with `@registered_prompt_format_fn(DataType, PromptFormatterType)` decorator. For example, if we created a new data type called `TextTriplet`, and added a default prompt format function, and another one specialized for Llama2:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "108b3593a5f16444", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids tensor([ 1, 9, 4, 9, 6, 9, 42, 9, 7, 9, 43, 9, 5, 9, 44, 2])\n", + "context_ids tensor([ 1, 9, 4, 9, 6, 9, 42, 9, 7, 9, 43, 9, 5])\n", + "answer_ids tensor([ 9, 44, 2])\n", + "mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, True, True, True])\n" + ] + } + ], + "source": [ + "from nemo.collections.common.prompts import Llama2PromptFormatter\n", + "from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn\n", + "from nemo.collections.common.data.lhotse.text_adapters import Formattable, CustomFieldMixin\n", + "\n", + "\n", + "@dataclass\n", + "class TextTriplet(Formattable, CustomFieldMixin):\n", + " # Note: we will explain Formattable and CustomFieldMixin in the next sections.\n", + " text1: str\n", + " text2: str\n", + " text3: str\n", + "\n", + "\n", + "@registered_prompt_format_fn(TextTriplet)\n", + "def text_triplets_generic(example: TextTriplet, prompt: PromptFormatter):\n", + " return prompt.encode_dialog(turns=[\n", + " {\"role\": \"user\", \"slots\": {\"message\": f\"{example.text1} {example.text2}\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": f\"{example.text3}\"}},\n", + " ])\n", + "\n", + " \n", + "@registered_prompt_format_fn(TextTriplet, Llama2PromptFormatter)\n", + "def text_triplets_llama2(example: TextTriplet, prompt: Llama2PromptFormatter):\n", + " return prompt.encode_dialog(turns=[\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": example.text1 , \"message\": example.text2}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": example.text3}},\n", + " ])\n", + "\n", + "\n", + "formatted = apply_prompt_format_fn(TextTriplet(\"A\", \"B\", \"C\"), prompt)\n", + "for k, v in formatted.items():\n", + " print(k, v)" + ] + }, + { + "cell_type": "markdown", + "id": "9565bef14a863465", + "metadata": {}, + "source": [ + "If we also created a data type parser for `TextTriplet` like we did for `SourceTargetTextExample` in the section before, we have a complete new data type support for dataloading. " + ] + }, + { + "cell_type": "markdown", + "id": "6ac39c8fcbcf5860", + "metadata": {}, + "source": [ + "## Support for sequence length stratification / dynamic bucketing\n", + "\n", + "References: \n", + "* [EMMeTT: Efficient Multimodal Machine Translation Training](https://arxiv.org/abs/2409.13523) \n", + "\n", + "We found that by using dynamic bucketing with [OOMptimizer](https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#pushing-gpu-utilization-to-the-limits-with-bucketing-and-oomptimizer) can significantly accelerate multimodal LLM training. \n", + "In order to ensure that all data types can benefit from this acceleration, we introduced the `Formattable` concept.\n", + "It indicates that a given data type supports prompt formatting and provides properties to measure input and output sequence length.\n", + "\n", + "Let's see this in action with the previously formatted data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f5ca38ea137f8210", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SourceTargetTextPair:\n", + "\t* input_length 39\n", + "\t* output_length 44\n", + "\t* total_length 83\n", + "\t* len(context_ids) 39\n", + "\t* len(answer_ids) 44\n", + "\t* len(input_ids) 83\n", + "NeMoMultimodalConversation\n", + "\t* input_length 191\n", + "\t* output_length 196\n", + "\t* total_length 387\n", + "\t* len(context_ids) 118\n", + "\t* len(answer_ids) 14\n", + "\t* len(input_ids) 132\n" + ] + } + ], + "source": [ + "print(\"SourceTargetTextPair:\")\n", + "text_pair = text_pair.apply_prompt_format(prompt)\n", + "print(\"\\t*\", \"input_length\", text_pair.input_length)\n", + "print(\"\\t*\", \"output_length\", text_pair.output_length)\n", + "print(\"\\t*\", \"total_length\", text_pair.total_length)\n", + "print(\"\\t*\", \"len(context_ids)\", len(text_pair.context_ids))\n", + "print(\"\\t*\", \"len(answer_ids)\", len(text_pair.answer_ids))\n", + "print(\"\\t*\", \"len(input_ids)\", len(text_pair.input_ids))\n", + "\n", + "print(\"NeMoMultimodalConversation\")\n", + "text_pair = multimodal_conversation.apply_prompt_format(prompt)\n", + "print(\"\\t*\", \"input_length\", multimodal_conversation.input_length)\n", + "print(\"\\t*\", \"output_length\", multimodal_conversation.output_length)\n", + "print(\"\\t*\", \"total_length\", multimodal_conversation.total_length)\n", + "print(\"\\t*\", \"len(context_ids)\", len(multimodal_conversation.context_ids))\n", + "print(\"\\t*\", \"len(answer_ids)\", len(multimodal_conversation.answer_ids))\n", + "print(\"\\t*\", \"len(input_ids)\", len(multimodal_conversation.input_ids))\n" + ] + }, + { + "cell_type": "markdown", + "id": "ecca372c2a0cad6e", + "metadata": {}, + "source": [ + "Note that for `NeMoMultimodalConversation` the length is much greater that the number of text tokens. \n", + "This is where `token_equivalent_duration` comes in: we want to factor in the audio turns into sequence lengths.\n", + "Since we know what is the duration of audio, we only need to know how much duration should be covered by each audio \"token\" or \"frame\".\n", + "A typical setup would be with NeMo FastConformer as an audio encoder, which uses 10ms frames at the input and subsamples them by a factor of 8 in the output. \n", + "The resulting `token_equivalent_duration` is therefore `0.08`, i.e., a single token created from audio is worth 80ms of duration. \n", + "For length computation, we sum the number of text tokens and the equivalent number of audio tokens.\n", + "\n", + "We can see that lhotse's `DynamicBucketingSampler` is able to process this data using NeMo multimodal sampling strategies:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6e295cfbfe8ff69b", + "metadata": {}, + "outputs": [], + "source": [ + "from lhotse.dataset import DynamicBucketingSampler\n", + "from nemo.collections.common.data.lhotse.sampling import MultimodalFixedBucketBatchSizeConstraint2D\n", + "\n", + "cuts = CutSet([multimodal_conversation]).repeat() # repeat makes iterable infinite\n", + "sampler = DynamicBucketingSampler(\n", + " cuts, \n", + " constraint=MultimodalFixedBucketBatchSizeConstraint2D(\n", + " max_seq_len_buckets=[32, 64, 128, 256, 512, 1024, 1536, 2048],\n", + " batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1],\n", + " token_equivalent_duration=0.08, \n", + " measure_total_length=True,\n", + " ),\n", + " buffer_size=10,\n", + ")\n", + "\n", + "batch = next(iter(sampler))\n", + "assert len(batch) == 4 \n", + "# Our conversation example fell into bucket number 4 (min: 256, max: 512) with an assigned batch size of 4" + ] + }, + { + "cell_type": "markdown", + "id": "4ff5baae-0771-4ac9-aa68-c3faee5aa261", + "metadata": {}, + "source": [ + "## Putting it all together to configure joint audio, text, and conversation dataloading\n", + "\n", + "We'll showcase some higher level APIs here. First, we'll create data examples on disk for three distinct types: audio to text, text to text, and multimodal conversations." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5a0e5433-3e63-4ab2-9290-001159a9b8e0", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from lhotse.serialization import save_to_jsonl\n", + "from lhotse.testing.dummies import dummy_recording\n", + "\n", + "# Prepare dummy ASR data\n", + "d = Path(\"_tutorial_data\")\n", + "!mkdir -p {d}/asr_shar\n", + "cut = dummy_recording(0, duration=17.11, with_data=True).to_cut()\n", + "cut.supervisions = [SupervisionSegment(id=cut.id, recording_id=cut.id, start=0.0, duration=cut.duration, text=\"Welcome to Lhotse!\")]\n", + "cut.context = \"Repeat after me\"\n", + "CutSet([cut.save_audio(d / \"rec.flac\")]).to_shar(d / \"asr_shar\", fields={\"recording\": \"flac\"})\n", + "\n", + "# Prepare dummy translation data\n", + "(d / \"src.txt\").write_text(\"A\")\n", + "(d / \"tgt.txt\").write_text(\"B\")\n", + "\n", + "# Prepare dummy multimodal conversation\n", + "save_to_jsonl(\n", + " [\n", + " {\n", + " \"id\": \"convo-1\",\n", + " \"conversations\": [\n", + " {\"from\": \"user\", \"value\": \"tell me what you hear\", \"type\": \"text\"},\n", + " {\"from\": \"user\", \"value\": str(d / \"rec.flac\"), \"duration\": cut.duration, \"type\": \"audio\"},\n", + " {\"from\": \"assistant\", \"value\": \"somebody just welcomed me to a himalayan mountain\", \"type\": \"text\"},\n", + " ]\n", + " }\n", + " ],\n", + " d / \"conv.jsonl\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3a4d669b-f816-4522-a491-ba31bfbf689c", + "metadata": {}, + "source": [ + "Now we'll configure a Lhotse dataloader to yield mini-batches with different data types in a round-robin fashion." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c4a7364e-c00f-4f60-9d72-9e7d228121cb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n", + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n", + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n" + ] + } + ], + "source": [ + "import torch\n", + "from omegaconf import OmegaConf\n", + "from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config\n", + "\n", + "# This configuration is typically present in NeMo training configs under `model.train_ds` key.\n", + "cfg = OmegaConf.create({\n", + " # Note that we have several sampler groups under keys: \"asr\", \"nmt\", and \"chat\".\n", + " # Each group has its own data source and sampling settings, i.e., you can define\n", + " # completely different batch sizes, sequence length filters, etc. for each type of data.\n", + " # To enable this behaviour, set multi_config to True.\n", + " \"multi_config\": True,\n", + " \n", + " \n", + " \"asr\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"lhotse_shar\", \n", + " \"shar_path\": d / \"asr_shar\"\n", + " }\n", + " ],\n", + " \"min_duration\": 0.5,\n", + " \"max_duration\": 40,\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [5, 10, 20, 40],\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + "\n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + "\n", + " # The first group defines a number of fields that will be later shared by all groups.\n", + " # sampler_fusion key determines how to yield batches from different samplers:\n", + " # * \"round_robin\" will just yield one type at a time\n", + " # * \"zip\" will sample a batch for each type and concatenate them, yielding a larger multimodal batch\n", + " # * \"randomized_round_robin\" expects an extra \"sampler_weights\" option which will define sampling probs for each group.:\n", + " \"sampler_fusion\": \"round_robin\",\n", + " \"shuffle\": True,\n", + " \"num_workers\": 0,\n", + " \"seed\": 0,\n", + " \"shard_seed\": \"trng\",\n", + " },\n", + "\n", + " \"nmt\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"txt_pair\", \n", + " \"source_paths\": d / \"src.txt\", \n", + " \"target_paths\": d / \"tgt.txt\"\n", + " }\n", + " ],\n", + " \"use_multimodal_sampling\": True, # will count tokens instead of seconds\n", + " \"min_tokens\": 1,\n", + " \"max_tokens\": 32,\n", + " \"measure_total_length\": False, # filters by input length instead of total length\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [[16, 16], [16, 32], [32, 16], [32, 32]], # 2D buckets\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + " \n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + " },\n", + "\n", + " \"chat\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"multimodal_conversation\", \n", + " \"manifest_filepath\": d / \"conv.jsonl\", \n", + " \"audio_locator_tag\": \"[audio]\"\n", + " }\n", + " ],\n", + " \"use_multimodal_sampling\": True, # will count tokens instead of seconds\n", + " \"min_tokens\": 1,\n", + " \"max_tokens\": 1024,\n", + " \"measure_total_length\": True,\n", + " \"token_equivalent_duration\": 0.08,\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [128, 256, 512, 1024],\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + "\n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + " },\n", + "})\n", + "\n", + "\n", + "# A no-op PyTorch Dataset class that will just return the data structures.\n", + "# In a real training setup, you'll want to implement conversion of a list of examples to a tensor mini-batch\n", + "# that is adequate for your model. \n", + "# Note that you can handle multiple types of examples to create appropriate mini-batch schema for each.\n", + "class Identity(torch.utils.data.Dataset):\n", + " def __getitem__(self, examples: CutSet):\n", + " return examples\n", + "\n", + "dloader = get_lhotse_dataloader_from_config(cfg, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e8768e28-663b-4d69-bb31-fbd6b80c0389", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0. Examples:\n", + "\t* MonoCut(id='dummy-recording-0000_repeat10', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 10, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* MonoCut(id='dummy-recording-0000_repeat41', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 41, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 1. Examples:\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 2. Examples:\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat0', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat1', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat2', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 3. Examples:\n", + "\t* MonoCut(id='dummy-recording-0000_repeat67', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 67, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* MonoCut(id='dummy-recording-0000_repeat16', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 16, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 4. Examples:\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n" + ] + } + ], + "source": [ + "for idx, batch in enumerate(dloader):\n", + " if idx == 5:\n", + " break\n", + " print(f\"Step {idx}. Examples:\")\n", + " for item in batch:\n", + " print(\"\\t*\", item)\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "704c44f5-bcce-4b4f-828b-fa1e18de8d71", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From e37ffbeb0ddf44a6a7b4bf17f8d4ced879cee0ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 23 Oct 2024 10:21:27 -0400 Subject: [PATCH 55/63] Updated documentation for multimodal dataloading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- docs/source/asr/datasets.rst | 278 ++++++++++++++++++++++++++++++----- 1 file changed, 245 insertions(+), 33 deletions(-) diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 2c0657d1c6ce..b8fc1cf56441 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -744,53 +744,265 @@ The final weight is the product of outer and inner weight: source_lang: pl target_lang: en -Configuring multi-modal dataloading +Configuring multimodal dataloading ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Our configuration format supports specifying data sources from other modalities than just audio. -At this time, this support is extended to text-only data. We provide the following parser types: +At this time, this support is extended to audio and text modalities. We provide the following parser types: -* ``txt`` for raw text files, sharded or unsharded. This can represent, for example, language modeling data. -* ``txt_pair`` for pairs of raw text files, sharded or unsharded. This can represent, for example, machine translation data. +**Raw text files.** Simple text files where each line is an individual text example. This can represent standard language modeling data. +This parser is registered under ``type: txt``. -The key strength of this approach is that we can easily combine audio datasets and text datasets, -and benefit from every other technique we described above such as dynamic data mixing, data weighting, dynamic bucketing, and so on. -To enable multimodal dataloading, we provide several configuration options: +Data format examples:: -* ``use_multimodal_sampling`` when set to True, we'll discard the settings of ``batch_duration`` and ``quadratic_duration`` and consider the settings below instead. + # file: document_0.txt + This is a language modeling example. + Wall Street is expecting major news tomorrow. -* ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens. + # file: document_1.txt + Invisible bats have stormed the city. + What an incredible event! -* ``token_equivalent_duration`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs. Note that this value acts as a "balancer" between how much audio data vs text data gets sampled into a mini-batch. +Dataloading configuration example:: -* ``quadratic_factor`` works the same way as ``quadratic_duration``, but is defined in the number of tokens. + input_cfg: + - type: txt + paths: /path/to/document_{0..1}.txt + language: en # optional -Example 3. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets. Provide a custom prompt field for both datasets (to be leveraged by a relevant dataset class): +Python object example:: -.. code-block:: yaml + from nemo.collections.common.data.lhotse.text_adapters import TextExample + + example = TextExample( + text="This is a language modeling example.", + language="en", # optional + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + {"type": "txt", "paths": "/path/to/document_{0..1}.txt", "language": "en"}, + ], + "use_multimodal_dataloading": True, + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**Raw text file pairs.** Pairs of raw text files with corresponding lines. This can represent machine translation data. +This parser is registered under ``type: txt_pair``. + +Data format examples:: + + # file: document_en_0.txt + This is a machine translation example. + Wall Street is expecting major news tomorrow. + + # file: document_pl_0.txt + To jest przykład tłumaczenia maszynowego. + Wall Street spodziewa się jutro ważnych wiadomości. + +Dataloading configuration example:: - use_multimodal_sampling: true - batch_tokens: 1024 - token_equivalent_duration: 0.08 # 0.01 frame shift * 8 subsampling factor - quadratic_factor: 50 - num_buckets: 30 - use_bucketing: true input_cfg: - - type: nemo_tarred - manifest_filepath: /path/to/manifest__OP_0..512_CL_.json - tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar - weight: 0.5 - tags: - lang: en - prompt: "Given the following recording, transcribe what the person is saying:" - type: txt_pair - source_path: /path/to/en__OP_0..512_CL_.txt - target_path: /path/to/pl__OP_0..512_CL_.txt - source_language: en - target_language: pl - weight: 0.5 - tags: - prompt: "Translate the following text to Polish:" + source_path: /path/to/document_en_{0..N}.txt + target_path: /path/to/document_pl_{0..N}.txt + source_language: en # optional + target_language: pl # optional + +Python object example:: + + from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample + + example = SourceTargetTextExample( + source=TextExample( + text="This is a language modeling example.", + language="en", # optional + ), + target=TextExample( + text="To jest przykład tłumaczenia maszynowego.", + language="pl", # optional + ), + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + { + "type": "txt_pair", + "source_path": "/path/to/document_en_{0..N}.txt", + "target_path": "/path/to/document_pl_{0..N}.txt", + "source_language": "en" + "target_language": "en" + }, + ], + "use_multimodal_dataloading": True, + "prompt_format": "t5nmt", + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**NeMo multimodal conversations.** A JSON-Lines (JSONL) file that defines multi-turn conversations with mixed text and audio turns. +This parser is registered under ``type: multimodal_conversation``. + +Data format examples:: + + # file: chat_0.jsonl + {"id": "conv-0", "conversations": [{"from": "user", "value": "speak to me", "type": "text"}, {"from": "assistant": "value": "/path/to/audio.wav", "duration": 17.1, "type": "audio"}]} + +Dataloading configuration example:: + + token_equivalent_duration: 0.08 + input_cfg: + - type: multimodal_conversation + manifest_filepath: /path/to/chat_{0..N}.jsonl + audio_locator_tag: [audio] + +Python object example:: + + from lhotse import Recording + from nemo.collections.common.data.lhotse.text_adapters import MultimodalConversation, TextTurn, AudioTurn + + conversation = NeMoMultimodalConversation( + id="conv-0", + turns=[ + TextTurn(value="speak to me", role="user"), + AudioTurn(cut=Recording.from_file("/path/to/audio.wav").to_cut(), role="assistant", audio_locator_tag="[audio]"), + ], + token_equivalent_duration=0.08, # this value will be auto-inserted by the dataloader + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": "/path/to/chat_{0..N}.jsonl", + "audio_locator_tag": "[audio]", + }, + ], + "use_multimodal_dataloading": True, + "token_equivalent_duration": 0.08, + "prompt_format": "llama2", + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**Dataloading and bucketing of text and multimodal data.** When dataloading text or multimodal data, pay attention to the following config options (we provide example values for convenience): + +* ``use_multimodal_sampling: true`` tells Lhotse to switch from measuring audio duration to measuring token counts; required for text. + +* ``prompt_format: "prompt-name"`` will apply a specified PromptFormatter during data sampling to accurately reflect its token counts. + +* ``measure_total_length: true`` customizes length measurement for decoder-only and encoder-decoder models. Decoder-only models consume a linear sequence of context + answer, so we should measure the total length (``true``). On the other hand, encoder-decoder models deal with two different sequence lengths: input (context) sequence length for the encoder, and output (answer) sequence length for the decoder. For such models set this to ``false``. + +* ``min_tokens: 1``/``max_tokens: 4096`` filters examples based on their token count (after applying the prompt format). + +* ``min_tpt: 0.1``/``max_tpt: 10`` filter examples based on their output-token-per-input-token-ratio. For example, a ``max_tpt: 10`` means we'll filter every example that has more than 10 output tokens per 1 input token. Very useful for removing sequence length outliers that lead to OOM. Use ``estimate_token_bins.py`` to view token count distributions for calbirating this value. + +* (multimodal-only) ``token_equivalent_duration: 0.08`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs. + +**Text/multimodal bucketing and OOMptimizer.** Analogous to bucketing for audio data, we provide two scripts to support efficient bucketing: + +* ``scripts/speech_llm/estimate_token_bins.py`` which estimates 1D or 2D buckets based on the input config, tokenizer, and prompt format. It also estimates input/output token count distribution and suggested ``max_tpt`` (token-per-token) filtering values. + +* (experimental) ``scripts/speech_llm/oomptimizer.py`` which works with SALM/BESTOW GPT/T5 models and estimates the optimal ``bucket_batch_size`` for a given model config and bucket bins value. Given the complexity of Speech LLM some configurations may not be supported yet at the time of writing (e.g., model parallelism). + +To enable bucketing, set ``batch_size: null`` and use the following options: + +* ``use_bucketing: true`` + +* ``bucket_duration_bins`` - the output of ``estimate_token_bins.py``. If ``null``, it will be estimated at the start of training at the cost of some run time (not recommended). + +* (oomptimizer-only) ``bucket_batch_size`` - the output of OOMptimizer. + +* (non-oomptimizer-only) ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens. Note that it's just a heuristic for determining the optimal batch sizes for different buckets, and may be less efficient than using OOMptimizer. + +* (non-oomptimizer-only) ``quadratic_factor`` is a quadratic penalty to equalize the GPU memory usage between buckets of short and long sequence lengths for models with quadratic memory usage. It is only a heuristic and may not be as efficient as using OOMptimizer. + +**Joint dataloading of text/audio/multimodal data.** The key strength of this approach is that we can easily combine audio datasets and text datasets, +and benefit from every other technique we described in this doc, such as: dynamic data mixing, data weighting, dynamic bucketing, and so on. + +This approach is described in the `EMMeTT`_ paper. There's also a notebook tutorial called Multimodal Lhotse Dataloading. We construct a separate sampler (with its own batching settings) for each modality, +and specify how the samplers should be fused together via the option ``sampler_fusion``: + +* ``sampler_fusion: "round_robin"`` will iterate single sampler per step, taking turns. For example: step 0 - audio batch, step 1 - text batch, step 2 - audio batch, etc. + +* ``sampler_fusion: "randomized_round_robin"`` is similar, but at each chooses a sampler randomly using ``sampler_weights: [w0, w1]`` (weights can be unnormalized). + +* ``sampler_fusion: "zip"`` will draw a mini-batch from each sampler at every step, and merge them into a single ``CutSet``. This approach combines well with multimodal gradient accumulation (run forward+backward for one modality, then the other, then the update step). + +.. _EMMeTT: https://arxiv.org/abs/2409.13523 + +Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets: + +.. code-block:: yaml + + model: + ... + train_ds: + multi_config: True, + audio: + sampler_fusion: zip + shuffle: true + num_workers: 4 + + prompt_format: t5nmt + use_bucketing: true + min_duration: 0.5 + max_duration: 30.0 + max_tps: 12.0 + bucket_duration_bins: [[3.16, 10], [3.16, 22], [5.18, 15], ...] + bucket_batch_size: [1024, 768, 832, ...] + input_cfg: + - type: nemo_tarred + manifest_filepath: /path/to/manifest__OP_0..512_CL_.json + tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar + weight: 0.5 + tags: + context: "Translate the following to English" + text: + prompt_format: t5nmt + use_multimodal_sampling: true + min_tokens: 1 + max_tokens: 256 + min_tpt: 0.333 + max_tpt: 3.0 + measure_total_length: false + use_bucketing: true + bucket_duration_bins: [[10, 4], [10, 26], [15, 10], ...] + bucket_batch_size: [512, 128, 192, ...] + input_cfg: + - type: txt_pair + source_path: /path/to/en__OP_0..512_CL_.txt + target_path: /path/to/pl__OP_0..512_CL_.txt + source_language: en + target_language: pl + weight: 0.5 + tags: + question: "Translate the following to Polish" .. caution:: We strongly recommend to use multiple shards for text files as well so that different nodes and dataloading workers are able to randomize the order of text iteration. Otherwise, multi-GPU training has a high risk of duplication of text examples. From c6ffb40a927cc7742367599ecc45bbbbec83c777 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 23 Oct 2024 11:36:24 -0400 Subject: [PATCH 56/63] Prompt Formatter tutorial MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../Prompt Formatter Tutorial.ipynb | 458 ++++++++++++++++++ 1 file changed, 458 insertions(+) create mode 100644 tutorials/multimodal/Prompt Formatter Tutorial.ipynb diff --git a/tutorials/multimodal/Prompt Formatter Tutorial.ipynb b/tutorials/multimodal/Prompt Formatter Tutorial.ipynb new file mode 100644 index 000000000000..85f220115e13 --- /dev/null +++ b/tutorials/multimodal/Prompt Formatter Tutorial.ipynb @@ -0,0 +1,458 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cd408a7a-d4b6-4f33-83d3-c607dbc5f580", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# Prompt Formatter Tutorial\n", + "\n", + "This tutorial introduces NeMo's PromptFormatter API available in module `nemo.collections.common.prompts`.\n", + "After finishing this tutorial you will be familiar with the existing prompt formatters, how to use them, and how to build your own.\n", + "\n", + "We cover the following topics:\n", + "\n", + "* Using existing prompt formatters with Llama2 as an example.\n", + "\n", + "* Defining your own prompt formatter.\n", + "\n", + "We also support applying prompt formatters for multimodal data and Lhotse-compatible data types. To learn more, see our other tutorial: [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "3f87f30c-79c0-41e8-b126-283ff5436465", + "metadata": {}, + "source": [ + "### Pre-requsite: building a dummy tokenizer\n", + "\n", + "We're going to need a tokenizer to work with prompt formatters - we'll just build a dummy one for the purpose of this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e91ebef5-9a25-4eb1-8211-d0f5990f7c37", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-23 11:26:41 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n" + ] + } + ], + "source": [ + "import string\n", + "import shlex\n", + "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n", + "\n", + "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n", + "\n", + "tok_path, vocab_path = create_spt_model(\n", + " data_file=\"_tutorial_train_text.txt\", \n", + " output_dir=\"_tutorial_spt\",\n", + " vocab_size=512, \n", + " sample_size=-1, \n", + " do_lower_case=False, \n", + " bos=True, \n", + " eos=True, \n", + " pad=True, \n", + " user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n", + ")\n", + "\n", + "tokenizer = SentencePieceTokenizer(tok_path)\n", + "\n", + "def display(encoded_chat, with_mask=False):\n", + " \"\"\"Utility for printing prompt formatted chats.\"\"\"\n", + " for key, val in encoded_chat.items():\n", + " if key.endswith(\"_ids\"):\n", + " print(key, '--', tokenizer.ids_to_text(val), '\\n')\n", + " if key == \"mask\" and with_mask:\n", + " print(key, '--', val)" + ] + }, + { + "cell_type": "markdown", + "id": "4c5c6c88-c882-4305-8757-585fec3eab46", + "metadata": {}, + "source": [ + "## Using an existing PromptFormatter: Llama2\n", + "\n", + "\n", + "**Instanting the prompt formatter.** Let's start with a simple example of Llama2 prompt format use." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c77a993e-453f-474e-8912-fd35c7fc39ba", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.prompts.llama import Llama2PromptFormatter\n", + "from pprint import pprint\n", + "\n", + "prompt = Llama2PromptFormatter(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "92054a0f-5b97-4178-94b8-a27e62acf97b", + "metadata": {}, + "source": [ + "**Chat example.** We'll define a multi-turn conversation between the user and assistant below:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c5eabe5e-4160-41d7-ad85-a4df596de38b", + "metadata": {}, + "outputs": [], + "source": [ + "chat = [\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "eff61b98-c7be-4345-ac97-15573d1a9533", + "metadata": {}, + "source": [ + "**Prompt formatter outputs.** Now, we apply prompt formatter to that conversation to obtain four tensors useful for training:\n", + "* `context_ids` encode the whole dialog history up to the last response of the assistant;\n", + "* `answer_ids` encode the last response of the assistant;\n", + "* `input_ids` encode the full conversation;\n", + "* `mask` is a boolean training loss mask that's set to `True` for every token belonging to assistant's turns.\n", + "\n", + "Since the token IDs are meaningless, we'll apply reverse tokenizer for displaying the prompt formatted example." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a10216b3-2bbe-4a2f-8ca8-557c3b9056be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "answer_ids -- In order to build your own audio amplifier, start with ... \n", + "\n", + "mask -- tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True])\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(chat)\n", + "display(encoded, with_mask=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e181618e-6df8-44b2-b986-15660133e486", + "metadata": {}, + "source": [ + "**System prompt.** We also support the system prompt. Since it affects the prompt format in a non-trivial way, it is defined as a separate role `\"system_and_user\"`, which has two slots `\"system\"` and `\"message\"`. We'll omit printing the mask for brevity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2c3476a4-b301-4f35-9520-90d4b919363d", + "metadata": {}, + "outputs": [], + "source": [ + "chat_with_system = [\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5c8c329d-f8b3-48cb-b664-baed0fcd90ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "answer_ids -- In order to build your own audio amplifier, start with ... \n", + "\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(chat_with_system)\n", + "display(encoded)" + ] + }, + { + "cell_type": "markdown", + "id": "a453345a-6456-43ed-a663-0554c459fddb", + "metadata": {}, + "source": [ + "**Constructing inference-time prompts.** During inference, we don't know what's the last turn of the assistant - we only want to construct the ``context_ids`` tensor. In those cases, just omit the last assistant's turn. The prompt formatter will return the ``context_ids`` tensor (with ``input_ids`` alias for it too)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4ede7100-9d28-4cf0-ab75-bfede9936218", + "metadata": {}, + "outputs": [], + "source": [ + "inference_chat = [\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "61bf8e77-0630-4a84-bd30-ca4c27f8d898", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(inference_chat)\n", + "display(encoded)" + ] + }, + { + "cell_type": "markdown", + "id": "a334e00a-9530-4333-98de-5cb8fb08eb47", + "metadata": {}, + "source": [ + "### How is Llama2 PromptFormatter built\n", + "\n", + "`Llama2PromptFormatter` is a small class with prompt definition that inherits `PromptFormatter`, which implements the logic for applying prompt format and tokenization to multi-turn conversations. \n", + "\n", + "Let's take a look at `Llama2PromptFormatter` definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f29fbf2f-3caa-4b27-86ca-5012d9fc6ba5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class Llama2PromptFormatter(PromptFormatter):\n", + " \"\"\"\n", + " This template has been validated to provide identical tokenized results to the official code\n", + " in https://github.com/meta-llama/llama/blob/main/llama/generation.py\n", + " \"\"\"\n", + "\n", + " NAME = \"llama2\"\n", + " OUTPUT_ROLE = \"assistant\"\n", + " TEMPLATE = {\n", + " \"system_and_user\": {\n", + " \"template\": f\"{BOS_SLOT}[INST] <>\\n|system|\\n<>\\n\\n|message| [/INST]\",\n", + " \"slots\": {\n", + " \"system\": Modality.Text,\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " \"user\": {\n", + " \"template\": f\"{BOS_SLOT}[INST] |message| [/INST]\",\n", + " \"slots\": {\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " OUTPUT_ROLE: {\n", + " \"template\": f\"|message| {EOS_SLOT}\",\n", + " \"slots\": {\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " }\n", + "\n" + ] + } + ], + "source": [ + "import inspect\n", + "print(inspect.getsource(Llama2PromptFormatter))" + ] + }, + { + "cell_type": "markdown", + "id": "b24e9310-b8ed-4e35-9dda-d24aa62cfb6a", + "metadata": {}, + "source": [ + "As you can see, the definition consist of the following key components:\n", + "* Derives `PromptFormatter` parent class.\n", + "* Specifies `NAME`, which is used for dynamic resolution of string to class via `cls = PromptFormatter.resolve(name)`.\n", + "* Specifies `OUTPUT_ROLE`, which is the name for the role with assistant's responses (typically `\"assistant\"`).\n", + "* Specifies `TEMPLATE` which defines the dialog structure and how user-provided values (slots) are applied to prompts. Notably:\n", + " * The slots are wrapped into pipe operators `\"|\"` in the prompt template definition, and substituted with user provided values before tokenization.\n", + " * `\"system_and_user`\" role has two slots, `\"system\"` and `\"message\"`, and a template that wraps them with Llama2 special tokens.\n", + " * We use `BOS_SLOT` and `EOS_SLOT` to insert sentencepiece tokenizer's `bos_id` and `eos_id` in the right places (remember that sentencepiece won't tokenize them from text, they need to be inserted programmatically).\n", + " * The slots have a type, currently supported types are `Modality.Text` and `Modality.TextLiteral(value1, value2, ...)` that allows to restrict the set of slots values." + ] + }, + { + "cell_type": "markdown", + "id": "8cbdca6c-6c0f-42a9-a4a7-b936684c6e12", + "metadata": {}, + "source": [ + "## Defining your own prompt formatter" + ] + }, + { + "cell_type": "markdown", + "id": "25a9b6d2-d004-4f7f-8b24-4fd6d4eae244", + "metadata": {}, + "source": [ + "Generally you can follow the definition of existing prompt formatters to define your own. \n", + "We have several prompt formats implemented for Llama, Gemma, Phi, etc. \n", + "\n", + "We'll define a custom simple prompt format that has no system prompt below as an illustration:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b69f6532-24d8-4419-b1da-42184c3d72de", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.prompts.formatter import PromptFormatter, Modality\n", + "\n", + "class MyPrompt(PromptFormatter):\n", + " NAME = \"myprompt\"\n", + " OUTPUT_ROLE = \"assistant\"\n", + " TEMPLATE = {\n", + " \"user\": {\n", + " \"template\": \"User: |message|\\n\",\n", + " \"slots\": {\"message\": Modality.Text},\n", + " },\n", + " \"assistant\": {\n", + " \"template\": \"Assistant: |message|\\n\",\n", + " \"slots\": {\"message\": Modality.Text},\n", + " },\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a97c6589-1303-446c-952f-d2b4007ca7e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? Assistant: In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? \n", + "\n", + "answer_ids -- Assistant: In order to build your own audio amplifier, start with ... \n", + "\n" + ] + } + ], + "source": [ + "my_prompt_cls = PromptFormatter.resolve(\"myprompt\") # it is auto-registered\n", + "my_prompt = my_prompt_cls(tokenizer)\n", + "display(my_prompt.encode_dialog(chat))" + ] + }, + { + "cell_type": "markdown", + "id": "30f9c96a-6cf8-4cd3-b0e8-6b461c86100f", + "metadata": {}, + "source": [ + "## Applying prompt formatter to multimodal data\n", + "\n", + "We refer the reader to our other tutorial, [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb), where this is discussed in detail." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 5cbca90e45590d494c2da7740264e9e7e34facc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 29 Oct 2024 11:05:39 -0400 Subject: [PATCH 57/63] Review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- ...dio_gpt_config_cross_llama_lhotse_multi.yaml | 5 ++++- .../salm/modular_audio_t5_multi_config.yaml | 4 ++++ .../speech_llm/models/modular_t5_models.py | 17 ++++------------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 3d0c1c43bf4a..14b471448b9e 100644 --- a/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # This configuration is similar to modular_audio_gpt_config_cross_llama_lhotse.yaml, # with the difference being in how it performs multimodal sampling. # The changes are in model.data.train_ds section. @@ -25,6 +24,10 @@ # or zip (sample mini-batch from each and combine them). name: megatron_audio_gpt_bestow_lhotse_multi_sampler +# Note: This config has been updated to work with PromptFormatter API. +# If you used an older version that defined a `train_ds.prompt_template` field, +# you should specify the prompt format using `train_ds..prompt_format` now instead. + trainer: devices: 1 accelerator: gpu diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml index e0b262db3adb..09f9987eee8f 100644 --- a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -25,6 +25,10 @@ # or zip (sample mini-batch from each and combine them). name: megatron_audio_t5_salm_lhotse_multi_sampler +# Note: This config has been updated to work with PromptFormatter API. +# If you used an older version that defined a `train_ds.prompt_template` field, +# you should specify the prompt format using `train_ds..prompt_format` now instead. + trainer: devices: 1 accelerator: gpu diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 00cda52539a4..e3315e6f0025 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -344,13 +344,8 @@ def prepare_llm_input(self, audio_batch): input_signal = audio_batch['audio_signal'] input_signal_length = audio_batch['audio_signal_length'] - - input_ids, input_length, labels, loss_mask = ( - audio_batch['contexts'], - audio_batch['context_lengths'], - audio_batch['answers'], - audio_batch['loss_mask'], - ) + input_ids = audio_batch['contexts'] + input_length = audio_batch['context_lengths'] # [b, t, c] encoded, encoded_len = self.perception( @@ -387,12 +382,8 @@ def forward( # enc_input = speech and text prompt # dec_input and label = text output label b = audio_batch['answers'].shape[0] - device = audio_batch['answers'].device - dec_input = ( - audio_batch['masked_answer_ids'] if 'masked_answer_ids' in audio_batch else audio_batch['answers'] - ) - dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=device), dec_input[:, :-1]], dim=-1) labels = audio_batch['answers'] + dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=labels.device), labels[:, :-1]], dim=-1) dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() output = self.frozen_model.enc_dec_model( enc_input_ids=None, @@ -994,7 +985,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A audio_batch or text_batch and not (audio_batch and text_batch) ), f"Expecting only text or audio batch, got {len(text_batch)=} and {len(audio_batch)=}" - if 'audio_signal' in audio_batch: + if audio_batch: input_text = audio_batch['contexts'] labels = audio_batch['answers'] encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) From eca816f6bd1a2c0d50e68bffe2f396c2c449dd73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 20 Nov 2024 12:56:30 -0500 Subject: [PATCH 58/63] Fixes for sampling filters None values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/sampling.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 8de49bbaf500..9eba1cee0c69 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -7,6 +7,7 @@ from lhotse.cut import Cut from lhotse.dataset import SamplingConstraint, TokenConstraint from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.utils import ifnone from nemo.collections.common.data.lhotse.text_adapters import Formattable @@ -193,9 +194,9 @@ class DurationFilter: Acts as a pass-through for objects of other type than Cut. """ - def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max + def __init__(self, d_min: float | None, d_max: float | None) -> None: + self.d_min = ifnone(d_min, -1) + self.d_max = ifnone(d_max, float("inf")) def __call__(self, example) -> bool: if isinstance(example, Cut): @@ -221,14 +222,12 @@ class TokenCountFilter: and enable ``TokenPerTokenFilter`` for additional filtering on the output sequence length. """ - def __init__(self, t_min: float, t_max: float, measure_total_length: bool) -> None: - self.t_min = t_min - self.t_max = t_max + def __init__(self, t_min: float | None, t_max: float | None, measure_total_length: bool) -> None: + self.t_min = ifnone(t_min, -1) + self.t_max = ifnone(t_max, float("inf")) self.measure_total_length = measure_total_length def __call__(self, example) -> bool: - if self.t_min is None and self.t_max is None: - return True # disabled if isinstance(example, Cut): return True # does not apply to Cuts assert isinstance(example, Formattable), ( @@ -254,10 +253,10 @@ class TokenPerSecondFilter: Acts as a pass-through for objects of other type than Cut. """ - def __init__(self, tps_min: float, tps_max: float) -> None: - assert tps_min <= tps_max - self.tps_min = tps_min - self.tps_max = tps_max + def __init__(self, tps_min: float | None, tps_max: float | None) -> None: + self.tps_min = ifnone(tps_min, -1) + self.tps_max = ifnone(tps_max, float("inf")) + assert tps_min <= tps_max, f"{tps_min=} {tps_max=}" self.enabled = tps_min > 0 or tps_max < float("inf") def __call__(self, example) -> bool: @@ -274,10 +273,10 @@ class TokenPerTokenFilter: Acts as a pass-through for audio examples (Cuts). """ - def __init__(self, tpt_min: float, tpt_max: float) -> None: - assert tpt_min <= tpt_max - self.tpt_min = tpt_min - self.tpt_max = tpt_max + def __init__(self, tpt_min: float | None, tpt_max: float | None) -> None: + self.tpt_min = ifnone(tpt_min, -1) + self.tpt_max = ifnone(tpt_max, float("inf")) + assert tpt_min <= tpt_max, f"{tpt_min=} {tpt_max=}" self.enabled = tpt_min > 0 or tpt_max < float("inf") def __call__(self, example) -> bool: From a46ad651eea35a32431305ef3574fa08a5ba6557 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 14:55:55 -0500 Subject: [PATCH 59/63] Changes requested by Steve: moving some args to main config namespace in multi config sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/common/data/lhotse/cutset.py | 9 +- .../common/data/lhotse/dataloader.py | 148 +++++++++------ .../common/data/lhotse/sampling.py | 3 +- .../common/test_lhotse_dataloading.py | 174 ++++++++++++++++-- 4 files changed, 262 insertions(+), 72 deletions(-) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 65027e366fbe..406aa558bb15 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -49,15 +49,18 @@ def read_cutset_from_config(config: DictConfig | dict) -> Tuple[CutSet, bool]: # Now, we'll figure out if we should read Lhotse manifest or NeMo manifest. use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path")) if use_nemo_manifest: - assert ( - config.get("manifest_filepath") is not None - ), "You must specify either: manifest_filepath, cuts_path, or shar_path" + if config.get("manifest_filepath") is None: + raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path") cuts, is_tarred = read_nemo_manifest(config) else: cuts, is_tarred = read_lhotse_manifest(config) return cuts, is_tarred +class IncompleteConfigError(RuntimeError): + pass + + KNOWN_DATA_CONFIG_TYPES = {} diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 61451ef09e1f..b68b3b3c2599 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -38,7 +38,11 @@ from lhotse.utils import fastcopy, fix_random_seed from omegaconf import DictConfig, OmegaConf -from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config +from nemo.collections.common.data.lhotse.cutset import ( + IncompleteConfigError, + guess_parse_cutset, + read_cutset_from_config, +) from nemo.collections.common.data.lhotse.sampling import ( DurationFilter, FixedBucketBatchSizeConstraint2D, @@ -92,8 +96,12 @@ class LhotseDataLoadingConfig: shard_seed: int | str = "trng" max_open_streams: int | None = None cuda_expandable_segments: bool = True - sampler_fusion: str = "mux" # mux | zip | round_robin | randomized_round_robin - sampler_weights: list[float] | None = None # only applicable to randomized_round_robin + # e. Multi-config related options. + # Setting multi_config=True will scan the config for keys with DictConfig values, + # create a separate sampler for each, and fuse the samplers according to sampler_fusion. + multi_config: bool = False + sampler_fusion: str = "round_robin" # round_robin | randomized_round_robin | zip + sampler_weights: dict[str, float] | None = None # only applicable to randomized_round_robin # 2.1 Multimodal sampling override options pretokenize: bool = True # should we apply tokenizer before data sampling @@ -186,7 +194,7 @@ class LhotseDataLoadingConfig: def get_lhotse_dataloader_from_config( - config: DictConfig, + config: dict | DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, @@ -211,9 +219,15 @@ def get_lhotse_dataloader_from_config( If "prompt_format" is additionally provided in the config, we will also apply a prompt formatter. Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ - if config.get("multi_config"): + if not isinstance(config, DictConfig): + config = OmegaConf.create(config) + + # Providing default value because we haven't filled the config defaults yet. + maybe_set_cuda_expandable_segments(enabled=config.get("cuda_expandable_segments", True)) + + if config.get("multi_config", False): return get_lhotse_dataloader_from_multi_config( - configs=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer + config=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer ) else: return get_lhotse_dataloader_from_single_config( @@ -253,16 +267,10 @@ def get_lhotse_dataloader_from_single_config( config = make_structured_with_schema_warnings(config) - maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) - # First, resolve the random seed in case a string value was provided. config.seed = resolve_seed(config.seed) fix_random_seed(config.seed) - assert config.sampler_fusion == "mux", ( - "In order to use a sampler_fusion strategy different than 'mux', " - "create your dataloader using 'get_lhotse_dataloader_from_multi_config' instead." - ) sampler, use_iterable_dataset = get_lhotse_sampler_from_config( config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer ) @@ -297,7 +305,7 @@ def get_lhotse_dataloader_from_single_config( def get_lhotse_dataloader_from_multi_config( - configs: DictConfig, + config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, @@ -313,44 +321,80 @@ def get_lhotse_dataloader_from_multi_config( The first config is treated as a "main" config that determines the RNG, CUDA allocator, and sampler fusion settings. """ - configs = [make_structured_with_schema_warnings(c) for c in configs.values() if isinstance(c, DictConfig)] - main_config = configs[0] - maybe_set_cuda_expandable_segments(enabled=main_config.cuda_expandable_segments) - seed = resolve_seed(main_config.seed) - fix_random_seed(seed) - - source_samplers, source_use_iterable_dataset = [], [] - for config in configs: - # TODO(pzelasko): perhaps emit a warning in the unlikely case somebody defines different seeds explicitly. - config.seed = seed - config.shard_seed = main_config.shard_seed - s, t = get_lhotse_sampler_from_config( - config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer - ) - source_samplers.append(s) + + def gather_shared_opts(): + """ + In multi-config setting, the top-level config defines several attributes that overwrite + the ones present in sub-configs. + """ + assert all( + k in config for k in ["seed", "shard_seed", "shuffle"] + ), "In a multi-config setting (multi_config=True), the top-level namespace (typically train_ds) must define at least 'seed', 'shard_seed', and 'shuffle' keys that will be shared by all sub-configs." + overwriting_opts = [ + "seed", + "shard_seed", + "num_workers", + "pin_memory", + "shuffle", + "sampler_fusion", + "sampler_weights", + "multi_config", + "metadata_only", + "force_finite", + ] + defaults = OmegaConf.structured(LhotseDataLoadingConfig) + config["seed"] = resolve_seed(config["seed"]) + return OmegaConf.create({k: config.get(k, defaults[k]) for k in overwriting_opts}) + + shared_opts = gather_shared_opts() + fix_random_seed(shared_opts.seed) + + configs = { + name: c + for name, c in config.items() + if isinstance(c, DictConfig) and name not in ("sampler_weights",) # exclude dict opts + } + for k, v in shared_opts.items(): + for config in configs.values(): + config[k] = v + + source_samplers, source_use_iterable_dataset = {}, [] + for name, config in configs.items(): + try: + expanded_config = make_structured_with_schema_warnings(config) + s, t = get_lhotse_sampler_from_config( + config=expanded_config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer + ) + except IncompleteConfigError as e: + raise IncompleteConfigError( + f"Cannot create a sampler for one of the sub-configs in a multi_config setup. The problematic config is under key={name} and has the following contents: {config}" + ) from e + source_samplers[name] = s source_use_iterable_dataset.append(t) - assert all( - st == source_use_iterable_dataset[0] for st in source_use_iterable_dataset[1:] - ), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)." + assert all(st == source_use_iterable_dataset[0] for st in source_use_iterable_dataset[1:]), ( + "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix). " + "You can provide force_iterable_dataset=True to each namespace to fix." + ) use_iterable_dataset = all(source_use_iterable_dataset) - if main_config.sampler_fusion == "zip": - sampler = ZipSampler(*source_samplers) - elif main_config.sampler_fusion == "round_robin": - sampler = RoundRobinSampler(*source_samplers) - elif main_config.sampler_fusion == "randomized_round_robin": - sampler = RoundRobinSampler( - *source_samplers, - randomize=True if main_config.sampler_weights is None else main_config.sampler_weights, - seed=seed, - ) - elif main_config.sampler_fusion == "mux": - raise RuntimeError( - "In order to use a sampler_fusion strategy 'mux', " - "create your dataloader using 'get_lhotse_dataloader_from_config' instead." - ) - else: - raise RuntimeError(f"Unsupported sampler fusion strategy: {main_config.sampler_fusion}") + match shared_opts.sampler_fusion: + case "zip": + sampler = ZipSampler(*source_samplers.values()) + case "round_robin": + sampler = RoundRobinSampler(*source_samplers.values()) + case "randomized_round_robin": + _samplers, _weights = [], [] + for key in source_samplers.keys(): + _samplers.append(source_samplers[key]) + if shared_opts.sampler_weights is not None: + _weights.append(shared_opts.sampler_weights[key]) + sampler = RoundRobinSampler( + *_samplers, + randomize=_weights if len(_weights) > 0 else True, + seed=shared_opts.seed, + ) + case unknown_value: + raise RuntimeError(f"Unsupported sampler fusion strategy: {unknown_value}") # 4. Creating dataloader. if use_iterable_dataset: @@ -363,8 +407,8 @@ def get_lhotse_dataloader_from_multi_config( # This together with infinite datasets removes the need to split data across nodes/workers. dloader_kwargs = dict( dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler), - worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed), - persistent_workers=main_config.num_workers > 0, # helps Lhotse Shar maintain shuffling state + worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=shared_opts.seed), + persistent_workers=shared_opts.num_workers > 0, # helps Lhotse Shar maintain shuffling state ) else: # For non-tarred data, the sampler resides in the training loop process and @@ -374,8 +418,8 @@ def get_lhotse_dataloader_from_multi_config( dloader = torch.utils.data.DataLoader( **dloader_kwargs, batch_size=None, - num_workers=main_config.num_workers, - pin_memory=main_config.pin_memory, + num_workers=shared_opts.num_workers, + pin_memory=shared_opts.pin_memory, ) return dloader diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py index 9eba1cee0c69..5206b9b1dec0 100644 --- a/nemo/collections/common/data/lhotse/sampling.py +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -226,9 +226,10 @@ def __init__(self, t_min: float | None, t_max: float | None, measure_total_lengt self.t_min = ifnone(t_min, -1) self.t_max = ifnone(t_max, float("inf")) self.measure_total_length = measure_total_length + self.enabled = self.t_min > 0 or self.t_max < float("inf") def __call__(self, example) -> bool: - if isinstance(example, Cut): + if not self.enabled or isinstance(example, Cut): return True # does not apply to Cuts assert isinstance(example, Formattable), ( f"TokenCountFilter can only be applied to data examples that derive Formattable class. " diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index 93bacb426adf..b5eb1017f1e2 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -21,14 +21,14 @@ import numpy as np import pytest import torch -from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, SupervisionSegment, compute_num_samples +from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, compute_num_samples from lhotse.audio import AudioLoadingError from lhotse.cut import Cut, MixedCut from lhotse.dataset import RoundRobinSampler, ZipSampler from lhotse.testing.dummies import dummy_recording +from lhotse.testing.random import deterministic_rng from omegaconf import OmegaConf -from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample, TextExample from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model @@ -1456,10 +1456,12 @@ def test_multimodal_text_audio_dataloading_zip_strategy( config = OmegaConf.create( { "multi_config": True, + "sampler_fusion": "zip", # <---- !!! this option is being tested here !!! + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, "audio": { - "sampler_fusion": "zip", # <---- !!! this option is being tested here !!! - "seed": 0, - "shard_seed": 0, "input_cfg": [ { "type": "nemo_tarred", @@ -1470,8 +1472,6 @@ def test_multimodal_text_audio_dataloading_zip_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, "prompt_format": "plain", "use_multimodal_sampling": True, "batch_tokens": BT, @@ -1499,8 +1499,6 @@ def test_multimodal_text_audio_dataloading_zip_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, "use_multimodal_sampling": True, "prompt_format": "plain", "batch_tokens": 64, @@ -1586,10 +1584,12 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( config = OmegaConf.create( { "multi_config": True, + "sampler_fusion": "round_robin", # <---- !!! this option is being tested here !!! + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, "audio": { - "sampler_fusion": "round_robin", # <---- !!! this option is being tested here !!! - "seed": 0, - "shard_seed": 0, "input_cfg": [ { "type": "nemo_tarred", @@ -1600,8 +1600,6 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, "use_multimodal_sampling": True, "prompt_format": "plain", "batch_tokens": BT, @@ -1629,9 +1627,7 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( }, }, ], - "shuffle": True, "prompt_format": "plain", - "num_workers": 0, "use_multimodal_sampling": True, "batch_tokens": BT, # How to set token equivalent duration in actual training? @@ -1688,6 +1684,152 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( assert torch.is_tensor(ex.mask) +def test_multimodal_text_audio_dataloading_randomized_round_robin_strategy( + deterministic_rng, + txt_pair_paths_shards: tuple[str, str], + nemo_tarred_manifest_path_multi: tuple[str, str], + en_es_tokenizer: SentencePieceTokenizer, + questions_path: str, +): + en_paths, es_paths = txt_pair_paths_shards + manifest_filepath, tarred_audio_filepaths = nemo_tarred_manifest_path_multi + QF, BT = 50, 64 + config = OmegaConf.create( + { + "multi_config": True, + "sampler_fusion": "randomized_round_robin", # <---- !!! this option is being tested here !!! + "sampler_weights": { + "audio": 0.5, + "text": 0.5, + }, + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, + "audio": { + "input_cfg": [ + { + "type": "nemo_tarred", + "manifest_filepath": manifest_filepath, + "tarred_audio_filepaths": tarred_audio_filepaths, + "tags": { + "modality": "audio", + }, + }, + ], + "use_multimodal_sampling": True, + "prompt_format": "plain", + "batch_tokens": BT, + # How to set token equivalent duration in actual training? + # assuming fbank frames: 0.01 is the base due to frame shift; + # + subsampling x8 gives us 0.08 + # assuming discrete audio tokens, with frame rate 50Hz, + # we'd get 0.02 + # in this test we'll just use 0.1 for simplicity + "token_equivalent_duration": 0.1, + "quadratic_factor": QF, + }, + "text": { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": en_paths, + "target_paths": es_paths, + "source_language": "en", + "target_language": "es", + "questions_path": questions_path, + "questions_language": "en", + "tags": { + "modality": "text", + }, + }, + ], + "prompt_format": "plain", + "use_multimodal_sampling": True, + "batch_tokens": BT, + # How to set token equivalent duration in actual training? + # assuming fbank frames: 0.01 is the base due to frame shift; + # + subsampling x8 gives us 0.08 + # assuming discrete audio tokens, with frame rate 50Hz, + # we'd get 0.02 + # in this test we'll just use 0.1 for simplicity + "token_equivalent_duration": 0.1, + "quadratic_factor": QF, + }, + } + ) + + dl = get_lhotse_dataloader_from_config( + config=config, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=en_es_tokenizer, + ) + + assert isinstance(dl.dataset.sampler, RoundRobinSampler) + + # Note: we use islice here because the dataloader will be infinite. + batches = [batch for batch in islice(dl, 2)] + + # Batch 0 is audio-only + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) + assert all(isinstance(ex, Cut) for ex in b) + # Batch tokens is not exceeded after applying the quadratic factor correction + assert sum(ex.num_tokens**2 / QF for ex in b) <= BT + for ex in b: + assert ex.modality == "audio" + assert isinstance(ex.load_audio(), np.ndarray) + assert isinstance(ex.supervisions[0].text, str) + + # Batch 1 is text-only + b = batches[1] + assert isinstance(b, lhotse.CutSet) + assert len(b) + assert all(isinstance(ex, SourceTargetTextExample) for ex in b) + # Batch tokens is not exceeded after applying the quadratic factor correction + assert sum(ex.num_tokens**2 / QF for ex in b) <= BT + for ex in b: + assert ex.modality == "text" + assert ex.source.language == "en" + assert ex.target.language == "es" + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) + + +def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): + config = OmegaConf.create( + { + "cuts_path": str(cutset_path), + "noise_path": str(nemo_manifest_path), + "noise_mix_prob": 1.0, + "noise_snr": [-5.0, 5.0], + "batch_size": 2, + "seed": 0, + "shard_seed": 0, + } + ) + dl = get_lhotse_dataloader_from_config( + config=config, + global_rank=0, + world_size=1, + dataset=Identity(), + ) + batch = next(iter(dl)) + assert isinstance(batch, CutSet) + assert len(batch) == 2 + cut = batch[0] + assert isinstance(cut, MixedCut) + assert -5.0 < cut.tracks[1].snr < 5.0 + cut = batch[1] + assert isinstance(cut, MixedCut) + assert -5.0 < cut.tracks[1].snr < 5.0 + + def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): config = OmegaConf.create( { From be47226763d27eb45a53d63e6cec0e4625085e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 11:59:56 -0800 Subject: [PATCH 60/63] fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/multimodal/speech_llm/data/build_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 65812f03cd0e..8d64632210a4 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -111,7 +111,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict # for eval, we need to create separate dataset so as to report splitted numbers else: dls = [] - if hasattr(data_cfg, 'manifest_filepath'): + if data_cfg.get('manifest_filepath') is not None: manifest_filepath = data_cfg.manifest_filepath for cur_manifest_filepath in manifest_filepath: conf = copy.deepcopy(data_cfg) From bc87935de8a567cec87f909d7a84326347cf9274 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 15:11:14 -0500 Subject: [PATCH 61/63] Update default configs to the modified config schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- ...dio_gpt_config_cross_llama_lhotse_multi.yaml | 17 ++++++++++------- .../salm/modular_audio_t5_multi_config.yaml | 17 ++++++++++------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 14b471448b9e..12b568f55f45 100644 --- a/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -240,20 +240,25 @@ model: end_string: "[EOG]" train_ds: use_lhotse: true + seed: 0 + shard_seed: "trng" + num_workers: 4 + shuffle: true + multi_config: true + sampler_fusion: randomized_round_robin + sampler_weights: + audio: 0.5 + text: 0.5 + audio: input_cfg: ??? - sampler_fusion: round_robin - seed: 0 - shard_seed: "trng" batch_size: null batch_duration: 360 quadratic_factor: 15 use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true prompt_format: llama2 text: input_cfg: ??? @@ -263,8 +268,6 @@ model: use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true prompt_format: llama2 global_batch_size: ${model.global_batch_size} diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml index 09f9987eee8f..857c2f2a1c8a 100644 --- a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -220,21 +220,26 @@ model: data: train_ds: use_lhotse: true + seed: 0 + shard_seed: "trng" + num_workers: 4 + shuffle: true + multi_config: true + sampler_fusion: randomized_round_robin + sampler_weights: + audio: 0.5 + text: 0.5 + audio: input_cfg: ??? - sampler_fusion: round_robin prompt_format: t5nmt - seed: 0 - shard_seed: "trng" batch_size: null batch_duration: 360 quadratic_factor: 15 use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true text: input_cfg: ??? prompt_format: t5nmt @@ -244,8 +249,6 @@ model: use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} From 6a02e984bda7c1c29259b47ccf161a43924ade83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 15:18:07 -0500 Subject: [PATCH 62/63] Fix omegaconf use issue MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- .../common/data/lhotse/dataloader.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index b68b3b3c2599..7ad5eb3114a6 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -19,6 +19,7 @@ from typing import Any, Optional, Sequence import numpy as np +import omegaconf import torch from lhotse import CutSet, RecordingSet from lhotse.cut import Cut @@ -227,7 +228,11 @@ def get_lhotse_dataloader_from_config( if config.get("multi_config", False): return get_lhotse_dataloader_from_multi_config( - config=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer + top_level_config=config, + global_rank=global_rank, + world_size=world_size, + dataset=dataset, + tokenizer=tokenizer, ) else: return get_lhotse_dataloader_from_single_config( @@ -305,7 +310,7 @@ def get_lhotse_dataloader_from_single_config( def get_lhotse_dataloader_from_multi_config( - config: DictConfig, + top_level_config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, @@ -328,7 +333,7 @@ def gather_shared_opts(): the ones present in sub-configs. """ assert all( - k in config for k in ["seed", "shard_seed", "shuffle"] + k in top_level_config for k in ["seed", "shard_seed", "shuffle"] ), "In a multi-config setting (multi_config=True), the top-level namespace (typically train_ds) must define at least 'seed', 'shard_seed', and 'shuffle' keys that will be shared by all sub-configs." overwriting_opts = [ "seed", @@ -343,25 +348,24 @@ def gather_shared_opts(): "force_finite", ] defaults = OmegaConf.structured(LhotseDataLoadingConfig) - config["seed"] = resolve_seed(config["seed"]) - return OmegaConf.create({k: config.get(k, defaults[k]) for k in overwriting_opts}) + top_level_config["seed"] = resolve_seed(top_level_config["seed"]) + return OmegaConf.create({k: top_level_config.get(k, defaults[k]) for k in overwriting_opts}) shared_opts = gather_shared_opts() fix_random_seed(shared_opts.seed) configs = { name: c - for name, c in config.items() + for name, c in top_level_config.items() if isinstance(c, DictConfig) and name not in ("sampler_weights",) # exclude dict opts } - for k, v in shared_opts.items(): - for config in configs.values(): - config[k] = v source_samplers, source_use_iterable_dataset = {}, [] for name, config in configs.items(): try: expanded_config = make_structured_with_schema_warnings(config) + for k, v in shared_opts.items(): + expanded_config[k] = v s, t = get_lhotse_sampler_from_config( config=expanded_config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer ) From 2d1243af057867f273321dfd32b37774439f82c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 26 Nov 2024 15:34:37 -0500 Subject: [PATCH 63/63] Update the docs to the modified multi config format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- docs/source/asr/datasets.rst | 9 ++++---- .../Multimodal Lhotse Dataloading.ipynb | 21 +++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index b8fc1cf56441..5214ef31f673 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -964,11 +964,11 @@ Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so t ... train_ds: multi_config: True, - audio: - sampler_fusion: zip - shuffle: true - num_workers: 4 + sampler_fusion: zip + shuffle: true + num_workers: 4 + audio: prompt_format: t5nmt use_bucketing: true min_duration: 0.5 @@ -983,6 +983,7 @@ Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so t weight: 0.5 tags: context: "Translate the following to English" + text: prompt_format: t5nmt use_multimodal_sampling: true diff --git a/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb index 79104f21a3ba..b9ddf350cdca 100644 --- a/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb +++ b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb @@ -768,6 +768,16 @@ " # To enable this behaviour, set multi_config to True.\n", " \"multi_config\": True,\n", " \n", + " # The following fields are shared by all groups.\n", + " # sampler_fusion key determines how to yield batches from different samplers:\n", + " # * \"round_robin\" will just yield one type at a time\n", + " # * \"zip\" will sample a batch for each type and concatenate them, yielding a larger multimodal batch\n", + " # * \"randomized_round_robin\" expects an extra \"sampler_weights\" option which will define sampling probs for each group.:\n", + " \"sampler_fusion\": \"round_robin\",\n", + " \"shuffle\": True,\n", + " \"num_workers\": 0,\n", + " \"seed\": 0,\n", + " \"shard_seed\": \"trng\",\n", " \n", " \"asr\": {\n", " \"input_cfg\": [\n", @@ -787,17 +797,6 @@ " \"concurrent_bucketing\": False,\n", " \"bucket_buffer_size\": 50,\n", " \"shuffle_buffer_size\": 50,\n", - "\n", - " # The first group defines a number of fields that will be later shared by all groups.\n", - " # sampler_fusion key determines how to yield batches from different samplers:\n", - " # * \"round_robin\" will just yield one type at a time\n", - " # * \"zip\" will sample a batch for each type and concatenate them, yielding a larger multimodal batch\n", - " # * \"randomized_round_robin\" expects an extra \"sampler_weights\" option which will define sampling probs for each group.:\n", - " \"sampler_fusion\": \"round_robin\",\n", - " \"shuffle\": True,\n", - " \"num_workers\": 0,\n", - " \"seed\": 0,\n", - " \"shard_seed\": \"trng\",\n", " },\n", "\n", " \"nmt\": {\n",