Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
38651d0
Preliminary support for oomptimizer
pzelasko Aug 16, 2024
a1754be
OOMptimizer for SpeechLLM
pzelasko Aug 19, 2024
97a543e
Initial version of estimate token bins script
pzelasko Aug 19, 2024
c6f0b3d
Initial support for multimodal 2d bucketing
pzelasko Aug 19, 2024
7b52d5b
Extend to text-to-text oomptimizer
pzelasko Aug 19, 2024
f63a110
Preliminary support for Llama2 prompt format in ast+mt
pzelasko Aug 19, 2024
ef96459
Support for 1D estimate token bins
pzelasko Aug 19, 2024
2cae09b
Support for 1D estimate token bins
pzelasko Aug 19, 2024
bdec618
Fix
pzelasko Aug 19, 2024
a7ce8b6
Fix
pzelasko Aug 19, 2024
b3ed44c
Minor tweaks
pzelasko Aug 20, 2024
f7809d6
Add min/max tokens filter
pzelasko Aug 20, 2024
2b26cb0
Change to bisect_left for bucket idx selection
pzelasko Aug 22, 2024
9589023
Add reconfigure_num_microbatches_calculator at the start of train epo…
pzelasko Aug 22, 2024
4f6d4fa
Update lhotse multi-sampler config and make validation datasets finite
pzelasko Aug 22, 2024
049bad5
Initial implementation of text+audio training for T5 modular models
pzelasko Aug 22, 2024
8ca73d2
megatron t5 nmt prompt formatter
pzelasko Aug 26, 2024
b26b5dd
Fixes for MT+AST T5 oomptimizer and training
pzelasko Aug 26, 2024
850e494
configs, fixes, token-per-token filtering
pzelasko Sep 6, 2024
ffd32b1
Support text modality in predict_step
pzelasko Sep 6, 2024
024701f
Support text data in val/test dl
pzelasko Sep 6, 2024
f574e70
fix
pzelasko Sep 6, 2024
2e2b396
fix
pzelasko Sep 6, 2024
dfdac5e
fix
pzelasko Sep 6, 2024
81bd732
fix
pzelasko Sep 6, 2024
00edb6d
fix
pzelasko Sep 6, 2024
2eb6331
fix
pzelasko Sep 6, 2024
cbaed3c
fix
pzelasko Sep 6, 2024
e8ec5a4
fix
pzelasko Sep 6, 2024
8c597b5
fix infinite
pzelasko Sep 7, 2024
14a1896
prompt format fixes
pzelasko Sep 7, 2024
5c382cd
Fixes in audio supervision
pzelasko Sep 10, 2024
6e276ca
remove superficial padding
pzelasko Sep 10, 2024
2fae9f9
test config and prompt context fetching fixes
pzelasko Sep 10, 2024
34f8526
support text-only decoding for salm/bestow
pzelasko Sep 10, 2024
878fef6
fix tests
pzelasko Oct 4, 2024
949f1ef
Add unit tests for EMMETT / refactor prompt_format_fn
pzelasko Oct 4, 2024
d91348e
make t5nmt prompt formatter auto discoverable
pzelasko Oct 4, 2024
39543a9
include token count / tpt filtering in estimate_token_bins
pzelasko Oct 4, 2024
b684750
fix max token filter
pzelasko Oct 4, 2024
6064bb4
some fixes
pzelasko Oct 9, 2024
92c81bb
custom mixin for text adapters
pzelasko Oct 15, 2024
68e27db
Warmup in oomptimizer-speechlm
pzelasko Oct 16, 2024
0c33146
Move oomptimizer-speechllm to separate directory
pzelasko Oct 16, 2024
6667e03
Merge branch 'speechllm-develop-oomptimizer' of https://github.com/nv…
pzelasko Oct 16, 2024
c3ea064
Initial cleanup
pzelasko Oct 16, 2024
2a16008
Refactoring of prompt format fn and length measurement and filtering …
pzelasko Oct 16, 2024
e2245c7
Refactor sampler constraints / filters into sampling.py
pzelasko Oct 17, 2024
27d0386
Tests and support for sampler length measurement of multimodal conver…
pzelasko Oct 17, 2024
30bda10
Update estimate_token_bins.py
pzelasko Oct 17, 2024
968238c
Move estimate_token_bins.py to speech_llm scripts
pzelasko Oct 17, 2024
f7d7453
Minor tweaks
pzelasko Oct 17, 2024
2519249
Fixes for SpeechLLM dataset
pzelasko Oct 17, 2024
69bb7e1
Apply isort and black reformatting
pzelasko Oct 17, 2024
bc7dcde
Add missing emmett tests
pzelasko Oct 17, 2024
c5a02c2
Merge branch 'speechllm-develop' into speechllm-develop-oomptimizer
pzelasko Oct 18, 2024
fb29842
Add tutorial about multimodal lhotse dataloading
pzelasko Oct 18, 2024
e37ffbe
Updated documentation for multimodal dataloading
pzelasko Oct 23, 2024
c6ffb40
Prompt Formatter tutorial
pzelasko Oct 23, 2024
5cbca90
Review comments
pzelasko Oct 29, 2024
4a465e3
Merge branch 'speechllm-develop' into speechllm-develop-oomptimizer
pzelasko Nov 1, 2024
eca816f
Fixes for sampling filters None values
pzelasko Nov 20, 2024
a46ad65
Changes requested by Steve: moving some args to main config namespace…
pzelasko Nov 26, 2024
be47226
fix
pzelasko Nov 26, 2024
bc87935
Update default configs to the modified config schema
pzelasko Nov 26, 2024
6a02e98
Fix omegaconf use issue
pzelasko Nov 26, 2024
2d1243a
Update the docs to the modified multi config format
pzelasko Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 246 additions & 33 deletions docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -744,53 +744,266 @@ The final weight is the product of outer and inner weight:
source_lang: pl
target_lang: en

Configuring multi-modal dataloading
Configuring multimodal dataloading
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Our configuration format supports specifying data sources from other modalities than just audio.
At this time, this support is extended to text-only data. We provide the following parser types:
At this time, this support is extended to audio and text modalities. We provide the following parser types:

* ``txt`` for raw text files, sharded or unsharded. This can represent, for example, language modeling data.
* ``txt_pair`` for pairs of raw text files, sharded or unsharded. This can represent, for example, machine translation data.
**Raw text files.** Simple text files where each line is an individual text example. This can represent standard language modeling data.
This parser is registered under ``type: txt``.

The key strength of this approach is that we can easily combine audio datasets and text datasets,
and benefit from every other technique we described above such as dynamic data mixing, data weighting, dynamic bucketing, and so on.
To enable multimodal dataloading, we provide several configuration options:
Data format examples::

* ``use_multimodal_sampling`` when set to True, we'll discard the settings of ``batch_duration`` and ``quadratic_duration`` and consider the settings below instead.
# file: document_0.txt
This is a language modeling example.
Wall Street is expecting major news tomorrow.

* ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens.
# file: document_1.txt
Invisible bats have stormed the city.
What an incredible event!

* ``token_equivalent_duration`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs. Note that this value acts as a "balancer" between how much audio data vs text data gets sampled into a mini-batch.
Dataloading configuration example::

* ``quadratic_factor`` works the same way as ``quadratic_duration``, but is defined in the number of tokens.
input_cfg:
- type: txt
paths: /path/to/document_{0..1}.txt
language: en # optional

Example 3. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets. Provide a custom prompt field for both datasets (to be leveraged by a relevant dataset class):
Python object example::

.. code-block:: yaml
from nemo.collections.common.data.lhotse.text_adapters import TextExample

example = TextExample(
text="This is a language modeling example.",
language="en", # optional
)

Python dataloader instantiation example::

from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config

dl = get_lhotse_dataloader_from_config({
"input_cfg": [
{"type": "txt", "paths": "/path/to/document_{0..1}.txt", "language": "en"},
],
"use_multimodal_dataloading": True,
"batch_size": 4,
},
global_rank=0,
world_size=1,
dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor]
tokenizer=my_tokenizer,
)

**Raw text file pairs.** Pairs of raw text files with corresponding lines. This can represent machine translation data.
This parser is registered under ``type: txt_pair``.

Data format examples::

# file: document_en_0.txt
This is a machine translation example.
Wall Street is expecting major news tomorrow.

# file: document_pl_0.txt
To jest przykład tłumaczenia maszynowego.
Wall Street spodziewa się jutro ważnych wiadomości.

Dataloading configuration example::

use_multimodal_sampling: true
batch_tokens: 1024
token_equivalent_duration: 0.08 # 0.01 frame shift * 8 subsampling factor
quadratic_factor: 50
num_buckets: 30
use_bucketing: true
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.5
tags:
lang: en
prompt: "Given the following recording, transcribe what the person is saying:"
- type: txt_pair
source_path: /path/to/en__OP_0..512_CL_.txt
target_path: /path/to/pl__OP_0..512_CL_.txt
source_language: en
target_language: pl
weight: 0.5
tags:
prompt: "Translate the following text to Polish:"
source_path: /path/to/document_en_{0..N}.txt
target_path: /path/to/document_pl_{0..N}.txt
source_language: en # optional
target_language: pl # optional

Python object example::

from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample

example = SourceTargetTextExample(
source=TextExample(
text="This is a language modeling example.",
language="en", # optional
),
target=TextExample(
text="To jest przykład tłumaczenia maszynowego.",
language="pl", # optional
),
)

Python dataloader instantiation example::

from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config

dl = get_lhotse_dataloader_from_config({
"input_cfg": [
{
"type": "txt_pair",
"source_path": "/path/to/document_en_{0..N}.txt",
"target_path": "/path/to/document_pl_{0..N}.txt",
"source_language": "en"
"target_language": "en"
},
],
"use_multimodal_dataloading": True,
"prompt_format": "t5nmt",
"batch_size": 4,
},
global_rank=0,
world_size=1,
dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor]
tokenizer=my_tokenizer,
)

**NeMo multimodal conversations.** A JSON-Lines (JSONL) file that defines multi-turn conversations with mixed text and audio turns.
This parser is registered under ``type: multimodal_conversation``.

Data format examples::

# file: chat_0.jsonl
{"id": "conv-0", "conversations": [{"from": "user", "value": "speak to me", "type": "text"}, {"from": "assistant": "value": "/path/to/audio.wav", "duration": 17.1, "type": "audio"}]}

Dataloading configuration example::

token_equivalent_duration: 0.08
input_cfg:
- type: multimodal_conversation
manifest_filepath: /path/to/chat_{0..N}.jsonl
audio_locator_tag: [audio]

Python object example::

from lhotse import Recording
from nemo.collections.common.data.lhotse.text_adapters import MultimodalConversation, TextTurn, AudioTurn

conversation = NeMoMultimodalConversation(
id="conv-0",
turns=[
TextTurn(value="speak to me", role="user"),
AudioTurn(cut=Recording.from_file("/path/to/audio.wav").to_cut(), role="assistant", audio_locator_tag="[audio]"),
],
token_equivalent_duration=0.08, # this value will be auto-inserted by the dataloader
)

Python dataloader instantiation example::

from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config

dl = get_lhotse_dataloader_from_config({
"input_cfg": [
{
"type": "multimodal_conversation",
"manifest_filepath": "/path/to/chat_{0..N}.jsonl",
"audio_locator_tag": "[audio]",
},
],
"use_multimodal_dataloading": True,
"token_equivalent_duration": 0.08,
"prompt_format": "llama2",
"batch_size": 4,
},
global_rank=0,
world_size=1,
dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor]
tokenizer=my_tokenizer,
)

**Dataloading and bucketing of text and multimodal data.** When dataloading text or multimodal data, pay attention to the following config options (we provide example values for convenience):

* ``use_multimodal_sampling: true`` tells Lhotse to switch from measuring audio duration to measuring token counts; required for text.

* ``prompt_format: "prompt-name"`` will apply a specified PromptFormatter during data sampling to accurately reflect its token counts.

* ``measure_total_length: true`` customizes length measurement for decoder-only and encoder-decoder models. Decoder-only models consume a linear sequence of context + answer, so we should measure the total length (``true``). On the other hand, encoder-decoder models deal with two different sequence lengths: input (context) sequence length for the encoder, and output (answer) sequence length for the decoder. For such models set this to ``false``.

* ``min_tokens: 1``/``max_tokens: 4096`` filters examples based on their token count (after applying the prompt format).

* ``min_tpt: 0.1``/``max_tpt: 10`` filter examples based on their output-token-per-input-token-ratio. For example, a ``max_tpt: 10`` means we'll filter every example that has more than 10 output tokens per 1 input token. Very useful for removing sequence length outliers that lead to OOM. Use ``estimate_token_bins.py`` to view token count distributions for calbirating this value.

* (multimodal-only) ``token_equivalent_duration: 0.08`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs.

**Text/multimodal bucketing and OOMptimizer.** Analogous to bucketing for audio data, we provide two scripts to support efficient bucketing:

* ``scripts/speech_llm/estimate_token_bins.py`` which estimates 1D or 2D buckets based on the input config, tokenizer, and prompt format. It also estimates input/output token count distribution and suggested ``max_tpt`` (token-per-token) filtering values.

* (experimental) ``scripts/speech_llm/oomptimizer.py`` which works with SALM/BESTOW GPT/T5 models and estimates the optimal ``bucket_batch_size`` for a given model config and bucket bins value. Given the complexity of Speech LLM some configurations may not be supported yet at the time of writing (e.g., model parallelism).

To enable bucketing, set ``batch_size: null`` and use the following options:

* ``use_bucketing: true``

* ``bucket_duration_bins`` - the output of ``estimate_token_bins.py``. If ``null``, it will be estimated at the start of training at the cost of some run time (not recommended).

* (oomptimizer-only) ``bucket_batch_size`` - the output of OOMptimizer.

* (non-oomptimizer-only) ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens. Note that it's just a heuristic for determining the optimal batch sizes for different buckets, and may be less efficient than using OOMptimizer.

* (non-oomptimizer-only) ``quadratic_factor`` is a quadratic penalty to equalize the GPU memory usage between buckets of short and long sequence lengths for models with quadratic memory usage. It is only a heuristic and may not be as efficient as using OOMptimizer.

**Joint dataloading of text/audio/multimodal data.** The key strength of this approach is that we can easily combine audio datasets and text datasets,
and benefit from every other technique we described in this doc, such as: dynamic data mixing, data weighting, dynamic bucketing, and so on.

This approach is described in the `EMMeTT`_ paper. There's also a notebook tutorial called Multimodal Lhotse Dataloading. We construct a separate sampler (with its own batching settings) for each modality,
and specify how the samplers should be fused together via the option ``sampler_fusion``:

* ``sampler_fusion: "round_robin"`` will iterate single sampler per step, taking turns. For example: step 0 - audio batch, step 1 - text batch, step 2 - audio batch, etc.

* ``sampler_fusion: "randomized_round_robin"`` is similar, but at each chooses a sampler randomly using ``sampler_weights: [w0, w1]`` (weights can be unnormalized).

* ``sampler_fusion: "zip"`` will draw a mini-batch from each sampler at every step, and merge them into a single ``CutSet``. This approach combines well with multimodal gradient accumulation (run forward+backward for one modality, then the other, then the update step).

.. _EMMeTT: https://arxiv.org/abs/2409.13523

Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets:

.. code-block:: yaml

model:
...
train_ds:
multi_config: True,
sampler_fusion: zip
shuffle: true
num_workers: 4

audio:
prompt_format: t5nmt
use_bucketing: true
min_duration: 0.5
max_duration: 30.0
max_tps: 12.0
bucket_duration_bins: [[3.16, 10], [3.16, 22], [5.18, 15], ...]
bucket_batch_size: [1024, 768, 832, ...]
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/manifest__OP_0..512_CL_.json
tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar
weight: 0.5
tags:
context: "Translate the following to English"

text:
prompt_format: t5nmt
use_multimodal_sampling: true
min_tokens: 1
max_tokens: 256
min_tpt: 0.333
max_tpt: 3.0
measure_total_length: false
use_bucketing: true
bucket_duration_bins: [[10, 4], [10, 26], [15, 10], ...]
bucket_batch_size: [512, 128, 192, ...]
input_cfg:
- type: txt_pair
source_path: /path/to/en__OP_0..512_CL_.txt
target_path: /path/to/pl__OP_0..512_CL_.txt
source_language: en
target_language: pl
weight: 0.5
tags:
question: "Translate the following to Polish"

.. caution:: We strongly recommend to use multiple shards for text files as well so that different nodes and dataloading workers are able to randomize the order of text iteration. Otherwise, multi-GPU training has a high risk of duplication of text examples.

Expand Down
Loading