diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 2c0657d1c6ce..5214ef31f673 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -744,53 +744,266 @@ The final weight is the product of outer and inner weight: source_lang: pl target_lang: en -Configuring multi-modal dataloading +Configuring multimodal dataloading ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Our configuration format supports specifying data sources from other modalities than just audio. -At this time, this support is extended to text-only data. We provide the following parser types: +At this time, this support is extended to audio and text modalities. We provide the following parser types: -* ``txt`` for raw text files, sharded or unsharded. This can represent, for example, language modeling data. -* ``txt_pair`` for pairs of raw text files, sharded or unsharded. This can represent, for example, machine translation data. +**Raw text files.** Simple text files where each line is an individual text example. This can represent standard language modeling data. +This parser is registered under ``type: txt``. -The key strength of this approach is that we can easily combine audio datasets and text datasets, -and benefit from every other technique we described above such as dynamic data mixing, data weighting, dynamic bucketing, and so on. -To enable multimodal dataloading, we provide several configuration options: +Data format examples:: -* ``use_multimodal_sampling`` when set to True, we'll discard the settings of ``batch_duration`` and ``quadratic_duration`` and consider the settings below instead. + # file: document_0.txt + This is a language modeling example. + Wall Street is expecting major news tomorrow. -* ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens. + # file: document_1.txt + Invisible bats have stormed the city. + What an incredible event! -* ``token_equivalent_duration`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs. Note that this value acts as a "balancer" between how much audio data vs text data gets sampled into a mini-batch. +Dataloading configuration example:: -* ``quadratic_factor`` works the same way as ``quadratic_duration``, but is defined in the number of tokens. + input_cfg: + - type: txt + paths: /path/to/document_{0..1}.txt + language: en # optional -Example 3. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets. Provide a custom prompt field for both datasets (to be leveraged by a relevant dataset class): +Python object example:: -.. code-block:: yaml + from nemo.collections.common.data.lhotse.text_adapters import TextExample + + example = TextExample( + text="This is a language modeling example.", + language="en", # optional + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + {"type": "txt", "paths": "/path/to/document_{0..1}.txt", "language": "en"}, + ], + "use_multimodal_dataloading": True, + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**Raw text file pairs.** Pairs of raw text files with corresponding lines. This can represent machine translation data. +This parser is registered under ``type: txt_pair``. + +Data format examples:: + + # file: document_en_0.txt + This is a machine translation example. + Wall Street is expecting major news tomorrow. + + # file: document_pl_0.txt + To jest przykład tłumaczenia maszynowego. + Wall Street spodziewa się jutro ważnych wiadomości. + +Dataloading configuration example:: - use_multimodal_sampling: true - batch_tokens: 1024 - token_equivalent_duration: 0.08 # 0.01 frame shift * 8 subsampling factor - quadratic_factor: 50 - num_buckets: 30 - use_bucketing: true input_cfg: - - type: nemo_tarred - manifest_filepath: /path/to/manifest__OP_0..512_CL_.json - tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar - weight: 0.5 - tags: - lang: en - prompt: "Given the following recording, transcribe what the person is saying:" - type: txt_pair - source_path: /path/to/en__OP_0..512_CL_.txt - target_path: /path/to/pl__OP_0..512_CL_.txt - source_language: en - target_language: pl - weight: 0.5 - tags: - prompt: "Translate the following text to Polish:" + source_path: /path/to/document_en_{0..N}.txt + target_path: /path/to/document_pl_{0..N}.txt + source_language: en # optional + target_language: pl # optional + +Python object example:: + + from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample + + example = SourceTargetTextExample( + source=TextExample( + text="This is a language modeling example.", + language="en", # optional + ), + target=TextExample( + text="To jest przykład tłumaczenia maszynowego.", + language="pl", # optional + ), + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + { + "type": "txt_pair", + "source_path": "/path/to/document_en_{0..N}.txt", + "target_path": "/path/to/document_pl_{0..N}.txt", + "source_language": "en" + "target_language": "en" + }, + ], + "use_multimodal_dataloading": True, + "prompt_format": "t5nmt", + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**NeMo multimodal conversations.** A JSON-Lines (JSONL) file that defines multi-turn conversations with mixed text and audio turns. +This parser is registered under ``type: multimodal_conversation``. + +Data format examples:: + + # file: chat_0.jsonl + {"id": "conv-0", "conversations": [{"from": "user", "value": "speak to me", "type": "text"}, {"from": "assistant": "value": "/path/to/audio.wav", "duration": 17.1, "type": "audio"}]} + +Dataloading configuration example:: + + token_equivalent_duration: 0.08 + input_cfg: + - type: multimodal_conversation + manifest_filepath: /path/to/chat_{0..N}.jsonl + audio_locator_tag: [audio] + +Python object example:: + + from lhotse import Recording + from nemo.collections.common.data.lhotse.text_adapters import MultimodalConversation, TextTurn, AudioTurn + + conversation = NeMoMultimodalConversation( + id="conv-0", + turns=[ + TextTurn(value="speak to me", role="user"), + AudioTurn(cut=Recording.from_file("/path/to/audio.wav").to_cut(), role="assistant", audio_locator_tag="[audio]"), + ], + token_equivalent_duration=0.08, # this value will be auto-inserted by the dataloader + ) + +Python dataloader instantiation example:: + + from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config + + dl = get_lhotse_dataloader_from_config({ + "input_cfg": [ + { + "type": "multimodal_conversation", + "manifest_filepath": "/path/to/chat_{0..N}.jsonl", + "audio_locator_tag": "[audio]", + }, + ], + "use_multimodal_dataloading": True, + "token_equivalent_duration": 0.08, + "prompt_format": "llama2", + "batch_size": 4, + }, + global_rank=0, + world_size=1, + dataset=MyDatasetClass(), # converts CutSet -> dict[str, Tensor] + tokenizer=my_tokenizer, + ) + +**Dataloading and bucketing of text and multimodal data.** When dataloading text or multimodal data, pay attention to the following config options (we provide example values for convenience): + +* ``use_multimodal_sampling: true`` tells Lhotse to switch from measuring audio duration to measuring token counts; required for text. + +* ``prompt_format: "prompt-name"`` will apply a specified PromptFormatter during data sampling to accurately reflect its token counts. + +* ``measure_total_length: true`` customizes length measurement for decoder-only and encoder-decoder models. Decoder-only models consume a linear sequence of context + answer, so we should measure the total length (``true``). On the other hand, encoder-decoder models deal with two different sequence lengths: input (context) sequence length for the encoder, and output (answer) sequence length for the decoder. For such models set this to ``false``. + +* ``min_tokens: 1``/``max_tokens: 4096`` filters examples based on their token count (after applying the prompt format). + +* ``min_tpt: 0.1``/``max_tpt: 10`` filter examples based on their output-token-per-input-token-ratio. For example, a ``max_tpt: 10`` means we'll filter every example that has more than 10 output tokens per 1 input token. Very useful for removing sequence length outliers that lead to OOM. Use ``estimate_token_bins.py`` to view token count distributions for calbirating this value. + +* (multimodal-only) ``token_equivalent_duration: 0.08`` is used to be able to measure audio examples in the number of "tokens". For example, if we're using fbank with 0.01s frame shift and an acoustic model that has a subsampling factor of 0.08, then a reasonable setting for this could be 0.08 (which means every subsampled frame counts as one token). Calibrate this value to fit your needs. + +**Text/multimodal bucketing and OOMptimizer.** Analogous to bucketing for audio data, we provide two scripts to support efficient bucketing: + +* ``scripts/speech_llm/estimate_token_bins.py`` which estimates 1D or 2D buckets based on the input config, tokenizer, and prompt format. It also estimates input/output token count distribution and suggested ``max_tpt`` (token-per-token) filtering values. + +* (experimental) ``scripts/speech_llm/oomptimizer.py`` which works with SALM/BESTOW GPT/T5 models and estimates the optimal ``bucket_batch_size`` for a given model config and bucket bins value. Given the complexity of Speech LLM some configurations may not be supported yet at the time of writing (e.g., model parallelism). + +To enable bucketing, set ``batch_size: null`` and use the following options: + +* ``use_bucketing: true`` + +* ``bucket_duration_bins`` - the output of ``estimate_token_bins.py``. If ``null``, it will be estimated at the start of training at the cost of some run time (not recommended). + +* (oomptimizer-only) ``bucket_batch_size`` - the output of OOMptimizer. + +* (non-oomptimizer-only) ``batch_tokens`` is the maximum number of tokens we want to find inside a mini-batch. Similarly to ``batch_duration``, this number does consider padding tokens too, therefore enabling bucketing is recommended to maximize the ratio of real vs padding tokens. Note that it's just a heuristic for determining the optimal batch sizes for different buckets, and may be less efficient than using OOMptimizer. + +* (non-oomptimizer-only) ``quadratic_factor`` is a quadratic penalty to equalize the GPU memory usage between buckets of short and long sequence lengths for models with quadratic memory usage. It is only a heuristic and may not be as efficient as using OOMptimizer. + +**Joint dataloading of text/audio/multimodal data.** The key strength of this approach is that we can easily combine audio datasets and text datasets, +and benefit from every other technique we described in this doc, such as: dynamic data mixing, data weighting, dynamic bucketing, and so on. + +This approach is described in the `EMMeTT`_ paper. There's also a notebook tutorial called Multimodal Lhotse Dataloading. We construct a separate sampler (with its own batching settings) for each modality, +and specify how the samplers should be fused together via the option ``sampler_fusion``: + +* ``sampler_fusion: "round_robin"`` will iterate single sampler per step, taking turns. For example: step 0 - audio batch, step 1 - text batch, step 2 - audio batch, etc. + +* ``sampler_fusion: "randomized_round_robin"`` is similar, but at each chooses a sampler randomly using ``sampler_weights: [w0, w1]`` (weights can be unnormalized). + +* ``sampler_fusion: "zip"`` will draw a mini-batch from each sampler at every step, and merge them into a single ``CutSet``. This approach combines well with multimodal gradient accumulation (run forward+backward for one modality, then the other, then the update step). + +.. _EMMeTT: https://arxiv.org/abs/2409.13523 + +Example. Combine an ASR (audio-text) dataset with an MT (text-only) dataset so that mini-batches have some examples from both datasets: + +.. code-block:: yaml + + model: + ... + train_ds: + multi_config: True, + sampler_fusion: zip + shuffle: true + num_workers: 4 + + audio: + prompt_format: t5nmt + use_bucketing: true + min_duration: 0.5 + max_duration: 30.0 + max_tps: 12.0 + bucket_duration_bins: [[3.16, 10], [3.16, 22], [5.18, 15], ...] + bucket_batch_size: [1024, 768, 832, ...] + input_cfg: + - type: nemo_tarred + manifest_filepath: /path/to/manifest__OP_0..512_CL_.json + tarred_audio_filepath: /path/to/tarred_audio/audio__OP_0..512_CL_.tar + weight: 0.5 + tags: + context: "Translate the following to English" + + text: + prompt_format: t5nmt + use_multimodal_sampling: true + min_tokens: 1 + max_tokens: 256 + min_tpt: 0.333 + max_tpt: 3.0 + measure_total_length: false + use_bucketing: true + bucket_duration_bins: [[10, 4], [10, 26], [15, 10], ...] + bucket_batch_size: [512, 128, 192, ...] + input_cfg: + - type: txt_pair + source_path: /path/to/en__OP_0..512_CL_.txt + target_path: /path/to/pl__OP_0..512_CL_.txt + source_language: en + target_language: pl + weight: 0.5 + tags: + question: "Translate the following to Polish" .. caution:: We strongly recommend to use multiple shards for text files as well so that different nodes and dataloading workers are able to randomize the order of text iteration. Otherwise, multi-GPU training has a high risk of duplication of text examples. diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml similarity index 87% rename from examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml rename to examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml index 52149da6a570..12b568f55f45 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml +++ b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse_multi.yaml @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - # This configuration is similar to modular_audio_gpt_config_cross_llama_lhotse.yaml, # with the difference being in how it performs multimodal sampling. # The changes are in model.data.train_ds section. @@ -25,11 +24,15 @@ # or zip (sample mini-batch from each and combine them). name: megatron_audio_gpt_bestow_lhotse_multi_sampler +# Note: This config has been updated to work with PromptFormatter API. +# If you used an older version that defined a `train_ds.prompt_template` field, +# you should specify the prompt format using `train_ds..prompt_format` now instead. + trainer: devices: 1 accelerator: gpu num_nodes: 1 - precision: 16 + precision: bf16-mixed logger: False # logger provided by exp_manager enable_checkpointing: False use_distributed_sampler: False @@ -237,20 +240,26 @@ model: end_string: "[EOG]" train_ds: use_lhotse: true + seed: 0 + shard_seed: "trng" + num_workers: 4 + shuffle: true + multi_config: true + sampler_fusion: randomized_round_robin + sampler_weights: + audio: 0.5 + text: 0.5 + audio: input_cfg: ??? - sampler_fusion: round_robin - seed: 0 - shard_seed: "trng" batch_size: null batch_duration: 360 quadratic_factor: 15 use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true + prompt_format: llama2 text: input_cfg: ??? use_multimodal_sampling: true @@ -259,15 +268,14 @@ model: use_bucketing: true num_buckets: 30 bucket_buffer_size: 20000 - num_workers: 4 - shuffle: true + prompt_format: llama2 global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} max_seq_length: 2048 min_seq_length: 1 - context_key: 'input' - label_key: 'output' + context_key: 'context' + answer_key: 'answer' add_eos: True # add_eos: False end_string: ${model.data.end_string} @@ -276,20 +284,20 @@ model: separate_prompt_and_response_with_newline: False truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{input}[/INST] {output}" validation_ds: manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg global_batch_size: ${model.global_batch_size} micro_batch_size: ${model.micro_batch_size} shuffle: False - num_workers: 0 + num_workers: 1 pin_memory: True max_seq_length: 2048 min_seq_length: 1 drop_last: False context_key: ${model.data.train_ds.context_key} - label_key: ${model.data.train_ds.label_key} + answer_key: ${model.data.train_ds.answer_key} add_eos: ${model.data.train_ds.add_eos} end_string: ${model.data.end_string} add_sep: ${model.data.train_ds.add_sep} @@ -299,7 +307,6 @@ model: output_file_path_prefix: null # Prefix of the file to write predictions to. truncation_field: "context" # Options: ['context', 'answer'] index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" tokens_to_generate: 128 # ASR configs sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} @@ -310,31 +317,31 @@ model: average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. num_classes: null - # test_ds: - # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. - # names: null # Names of the corresponding datasets used to log metrics. - # global_batch_size: ${model.global_batch_size} - # micro_batch_size: ${model.micro_batch_size} - # shuffle: False - # num_workers: 4 - # pin_memory: True - # max_seq_length: 2048 - # min_seq_length: 1 - # drop_last: False - # context_key: 'input' - # label_key: 'output' - # add_eos: ${model.data.train_ds.add_eos} - # end_string: ${model.data.end_string} - # add_sep: ${model.data.train_ds.add_sep} - # add_bos: ${model.data.train_ds.add_bos} - # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} - # write_predictions_to_file: False - # output_file_path_prefix: null # Prefix of the file to write predictions to. - # truncation_field: "context" # Options: ['context', 'answer'] - # index_mapping_dir: null # Path to a directory to write index mapping files. - # prompt_template: ${model.data.train_ds.prompt_template} - # # ASR configs - # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + test_ds: + manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 1 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} # metric: # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml index 62b9030b4708..658485aa6807 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml @@ -104,7 +104,7 @@ model: prompt_template: ${data.train_ds.prompt_template} # don't change, let hydra resolve from saved config tokens_to_generate: 512 log_every_n_steps: 1 - sample_rate: ${data.train_ds.sample_rate} # don't change, let hydra resolve from saved config + sample_rate: 16000 # don't change, let hydra resolve from saved config audio_locator: null # set it to allow multiple audios in a sample, e.g. '|audio|', and use it in the context field of manifest to specify the locations of audios (`audio_filepath` is a list of audios). metric: diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml new file mode 100644 index 000000000000..857c2f2a1c8a --- /dev/null +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_multi_config.yaml @@ -0,0 +1,342 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This configuration is similar to modular_audio_t5_multi_config.yaml, +# with the difference being in how it performs multimodal sampling. +# The changes are in model.data.train_ds section. +# You'll notice that it defines two sub-sections: audio and text. +# Their names are arbitrary in the sense that you may define more subsections as you like, also with repeated modalities. +# We still set up a single dataloader, but each sub-section produces its own sampler with its own batch size related settings. +# That means each sub-section may decide about its own static/dynamic batch sizes, bucketing, etc. +# These different samplers are later combined into a single sampler using one of three available sampler fusion strategies: +# round_robin (taking turns), randomized_round_robin (at each step select a sampler according to weights), +# or zip (sample mini-batch from each and combine them). +name: megatron_audio_t5_salm_lhotse_multi_sampler + +# Note: This config has been updated to work with PromptFormatter API. +# If you used an older version that defined a `train_ds.prompt_template` field, +# you should specify the prompt format using `train_ds..prompt_format` now instead. + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16-mixed + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_t5_models.ModularizedAudioT5Model + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + virtual_prompt_style: 'no-prompts' # make cls happy + audio_prompt_first: False + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + global_batch_size: 128 + micro_batch_size: 4 + language_model_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: True + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # override_vocab_size: 1024 + + lora_tuning: + kqv_adapter_dim: 128 + kv_adapter_dim: 64 + q_adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + + peft: + peft_scheme: "adapter" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + train_ds: + use_lhotse: true + seed: 0 + shard_seed: "trng" + num_workers: 4 + shuffle: true + + multi_config: true + sampler_fusion: randomized_round_robin + sampler_weights: + audio: 0.5 + text: 0.5 + + audio: + input_cfg: ??? + prompt_format: t5nmt + batch_size: null + batch_duration: 360 + quadratic_factor: 15 + use_bucketing: true + num_buckets: 30 + bucket_buffer_size: 20000 + text: + input_cfg: ??? + prompt_format: t5nmt + use_multimodal_sampling: true + batch_tokens: 8000 + quadratic_factor: 192 + use_bucketing: true + num_buckets: 30 + bucket_buffer_size: 20000 + + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + max_seq_length: 2048 + min_seq_length: 1 + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + add_sep: True + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + + validation_ds: + force_finite: true # workaround to allow using input_cfg + prompt_format: t5nmt + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 1 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 + + log_every_n_steps: 10 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + test_ds: + manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + force_finite: true # workaround to allow using input_cfg + prompt_format: t5nmt + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 1 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + # ASR configs + sample_rate: 16000 + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index 89dcc61655e8..51935bbbfdcd 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -19,6 +19,8 @@ from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors +from nemo.collections.common.data import apply_prompt_format_fn +from nemo.collections.common.prompts import CanaryPromptFormatter, PromptFormatter from nemo.collections.common.tokenizers import TokenizerSpec @@ -62,28 +64,27 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: TokenizerSpec, - prompt_format_fn: Callable[ - [CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]] - ], + prompt: PromptFormatter, ): super().__init__() self.tokenizer = tokenizer self.load_audio = AudioSamples(fault_tolerant=True) self.padding_value = self.tokenizer.pad - self.prompt_format_fn = prompt_format_fn + self.prompt = prompt def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) # Fast-path: the tokenization and prompt formatting was already done before sampling. - attrs = ("tokenized_prompt", "tokenized_transcript", "tokenized_prompted_transcript") + attrs = ("input_ids", "context_ids", "answer_ids") pre_formatted = all(hasattr(c, a) for c in cuts for a in attrs) if pre_formatted: - prompts_with_answers, prompts, answers = zip( - *((c.tokenized_prompted_transcript, c.tokenized_prompt, c.tokenized_transcript) for c in cuts) - ) + prompts_with_answers, prompts, answers = zip(*((c.input_ids, c.context_ids, c.answer_ids) for c in cuts)) else: - prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer) + formatted = [apply_prompt_format_fn(cut, self.prompt) for cut in cuts] + prompts_with_answers = [ex["input_ids"] for ex in formatted] + prompts = [ex["context_ids"] for ex in formatted] + answers = [ex["answer_ids"] for ex in formatted] transcript, transcript_lens = self._collate_tokens(answers) prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index d2d2213be6e6..454c79ee4e87 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -44,10 +44,10 @@ from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.prompt_fn import get_prompt_format_fn from nemo.collections.common.metrics import GlobalAverageLossMetric from nemo.collections.common.parts import transformer_weights_init from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.common.prompts.fn import get_prompt_format_fn from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.core.classes.common import typecheck from nemo.core.neural_types import ( @@ -510,7 +510,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): world_size=self.world_size, dataset=PromptedAudioToTextLhotseDataset( tokenizer=self.tokenizer, - prompt_format_fn=get_prompt_format_fn(self.prompt_format), + prompt=self.prompt, ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/common/data/__init__.py b/nemo/collections/common/data/__init__.py index ecc67ef05ea5..d4b43d2b4edc 100644 --- a/nemo/collections/common/data/__init__.py +++ b/nemo/collections/common/data/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from nemo.collections.common.data.dataset import CodeSwitchedDataset, ConcatDataset, ConcatMapDataset +from nemo.collections.common.data.lhotse import * +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn, get_prompt_format_fn diff --git a/nemo/collections/common/data/lhotse/__init__.py b/nemo/collections/common/data/lhotse/__init__.py index 6bbe9e991236..95f0d01db297 100644 --- a/nemo/collections/common/data/lhotse/__init__.py +++ b/nemo/collections/common/data/lhotse/__init__.py @@ -13,4 +13,15 @@ # limitations under the License. from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config -from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.dataloader import ( + LhotseDataLoadingConfig, + get_lhotse_dataloader_from_config, + get_lhotse_sampler_from_config, +) +from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator, LazyNeMoTarredIterator +from nemo.collections.common.data.lhotse.text_adapters import ( + NeMoMultimodalConversation, + NeMoSFTExample, + SourceTargetTextExample, + TextExample, +) diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 24c0ffaf59b7..406aa558bb15 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -17,7 +17,7 @@ from functools import partial from itertools import repeat from pathlib import Path -from typing import Sequence, Tuple, Union +from typing import KeysView, Sequence, Tuple, Union import omegaconf from lhotse import CutSet, Features, Recording @@ -35,46 +35,84 @@ from nemo.collections.common.parts.preprocessing.manifest import get_full_path -def read_cutset_from_config(config: DictConfig) -> Tuple[CutSet, bool]: +def read_cutset_from_config(config: DictConfig | dict) -> Tuple[CutSet, bool]: """ Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests. Returns a tuple of ``CutSet`` and a boolean indicating whether the data is tarred (True) or not (False). """ # First, check if the dataset is specified in the new configuration format and use it if possible. + if not isinstance(config, DictConfig): + config = DictConfig(config) if config.get("input_cfg") is not None: return read_dataset_config(config) # Now, we'll figure out if we should read Lhotse manifest or NeMo manifest. use_nemo_manifest = all(config.get(opt) is None for opt in ("cuts_path", "shar_path")) if use_nemo_manifest: - assert ( - config.get("manifest_filepath") is not None - ), "You must specify either: manifest_filepath, cuts_path, or shar_path" - is_tarred = config.get("tarred_audio_filepaths") is not None + if config.get("manifest_filepath") is None: + raise IncompleteConfigError("You must specify either: manifest_filepath, cuts_path, or shar_path") + cuts, is_tarred = read_nemo_manifest(config) else: - is_tarred = config.get("shar_path") is not None - if use_nemo_manifest: - # Read NeMo manifest -- use the right wrapper depending on tarred/non-tarred. - cuts = read_nemo_manifest(config, is_tarred) - else: - # Read Lhotse manifest (again handle both tarred(shar)/non-tarred). - cuts = read_lhotse_manifest(config, is_tarred) + cuts, is_tarred = read_lhotse_manifest(config) return cuts, is_tarred -KNOWN_DATASET_CONFIG_TYPES = frozenset( - ( - "nemo", - "nemo_tarred", - "lhotse", - "lhotse_shar", - "txt", - "txt_pair", - "nemo_sft_jsonl", - "multimodal_conversation", - "group", - ) -) +class IncompleteConfigError(RuntimeError): + pass + + +KNOWN_DATA_CONFIG_TYPES = {} + + +def get_known_config_data_types() -> KeysView[str]: + """ + Return the names of all registered data type parsers. + + Example: + + >>> get_known_config_data_types() + ["nemo", "nemo_tarred", "lhotse", ...] + """ + return KNOWN_DATA_CONFIG_TYPES.keys() + + +def get_parser_fn(data_type_name: str): + """ + Return the parsing function for a given data type name. + Parsing function reads a dataloading config and returns a tuple + of lhotse ``CutSet`` and boolean indicating whether we should use + iterable dataset (True) or map dataset (False) mechanism ("is tarred"). + """ + return KNOWN_DATA_CONFIG_TYPES[data_type_name] + + +def data_type_parser(name: str | list[str]): + """ + Decorator used to register data type parser functions. + Parsing function reads a dataloading config and returns a tuple + of lhotse ``CutSet`` and boolean indicating whether we should use + iterable dataset (True) or map dataset (False) mechanism ("is tarred"). + + Example: + + >>> @data_type_parser("my_new_format") + ... def my_new_format(config): + ... return CutSet(read_my_format(**config)), True + ... + ... fn = get_parser_fn("my_new_format") + ... cuts, is_tarred = fn({"my_arg_0": ..., "my_arg_1": ..., ...}) + """ + + def _decorator(fn): + global KNOWN_DATA_CONFIG_TYPES + if isinstance(name, str): + KNOWN_DATA_CONFIG_TYPES[name] = fn + else: + for n in name: + KNOWN_DATA_CONFIG_TYPES[n] = fn + return fn + + return _decorator def read_dataset_config(config) -> tuple[CutSet, bool]: @@ -140,13 +178,14 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: tgt_lang: en """ propagate_attrs = { - "shuffle": config.shuffle, - "shard_seed": config.shard_seed, - "text_field": config.text_field, - "lang_field": config.lang_field, - "metadata_only": config.metadata_only, - "force_finite": config.force_finite, - "max_open_streams": config.max_open_streams, + "shuffle": config.get("shuffle", False), + "shard_seed": config.get("shard_seed", "trng"), + "text_field": config.get("text_field", "text"), + "lang_field": config.get("lang_field", "lang"), + "metadata_only": config.get("metadata_only", False), + "force_finite": config.get("force_finite", False), + "max_open_streams": config.get("max_open_streams", None), + "token_equivalent_duration": config.get("token_equivalent_duration", None), } input_cfg = config.input_cfg if isinstance(input_cfg, (str, Path)): @@ -157,99 +196,89 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: def parse_group(grp_cfg: DictConfig, propagate_attrs: dict) -> [CutSet, bool]: - assert grp_cfg.type in KNOWN_DATASET_CONFIG_TYPES, f"Unknown item type in dataset config list: {grp_cfg.type=}" - if grp_cfg.type == "nemo_tarred": - is_tarred = True - cuts = read_nemo_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "nemo": - is_tarred = False - cuts = read_nemo_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "lhotse_shar": - is_tarred = True - cuts = read_lhotse_manifest(grp_cfg, is_tarred=is_tarred) - elif grp_cfg.type == "lhotse": - is_tarred = False - cuts = read_lhotse_manifest(grp_cfg, is_tarred=is_tarred) - # Note: "txt" and "txt_pair" have "is_tarred" set to True. - # The main reason is to enable combination of tarred audio and text dataloading, - # since we don't allow combination of tarred and non-tarred datasets. - # We choose to treat text as-if it was tarred, which also tends to be more + assert grp_cfg.type in get_known_config_data_types(), f"Unknown item type in dataset config list: {grp_cfg.type=}" + + # Note: Text data types will return is_tarred=True. + # We choose to treat text as-if it was tarred, which tends to be more # efficient as it moves the text file iteration into dataloading subprocess. - elif grp_cfg.type == "txt": - is_tarred = True - cuts = read_txt_paths(grp_cfg) - elif grp_cfg.type == "txt_pair": - is_tarred = True - cuts = read_txt_pair_paths(grp_cfg) - elif grp_cfg.type == "nemo_sft_jsonl": - is_tarred = True - cuts = read_nemo_sft_jsonl(grp_cfg) - elif grp_cfg.type == "multimodal_conversation": - is_tarred = True - cuts = read_multimodal_conversation_jsonl(grp_cfg) - elif grp_cfg.type == "group": + if grp_cfg.type != "group": + parser_fn = get_parser_fn(grp_cfg.type) + cuts, is_tarred = parser_fn(grp_cfg) + else: cuts, is_tarred = parse_and_combine_datasets( grp_cfg.input_cfg, propagate_attrs=propagate_attrs, ) - else: - raise ValueError(f"Unrecognized group: {grp_cfg.type}") # Attach extra tags to every utterance dynamically, if provided. if (extra_tags := grp_cfg.get("tags")) is not None: cuts = cuts.map(partial(attach_tags, tags=extra_tags), apply_fn=None) return cuts, is_tarred -def read_txt_paths(config: DictConfig) -> CutSet: - return CutSet( +@data_type_parser("txt") +def read_txt_paths(config: DictConfig) -> tuple[CutSet, bool]: + cuts = CutSet( LhotseTextAdapter( paths=config.paths, language=config.language, shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts, True -def read_txt_pair_paths(config: DictConfig) -> CutSet: - return CutSet( +@data_type_parser("txt_pair") +def read_txt_pair_paths(config: DictConfig) -> tuple[CutSet, bool]: + cuts = CutSet( LhotseTextPairAdapter( source_paths=config.source_paths, target_paths=config.target_paths, - source_language=config.source_language, - target_language=config.target_language, - questions_path=config.questions_path, - questions_language=config.questions_language, + source_language=config.get("source_language"), + target_language=config.get("target_language"), + questions_path=config.get("questions_path"), + questions_language=config.get("questions_language"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts, True -def read_nemo_sft_jsonl(config: DictConfig) -> CutSet: - return CutSet( +@data_type_parser("nemo_sft_jsonl") +def read_nemo_sft_jsonl(config: DictConfig) -> tuple[CutSet, bool]: + cuts = CutSet( NeMoSFTJsonlAdapter( paths=config.paths, - language=config.language, + language=config.get("language"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) - ).repeat() + ) + if not config.get("force_finite", False): + cuts = cuts.repeat() + return cuts, True -def read_multimodal_conversation_jsonl(config: DictConfig) -> CutSet: +@data_type_parser("multimodal_conversation") +def read_multimodal_conversation_jsonl(config: DictConfig) -> tuple[CutSet, bool]: cuts = CutSet( NeMoMultimodalConversationJsonlAdapter( manifest_filepath=config.manifest_filepath, tarred_audio_filepaths=config.get("tarred_audio_filepaths"), audio_locator_tag=config.audio_locator_tag, + token_equivalent_duration=config.get("token_equivalent_duration"), shuffle_shards=config.shuffle, shard_seed=config.shard_seed, ) ) if not config.get("force_finite", False): cuts = cuts.repeat() - return cuts + return cuts, True def attach_tags(cut, tags: dict): @@ -258,6 +287,7 @@ def attach_tags(cut, tags: dict): return cut +@data_type_parser("group") def parse_and_combine_datasets( config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict ) -> tuple[CutSet, bool]: @@ -303,7 +333,9 @@ def parse_and_combine_datasets( return cuts, tarred_status[0] -def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: +@data_type_parser(["lhotse", "lhotse_shar"]) +def read_lhotse_manifest(config) -> tuple[CutSet, bool]: + is_tarred = config.get("shar_path") is not None if is_tarred: # Lhotse Shar is the equivalent of NeMo's native "tarred" dataset. # The combination of shuffle_shards, and repeat causes this to @@ -372,7 +404,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: # Regular Lhotse manifest points to individual audio files (like native NeMo manifest). path = config.cuts_path cuts = CutSet.from_file(path).map(partial(resolve_relative_paths, manifest_path=path)) - return cuts + return cuts, is_tarred def _resolve_shar_inputs(path: str | Path, only_metadata: bool) -> dict: @@ -430,7 +462,8 @@ def resolve_array(value): return cut -def read_nemo_manifest(config, is_tarred: bool) -> CutSet: +@data_type_parser(["nemo", "nemo_tarred"]) +def read_nemo_manifest(config) -> tuple[CutSet, bool]: common_kwargs = { "text_field": config.text_field, "lang_field": config.lang_field, @@ -447,6 +480,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: notar_kwargs = {"metadata_only": config.metadata_only} metadata_only = config.metadata_only force_finite = config.force_finite + is_tarred = config.get("tarred_audio_filepaths") is not None if isinstance(config.manifest_filepath, (str, Path)): logging.info(f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'") if is_tarred and not metadata_only: @@ -526,7 +560,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: seed=config.shard_seed, force_finite=force_finite or metadata_only, ) - return cuts + return cuts, is_tarred def mux( diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index de5bd83263a2..7ad5eb3114a6 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import bisect import os import random import warnings from dataclasses import dataclass from functools import partial -from typing import Any, List, Optional, Sequence, TypeVar, Union +from typing import Any, Optional, Sequence import numpy as np +import omegaconf import torch from lhotse import CutSet, RecordingSet from lhotse.cut import Cut @@ -34,20 +34,27 @@ make_worker_init_fn, ) from lhotse.dataset.dataloading import resolve_seed -from lhotse.dataset.sampling.base import CutSampler, SamplingConstraint, TimeConstraint, TokenConstraint -from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.dataset.sampling.base import CutSampler, TimeConstraint from lhotse.lazy import LazyFlattener from lhotse.utils import fastcopy, fix_random_seed -from omegaconf import DictConfig, ListConfig, OmegaConf - -from nemo.collections.common.data.lhotse.cutset import guess_parse_cutset, read_cutset_from_config -from nemo.collections.common.data.lhotse.text_adapters import ( - NeMoMultimodalConversation, - NeMoSFTExample, - SourceTargetTextExample, - TextExample, +from omegaconf import DictConfig, OmegaConf + +from nemo.collections.common.data.lhotse.cutset import ( + IncompleteConfigError, + guess_parse_cutset, + read_cutset_from_config, +) +from nemo.collections.common.data.lhotse.sampling import ( + DurationFilter, + FixedBucketBatchSizeConstraint2D, + MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, + TokenCountFilter, + TokenPerSecondFilter, + TokenPerTokenFilter, ) -from nemo.collections.common.prompts.fn import get_prompt_format_fn +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn +from nemo.collections.common.prompts import PromptFormatter from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper from nemo.utils import logging @@ -90,8 +97,12 @@ class LhotseDataLoadingConfig: shard_seed: int | str = "trng" max_open_streams: int | None = None cuda_expandable_segments: bool = True - sampler_fusion: str = "mux" # mux | zip | round_robin | randomized_round_robin - sampler_weights: list[float] | None = None # only applicable to randomized_round_robin + # e. Multi-config related options. + # Setting multi_config=True will scan the config for keys with DictConfig values, + # create a separate sampler for each, and fuse the samplers according to sampler_fusion. + multi_config: bool = False + sampler_fusion: str = "round_robin" # round_robin | randomized_round_robin | zip + sampler_weights: dict[str, float] | None = None # only applicable to randomized_round_robin # 2.1 Multimodal sampling override options pretokenize: bool = True # should we apply tokenizer before data sampling @@ -101,17 +112,28 @@ class LhotseDataLoadingConfig: batch_tokens: int | None = None quadratic_factor: float | None = None + # 2.2 Filters on sequence lengths. + # * Speech input + min_duration: float | None = -1 + max_duration: float | None = float("inf") + min_tps: int = -1 # allowed tokens per second (audio-only) + max_tps: float = float("inf") + # * Text input + min_tokens: int | None = None + max_tokens: int | None = None + # When true, combine context+answer lengths into a total length; otherwise report context length. + # For 2D bucketing it's always false, as we report a tuple of (context_len, answer_len). + measure_total_length: bool = True + min_tpt: int = -1 # allowed tokens per token (text-only) + max_tpt: float = float("inf") + # 3. Supported existing NeMo options. shuffle: bool = False sample_rate: int = 16000 - min_duration: float | None = -1 - max_duration: float | None = float("inf") seed: int | str = 0 num_workers: int = 0 pin_memory: bool = False channel_selector: int | str | None = None - min_tps: int = -1 # allowed tokens per second - max_tps: float = float("inf") # 4. Optional Lhotse data augmentation. # a. On-the-fly noise/audio mixing. @@ -173,7 +195,7 @@ class LhotseDataLoadingConfig: def get_lhotse_dataloader_from_config( - config: DictConfig, + config: dict | DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, @@ -198,9 +220,19 @@ def get_lhotse_dataloader_from_config( If "prompt_format" is additionally provided in the config, we will also apply a prompt formatter. Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ - if config.get("multi_config"): + if not isinstance(config, DictConfig): + config = OmegaConf.create(config) + + # Providing default value because we haven't filled the config defaults yet. + maybe_set_cuda_expandable_segments(enabled=config.get("cuda_expandable_segments", True)) + + if config.get("multi_config", False): return get_lhotse_dataloader_from_multi_config( - configs=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer + top_level_config=config, + global_rank=global_rank, + world_size=world_size, + dataset=dataset, + tokenizer=tokenizer, ) else: return get_lhotse_dataloader_from_single_config( @@ -240,16 +272,10 @@ def get_lhotse_dataloader_from_single_config( config = make_structured_with_schema_warnings(config) - maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) - # First, resolve the random seed in case a string value was provided. config.seed = resolve_seed(config.seed) fix_random_seed(config.seed) - assert config.sampler_fusion == "mux", ( - "In order to use a sampler_fusion strategy different than 'mux', " - "create your dataloader using 'get_lhotse_dataloader_from_multi_config' instead." - ) sampler, use_iterable_dataset = get_lhotse_sampler_from_config( config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer ) @@ -284,7 +310,7 @@ def get_lhotse_dataloader_from_single_config( def get_lhotse_dataloader_from_multi_config( - configs: DictConfig, + top_level_config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, @@ -300,46 +326,79 @@ def get_lhotse_dataloader_from_multi_config( The first config is treated as a "main" config that determines the RNG, CUDA allocator, and sampler fusion settings. """ - logging.info(f"We will be using a multi config Lhotse DataLoader with groups: {list(configs.keys())}.") - - configs = [make_structured_with_schema_warnings(c) for c in configs.values() if isinstance(c, DictConfig)] - main_config = configs[0] - maybe_set_cuda_expandable_segments(enabled=main_config.cuda_expandable_segments) - seed = resolve_seed(main_config.seed) - fix_random_seed(seed) - - source_samplers, source_use_iterable_dataset = [], [] - for config in configs: - # TODO(pzelasko): perhaps emit a warning in the unlikely case somebody defines different seeds explicitly. - config.seed = seed - config.shard_seed = main_config.shard_seed - s, t = get_lhotse_sampler_from_config( - config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer - ) - source_samplers.append(s) + + def gather_shared_opts(): + """ + In multi-config setting, the top-level config defines several attributes that overwrite + the ones present in sub-configs. + """ + assert all( + k in top_level_config for k in ["seed", "shard_seed", "shuffle"] + ), "In a multi-config setting (multi_config=True), the top-level namespace (typically train_ds) must define at least 'seed', 'shard_seed', and 'shuffle' keys that will be shared by all sub-configs." + overwriting_opts = [ + "seed", + "shard_seed", + "num_workers", + "pin_memory", + "shuffle", + "sampler_fusion", + "sampler_weights", + "multi_config", + "metadata_only", + "force_finite", + ] + defaults = OmegaConf.structured(LhotseDataLoadingConfig) + top_level_config["seed"] = resolve_seed(top_level_config["seed"]) + return OmegaConf.create({k: top_level_config.get(k, defaults[k]) for k in overwriting_opts}) + + shared_opts = gather_shared_opts() + fix_random_seed(shared_opts.seed) + + configs = { + name: c + for name, c in top_level_config.items() + if isinstance(c, DictConfig) and name not in ("sampler_weights",) # exclude dict opts + } + + source_samplers, source_use_iterable_dataset = {}, [] + for name, config in configs.items(): + try: + expanded_config = make_structured_with_schema_warnings(config) + for k, v in shared_opts.items(): + expanded_config[k] = v + s, t = get_lhotse_sampler_from_config( + config=expanded_config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer + ) + except IncompleteConfigError as e: + raise IncompleteConfigError( + f"Cannot create a sampler for one of the sub-configs in a multi_config setup. The problematic config is under key={name} and has the following contents: {config}" + ) from e + source_samplers[name] = s source_use_iterable_dataset.append(t) - assert all( - st == source_use_iterable_dataset[0] for st in source_use_iterable_dataset[1:] - ), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)." + assert all(st == source_use_iterable_dataset[0] for st in source_use_iterable_dataset[1:]), ( + "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix). " + "You can provide force_iterable_dataset=True to each namespace to fix." + ) use_iterable_dataset = all(source_use_iterable_dataset) - if main_config.sampler_fusion == "zip": - sampler = ZipSampler(*source_samplers) - elif main_config.sampler_fusion == "round_robin": - sampler = RoundRobinSampler(*source_samplers) - elif main_config.sampler_fusion == "randomized_round_robin": - sampler = RoundRobinSampler( - *source_samplers, - randomize=True if main_config.sampler_weights is None else main_config.sampler_weights, - seed=seed, - ) - elif main_config.sampler_fusion == "mux": - raise RuntimeError( - "In order to use a sampler_fusion strategy 'mux', " - "create your dataloader using 'get_lhotse_dataloader_from_config' instead." - ) - else: - raise RuntimeError(f"Unsupported sampler fusion strategy: {main_config.sampler_fusion}") + match shared_opts.sampler_fusion: + case "zip": + sampler = ZipSampler(*source_samplers.values()) + case "round_robin": + sampler = RoundRobinSampler(*source_samplers.values()) + case "randomized_round_robin": + _samplers, _weights = [], [] + for key in source_samplers.keys(): + _samplers.append(source_samplers[key]) + if shared_opts.sampler_weights is not None: + _weights.append(shared_opts.sampler_weights[key]) + sampler = RoundRobinSampler( + *_samplers, + randomize=_weights if len(_weights) > 0 else True, + seed=shared_opts.seed, + ) + case unknown_value: + raise RuntimeError(f"Unsupported sampler fusion strategy: {unknown_value}") # 4. Creating dataloader. if use_iterable_dataset: @@ -352,8 +411,8 @@ def get_lhotse_dataloader_from_multi_config( # This together with infinite datasets removes the need to split data across nodes/workers. dloader_kwargs = dict( dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler), - worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed), - persistent_workers=main_config.num_workers > 0, # helps Lhotse Shar maintain shuffling state + worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=shared_opts.seed), + persistent_workers=shared_opts.num_workers > 0, # helps Lhotse Shar maintain shuffling state ) else: # For non-tarred data, the sampler resides in the training loop process and @@ -363,8 +422,8 @@ def get_lhotse_dataloader_from_multi_config( dloader = torch.utils.data.DataLoader( **dloader_kwargs, batch_size=None, - num_workers=main_config.num_workers, - pin_memory=main_config.pin_memory, + num_workers=shared_opts.num_workers, + pin_memory=shared_opts.pin_memory, ) return dloader @@ -410,6 +469,7 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No tokenizer = TokenizerWrapper(tokenizer) cuts = cuts.map(partial(tokenize, tokenizer=tokenizer), apply_fn=None) cuts = cuts.filter(TokenPerSecondFilter(config.min_tps, config.max_tps)) + cuts = cuts.filter(TokenPerTokenFilter(config.min_tpt, config.max_tpt)) # 2. Optional augmentations. # 2.a. Noise mixing. @@ -451,40 +511,14 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No # Duration filtering, same as native NeMo dataloaders. # We can filter after the augmentations because they are applied only when calling load_audio(). cuts = cuts.filter(DurationFilter(config.min_duration, config.max_duration)) + cuts = cuts.filter( + TokenCountFilter(config.min_tokens, config.max_tokens, measure_total_length=config.measure_total_length) + ) + # Select the strategy customizing Lhotse sampler behaviour. + # Provides support for dynamic batch sizes, multimodal dataloading, 2D bucketing, etc. bucket_duration_bins = determine_bucket_duration_bins(config) - if config.use_multimodal_sampling: - if config.bucket_batch_size is not None: - assert ( - bucket_duration_bins is not None - ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." - constraint = MultimodalFixedBucketBatchSizeConstraint2D( - max_seq_len_buckets=bucket_duration_bins, - batch_sizes=config.bucket_batch_size, - token_equivalent_duration=config.token_equivalent_duration, - ) - else: - constraint = MultimodalSamplingConstraint( - token_equivalent_duration=config.token_equivalent_duration, - batch_size=config.batch_size, - batch_tokens=config.batch_tokens, - quadratic_factor=config.quadratic_factor, - ) - else: - if config.bucket_batch_size is not None: - assert ( - bucket_duration_bins is not None - ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." - constraint = FixedBucketBatchSizeConstraint2D( - max_seq_len_buckets=bucket_duration_bins, - batch_sizes=config.bucket_batch_size, - ) - else: - constraint = TimeConstraint( - max_cuts=config.batch_size, - max_duration=config.batch_duration, - quadratic_duration=config.quadratic_duration, - ) + constraint = determine_sampling_constraint(bucket_duration_bins, config) # 3. The sampler. if config.use_bucketing: @@ -562,7 +596,59 @@ def get_lhotse_sampler_from_config(config, global_rank, world_size, tokenizer=No return sampler, use_iterable_dataset +def determine_sampling_constraint(bucket_duration_bins, config): + """ + Select an appropriate sampling strategy (constraint) for Lhotse samplers based on the configuration. + Sampling constraint affects the batch size (static/dynamic) and bucketing behaviour (1D/2D). + It is the appropriate customization point to introduce support of other modalities, + as it defines a method for example sequence length measurement (audio duration, text tokens, etc.). + + Lhotse's default is :class:`TimeConstraint` for regular audio data, other available options are + multimodal constraints (joint text + audio) and their 2D bucketing extensions. + """ + if config.use_multimodal_sampling: + if config.bucket_batch_size is not None: + assert ( + bucket_duration_bins is not None + ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." + constraint = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=bucket_duration_bins, + batch_sizes=config.bucket_batch_size, + token_equivalent_duration=config.token_equivalent_duration, + ) + else: + constraint = MultimodalSamplingConstraint( + token_equivalent_duration=config.token_equivalent_duration, + batch_size=config.batch_size, + batch_tokens=config.batch_tokens, + quadratic_factor=config.quadratic_factor, + ) + else: + if config.bucket_batch_size is not None: + assert ( + bucket_duration_bins is not None + ), "Cannot use bucket_batch_size option if bucket_duration_bins are not provided." + constraint = FixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=bucket_duration_bins, + batch_sizes=config.bucket_batch_size, + ) + else: + constraint = TimeConstraint( + max_cuts=config.batch_size, + max_duration=config.batch_duration, + quadratic_duration=config.quadratic_duration, + ) + return constraint + + def determine_bucket_duration_bins(config): + """ + Returns appropriate bucket bins based on configuration. + If user provided them explicitly, we just pass them along; + otherwise, we try to create provisional bins when min/max duration is available. + We might return None if it's impossible to determine the bins without computing data statistics, + in which case it will be automatically done at the start of training (but may take a few minutes). + """ if config.bucket_duration_bins is not None: # Bucket duration bins are provided: just use them. ans = OmegaConf.to_container(config.bucket_duration_bins) @@ -590,13 +676,15 @@ def determine_bucket_duration_bins(config): return None -def make_structured_with_schema_warnings(config: DictConfig) -> DictConfig: +def make_structured_with_schema_warnings(config: DictConfig | dict) -> DictConfig: """ Checks the schema and fills missing default option values. Warns the user if any of the fields are not supported by the current schema but does not raise exceptions. """ default = OmegaConf.structured(LhotseDataLoadingConfig) + if not isinstance(config, DictConfig): + config = DictConfig(config) # Remove unsupported keys and warn about them. supported_keys = set(OmegaConf.to_container(default).keys()) @@ -620,127 +708,7 @@ def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfi return use_iterable_dataset -@dataclass -class MultimodalSamplingConstraint(SamplingConstraint): - # how many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch - token_equivalent_duration: float - - # defines maximum batch size (may be lower than that if batch_length is also specified) - batch_size: int | None = None - - # defines the total number of tokens in a mini-batch - # setting this enables dynamic batch sizes - # we will use ``token_equivalent_duration`` to convert audio examples to token sizes - batch_tokens: int | None = None - - # when specified, this value is inversely proportional to the penalty we assign - # to longer examples when measuring their length/duration; - # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty - # tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing - quadratic_factor: float | None = None - - _internal = None - - def __post_init__(self): - self._internal = TokenConstraint( - max_tokens=self.batch_tokens, - max_examples=self.batch_size, - quadratic_length=self.quadratic_factor, - ) - - def add(self, example: Any) -> None: - if isinstance(example, Cut): - num_tokens = self.measure_length(example) - example.num_tokens = num_tokens - self._internal.add(example) - - def exceeded(self) -> bool: - return self._internal.exceeded() - - def close_to_exceeding(self) -> bool: - return self._internal.close_to_exceeding() - - def reset(self) -> None: - self._internal.reset() - - def measure_length(self, example: Any) -> float: - if isinstance(example, Cut): - # "length" of a Cut (audio+text example) is counted as the sum of: - # * num_tokens in each supervision segment ("utterance") in the Cut - # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) - text_tokens = 0 - for s in example.supervisions: - if s.has_custom("tokens"): - text_tokens += len(s.tokens) - return example.duration / self.token_equivalent_duration + text_tokens - if isinstance(example, (TextExample, SourceTargetTextExample, NeMoSFTExample)): - return example.num_tokens - raise RuntimeError(f"Unsupported example type: {type(example)}") - - -@dataclass -class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): - @property - def bucketing_2d_enabled(self) -> bool: - return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 - - def measure_length(self, example: Any) -> tuple[float, float]: - if self.bucketing_2d_enabled: - return example.duration, _measure_tokens(example) - else: - return example.duration - - def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: - if not self.bucketing_2d_enabled: - return super().select_bucket(buckets=buckets, example=example, example_len=example_len) - if example_len is None: - example_len = self.measure_length(example) - bucket_idx = bisect.bisect_right(buckets, example_len) - # For 2D bucketing we have to refine the initially found bucket_idx, as bisect - # looks primarily at the first index of a tuple (i.e. duration). - # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) - # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. - # To refine, we'll try to push the example to as many buckets to the right as possible, - # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 - # is smaller than the bin's dim1 (e.g., output token sequence length). - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - num_buckets = len(self.max_seq_len_buckets) - while ( - (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket - and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. - # The example's 2nd dim is between that of the current and the next bucket; or, - # the next bucket's 2nd dim is still smaller than example. - and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) - ): - bucket_idx = next_idx - bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] - return bucket_idx - - -@dataclass -class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): - token_equivalent_duration: float | None = None - - def measure_length(self, example: Any) -> float: - assert not self.bucketing_2d_enabled, "2D bucketing for multimodal sampling is not yet supported." - if hasattr(example, "num_tokens"): - return example.num_tokens - if isinstance(example, Cut): - assert ( - self.token_equivalent_duration is not None - ), "Cannot use MultimodalFixedBucketBatchSizeConstraint with speech data when token_equivalent_duration was not specified." - return example.duration / self.token_equivalent_duration - raise RuntimeError(f"Unsupported example type: {type(example)}") - - -def is_text(example) -> bool: - return isinstance(example, (TextExample, SourceTargetTextExample, NeMoSFTExample)) - - -Example = TypeVar("Example", bound=Union[Cut, TextExample, SourceTargetTextExample, NeMoSFTExample]) - - -def tokenize(example: Example, tokenizer) -> Example: +def tokenize(example, tokenizer): if isinstance(example, Cut): for s in example.supervisions: if s.text is not None: @@ -752,28 +720,12 @@ def tokenize(example: Example, tokenizer) -> Example: return example -def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Example: - # TODO(pzelasko): This mechanism makes it possible to measure the actual output sequence length - # for prompted models such as AED MultiTask (Canary), which includes the transcript and the prompt. - # We intend to extend it for text modality in follow-up work. - if isinstance(example, Cut): - prompt_format_fn = get_prompt_format_fn(prompt_format) - ans = prompt_format_fn(CutSet([example]), tokenizer) - if isinstance(ans, tuple): - (tokenized_prompted_transcript,), (tokenized_prompt,), (tokenized_transcript,) = ans - example.tokenized_prompted_transcript = tokenized_prompted_transcript - example.tokenized_prompt = tokenized_prompt - example.tokenized_transcript = tokenized_transcript - elif isinstance(ans, dict): - example.tokenized_prompted_transcript = ans["input_ids"][0] - example.tokenized_prompt = ans["context_ids"][0] - example.tokenized_transcript = ans["answer_ids"][0] - else: - raise RuntimeError(f"Unexpected return type from prompt_format_fn (must be dict or tuple): {ans}") - elif isinstance(example, NeMoMultimodalConversation): - example = example.tokenize(tokenizer, prompt_format) - else: - raise RuntimeError(f"Currently we only support tokenization + prompting during sampling for audio modality.") +def tokenize_with_prompt(example, tokenizer, prompt_format: str | PromptFormatter): + if isinstance(prompt_format, str): + prompt_format = PromptFormatter.resolve(prompt_format)(tokenizer) + encoded = apply_prompt_format_fn(example, prompt_format) + for key, value in encoded.items(): + setattr(example, key, value) return example @@ -783,55 +735,6 @@ def tokenize_with_prompt(example: Example, tokenizer, prompt_format: str) -> Exa # to support pickling lambdas if its ever truly necessary. -class DurationFilter: - """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" - - def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max - - def __call__(self, example) -> bool: - if isinstance(example, Cut): - return self.d_min <= example.duration <= self.d_max - else: - return True # does not apply to text etc. - - -class TokenPerSecondFilter: - """ - Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) - is in range [tps_min, tps_max] and ``False`` otherwise. - """ - - def __init__(self, tps_min: float, tps_max: float) -> None: - assert tps_min <= tps_max - self.tps_min = tps_min - self.tps_max = tps_max - self.enabled = tps_min > 0 or tps_max < float("inf") - - def __call__(self, example) -> bool: - if not isinstance(example, Cut) or not self.enabled: - return True # pass-through for non-audio examples. - tps = _measure_tps(example) - return self.tps_min <= tps <= self.tps_max - - -def _measure_tokens(cut: Cut) -> int: - if hasattr(cut, "tokenized_prompted_transcript"): - return len(cut.tokenized_prompted_transcript) # tokenized with prompt formatter - supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] - assert len(supervisions_with_tokens) > 0, ( - "Cannot measure tokens-per-second with untokenized supervisions. " - "Did you forget to provide the tokenizer argument to get_lhotse_dataloader_from_config() method?" - ) - return sum(len(s.tokens) for s in supervisions_with_tokens) - - -def _measure_tps(cut: Cut) -> float: - num_tokens = _measure_tokens(cut) - return num_tokens / cut.duration - - def _normalize_loudness(cuts: CutSet, db_norm: float) -> CutSet: return cuts.normalize_loudness(target=db_norm, mix_first=False) diff --git a/nemo/collections/common/data/lhotse/sampling.py b/nemo/collections/common/data/lhotse/sampling.py new file mode 100644 index 000000000000..5206b9b1dec0 --- /dev/null +++ b/nemo/collections/common/data/lhotse/sampling.py @@ -0,0 +1,303 @@ +import bisect +import logging +import math +from dataclasses import dataclass +from typing import Any, Sequence + +from lhotse.cut import Cut +from lhotse.dataset import SamplingConstraint, TokenConstraint +from lhotse.dataset.sampling.dynamic_bucketing import FixedBucketBatchSizeConstraint +from lhotse.utils import ifnone + +from nemo.collections.common.data.lhotse.text_adapters import Formattable + + +@dataclass +class MultimodalSamplingConstraint(SamplingConstraint): + """ + Sampling strategy that customizes Lhotse samplers to measure sequence lengths as token counts. + It provides a unified interface for audio and text examples - audio duration is converted to + an equivalent token count. + """ + + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. + token_equivalent_duration: float | None = None + + # Defines maximum batch size (may be lower than that if batch_length is also specified). + batch_size: int | None = None + + # Defines the total number of tokens in a mini-batch. + # Setting this enables dynamic batch sizes. + # We will use ``token_equivalent_duration`` to convert audio examples to token sizes. + batch_tokens: int | None = None + + # When specified, this value is inversely proportional to the penalty we assign + # to longer examples when measuring their length/duration; + # i.e. large quadratic factor is a small penalty, small quadratic factor is a large penalty. + # Tweaking this helps equalize the GPU memory usage for dynamic batch sizes when using bucketing. + quadratic_factor: float | None = None + + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + + _internal = None + + def __post_init__(self): + self._internal = TokenConstraint( + max_tokens=self.batch_tokens, + max_examples=self.batch_size, + quadratic_length=self.quadratic_factor, + ) + + def add(self, example: Any) -> None: + num_tokens = self.measure_length(example) + example.num_tokens = num_tokens + self._internal.add(example) + + def exceeded(self) -> bool: + return self._internal.exceeded() + + def close_to_exceeding(self) -> bool: + return self._internal.close_to_exceeding() + + def reset(self) -> None: + self._internal.reset() + + def measure_length(self, example: Any) -> float: + if isinstance(example, Cut): + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + if self.measure_total_length: + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + text_tokens = 0 + for s in example.supervisions: + if s.has_custom("tokens"): + text_tokens += len(s.tokens) + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + elif isinstance(example, Formattable): + try: + return example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + "Couldn't determine the length of a text example; " + "have you provided both prompt_format and tokenizer when instantiating the dataloader?" + ) from e + raise RuntimeError(f"Unsupported example type: {type(example)}") + + +@dataclass +class FixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint): + """ + Sampling strategy that customizes Lhotse samplers to support 2D bucket selection (it also supports 1D). + It is intended only for audio examples (i.e., Lhotse Cut objects). + """ + + @property + def bucketing_2d_enabled(self) -> bool: + return isinstance(self.max_seq_len_buckets[0], Sequence) and len(self.max_seq_len_buckets[0]) == 2 + + def measure_length(self, example: Cut) -> tuple[float, float] | float: + if self.bucketing_2d_enabled: + return example.duration, _measure_tokens(example) + else: + return example.duration + + def select_bucket(self, buckets: Any, example: Any = None, example_len: Any = None) -> int: + if not self.bucketing_2d_enabled: + return super().select_bucket(buckets=buckets, example=example, example_len=example_len) + if example_len is None: + example_len = self.measure_length(example) + bucket_idx = bisect.bisect_left(buckets, example_len) + # For 2D bucketing we have to refine the initially found bucket_idx, as bisect + # looks primarily at the first index of a tuple (i.e. duration). + # For example, with buckets [(1, 1), (1, 2), (2, 2), (2, 4)] and example (1.5, 3) + # bisect would allocate it to bucket_idx=2 instead of bucket_idx=3. + # To refine, we'll try to push the example to as many buckets to the right as possible, + # as long as they have the same dim0 length (e.g. audio duration) and the example's dim1 + # is smaller than the bin's dim1 (e.g., output token sequence length). + bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] + num_buckets = len(self.max_seq_len_buckets) + while ( + (next_idx := bucket_idx + 1) < num_buckets # There is a next bucket + and (bin := self.max_seq_len_buckets[next_idx])[0] == bin_dim0 # The next bucket has the same 1st dim. + # The example's 2nd dim is between that of the current and the next bucket; or, + # the next bucket's 2nd dim is still smaller than example. + and (bin_dim1 < example_len[1] <= bin[1] or bin[1] < example_len[1]) + ): + bucket_idx = next_idx + bin_dim0, bin_dim1 = self.max_seq_len_buckets[bucket_idx] + + if example_len[0] > bin_dim0 or example_len[1] > bin_dim1: + logging.warning( + f"Data sample exceeds 2D bucket specification: lengths={example_len} bucket=({bin_dim0}, {bin_dim1}) " + f"(there is no larger bucket that would fit this example). " + f"We will keep it but expect OutOfMemoryError to happen during the training. " + f"You can fix this by stricter filtering with max_duration, max_tokens, max_tps, max_tpt; " + f"or re-estimating your bucket bins to match the actual data length distribution. " + f"Details: {example=}" + ) + + return bucket_idx + + +@dataclass +class MultimodalFixedBucketBatchSizeConstraint2D(FixedBucketBatchSizeConstraint2D): + """ + Sampling strategy that customizes Lhotse samplers to support both multimodal sampling and 2D bucket selection. + It combines the capabilities of :class:`FixedBucketBatchSizeConstraint2D` and :class:`MultimodalSamplingConstraint`. + """ + + # How many seconds of audio is a text token worth; balances audio to text ratio in a mini-batch. + # Generally set this to frame_shift * total_subsampling_factor of your audio encoder. + token_equivalent_duration: float | None = None + + # When False (default), we only consider the input part of the example to determine its length, + # e.g. for a Cut that means its audio duration converted to tokens, for text that means len(context_ids), etc. + # When True, we consider the sum of input and output lengths together (useful mostly for decoder-only models). + measure_total_length: bool = False + + def measure_length(self, example: Any) -> float | tuple[float, float]: + if isinstance(example, Cut): + # Total length of a Cut (audio+text example) is counted as the sum of: + # * num_tokens in each supervision segment ("utterance") in the Cut + # * num_frames of audio (frame=token) given a token-equivalent-duration (basically a frame shift) + audio_len_in_tokens = math.ceil(example.duration / self.token_equivalent_duration) + text_tokens = _measure_tokens(example) + + if self.bucketing_2d_enabled: + return audio_len_in_tokens, text_tokens + + else: + if self.measure_total_length: + return audio_len_in_tokens + text_tokens + else: + return audio_len_in_tokens + + elif isinstance(example, Formattable): + if self.bucketing_2d_enabled: + return example.input_length, example.output_length + else: + return example.total_length if self.measure_total_length else example.input_length + + raise RuntimeError(f"Unsupported example type: {type(example)}") + + +class DurationFilter: + """ + Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, d_min: float | None, d_max: float | None) -> None: + self.d_min = ifnone(d_min, -1) + self.d_max = ifnone(d_max, float("inf")) + + def __call__(self, example) -> bool: + if isinstance(example, Cut): + return self.d_min <= example.duration <= self.d_max + else: + return True # does not apply to text etc. + + +class TokenCountFilter: + """ + Callable, returns ``True`` if an example's number of tokens is in range [t_min, t_max] and ``False`` otherwise. + + It is only applicable to data types that derive from class ``Formattable`` and lhotse ``Cut`` objects. + Acts as a passthrough for Cuts. + Raises exception if a non-Formattable and non-Cut data are provided. + + The ``measure_total_length`` option allows to select whether we should filter on context_ids length (=False) + or input_ids length (=True). + The difference is that for decoder-only models, we collapse input and output into a single sequence, + so we should measure the example length using input_ids (measure_total_length=True). + However, for models which have separate inputs and outputs such as encoder-decoder models, + we want to measure the input lengths only here (measure_total_length=False), + and enable ``TokenPerTokenFilter`` for additional filtering on the output sequence length. + """ + + def __init__(self, t_min: float | None, t_max: float | None, measure_total_length: bool) -> None: + self.t_min = ifnone(t_min, -1) + self.t_max = ifnone(t_max, float("inf")) + self.measure_total_length = measure_total_length + self.enabled = self.t_min > 0 or self.t_max < float("inf") + + def __call__(self, example) -> bool: + if not self.enabled or isinstance(example, Cut): + return True # does not apply to Cuts + assert isinstance(example, Formattable), ( + f"TokenCountFilter can only be applied to data examples that derive Formattable class. " + f"Formattable objects define properties input_length, output_length, and total_length that " + f"allow us to select the right sequence length for filtering. We got: {example}" + ) + try: + length = example.total_length if self.measure_total_length else example.input_length + except (AttributeError, AssertionError) as e: + raise RuntimeError( + f"Cannot measure token count for example: {example} " + f"-- did you forget to apply prompt formatting? If instantiating Lhotse dataloader, " + f"make sure you provided 'prompt_format' option and passed the tokenizer." + ) from e + return self.t_min <= length <= self.t_max + + +class TokenPerSecondFilter: + """ + Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) + is in range [tps_min, tps_max] and ``False`` otherwise. + Acts as a pass-through for objects of other type than Cut. + """ + + def __init__(self, tps_min: float | None, tps_max: float | None) -> None: + self.tps_min = ifnone(tps_min, -1) + self.tps_max = ifnone(tps_max, float("inf")) + assert tps_min <= tps_max, f"{tps_min=} {tps_max=}" + self.enabled = tps_min > 0 or tps_max < float("inf") + + def __call__(self, example) -> bool: + if not isinstance(example, Cut) or not self.enabled: + return True # pass-through for non-audio examples. + tps = _measure_tps(example) + return self.tps_min <= tps <= self.tps_max + + +class TokenPerTokenFilter: + """ + Callable, returns ``True`` if a cut's num_tokens (sum of len(tokens) for each supervision) + is in range [tps_min, tps_max] and ``False`` otherwise. + Acts as a pass-through for audio examples (Cuts). + """ + + def __init__(self, tpt_min: float | None, tpt_max: float | None) -> None: + self.tpt_min = ifnone(tpt_min, -1) + self.tpt_max = ifnone(tpt_max, float("inf")) + assert tpt_min <= tpt_max, f"{tpt_min=} {tpt_max=}" + self.enabled = tpt_min > 0 or tpt_max < float("inf") + + def __call__(self, example) -> bool: + if isinstance(example, Cut) or not self.enabled: + return True # pass-through for non-text examples. + tpt = example.answer_ids.shape[0] / example.context_ids.shape[0] + return self.tpt_min <= tpt <= self.tpt_max + + +def _measure_tokens(cut: Cut) -> int: + if hasattr(cut, "input_ids"): + return len(cut.input_ids) # tokenized with prompt formatter + supervisions_with_tokens = [s for s in cut.supervisions if hasattr(s, "tokens")] + assert len(supervisions_with_tokens) > 0, ( + "Cannot measure the number of tokens with untokenized supervisions. " + "Did you forget to provide the tokenizer argument to get_lhotse_dataloader_from_config() method?" + ) + return sum(len(s.tokens) for s in supervisions_with_tokens) + + +def _measure_tps(cut: Cut) -> float: + num_tokens = _measure_tokens(cut) + return num_tokens / cut.duration diff --git a/nemo/collections/common/data/lhotse/text_adapters.py b/nemo/collections/common/data/lhotse/text_adapters.py index 9e64b37bffef..9503d2eb1ffa 100644 --- a/nemo/collections/common/data/lhotse/text_adapters.py +++ b/nemo/collections/common/data/lhotse/text_adapters.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy +import math import random from collections import deque -from dataclasses import asdict, dataclass +from dataclasses import dataclass from itertools import groupby from pathlib import Path from typing import Iterator, Literal, Optional, Union @@ -23,26 +22,65 @@ import numpy as np import torch from lhotse import Recording +from lhotse.custom import CustomFieldMixin from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.serialization import load_jsonl -from lhotse.shar import AudioTarWriter, JsonlShardWriter, TarIterator, TarWriter -from lhotse.utils import Pathlike, asdict_nonull, is_valid_url +from lhotse.shar import AudioTarWriter, JsonlShardWriter, TarIterator +from lhotse.utils import Pathlike, is_valid_url from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths +from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn, registered_prompt_format_fn from nemo.collections.common.parts.preprocessing.manifest import get_full_path -from nemo.collections.common.prompts import PromptFormatter -from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer, TokenizerWrapper -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -from nemo.utils import logging +from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper """ -Basic text example, adequate for pretraining-style language modeling. +Formattable: mixin class with data fields for prompt formatter outputs and method for +applying prompt formatters to derived data types. +""" + + +class Formattable: + def __init__(self): + self.input_ids: np.ndarray | torch.Tensor | None = None + self.context_ids: np.ndarray | torch.Tensor | None = None + self.answer_ids: np.ndarray | torch.Tensor | None = None + self.mask: np.ndarray | torch.Tensor | None = None + + @property + def input_length(self) -> int | None: + if self.context_ids is None: + return None + return self.context_ids.shape[0] + + @property + def output_length(self) -> int | None: + if self.answer_ids is None: + return None + return self.answer_ids.shape[0] + + @property + def total_length(self) -> int | None: + if self.input_ids is None: + return None + return self.input_ids.shape[0] + + def apply_prompt_format(self, prompt) -> "Formattable": + ans = apply_prompt_format_fn(self, prompt) + self.input_ids = ans["input_ids"] + self.context_ids = ans["context_ids"] + self.answer_ids = ans["answer_ids"] + self.mask = ans["mask"] + return self + + +""" +TextExample: data types, file parser, default prompt formatting logic. """ @dataclass -class TextExample: +class TextExample(Formattable, CustomFieldMixin): """ Represents a single text example. Useful e.g. for language modeling. """ @@ -50,12 +88,7 @@ class TextExample: text: str language: str | None = None tokens: Optional[np.ndarray] = None - - @property - def num_tokens(self) -> Optional[int]: - if self.tokens is None: - return None - return len(self.tokens) + custom: dict = None def tokenize(self, tokenizer: TokenizerWrapper) -> "TextExample": self.tokens = np.asarray(tokenizer(self.text, self.language)) @@ -88,13 +121,26 @@ def __iter__(self) -> Iterator[TextExample]: yield TextExample(line, language=self.language) +@registered_prompt_format_fn(TextExample) +def default_text_example_prompt_format_fn(example: TextExample, prompt): + # It doesn't really make sense to prompt format a single line text example, + # but we implement some default logic for the sake of completeness. + # The default logic here is to treat the whole example as an assistant turn, + # so that the mask is all set to true for the training loss. + return prompt.encode_dialog( + [ + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.text}}, + ] + ) + + """ -Source-target text examples (e.g., machine translation). +SourceTargetTextExample: data types, file parser, default prompt formatting logic. """ @dataclass -class SourceTargetTextExample: +class SourceTargetTextExample(Formattable, CustomFieldMixin): """ Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks. Supports a ``question`` field, used as the prompt for LLM. @@ -103,38 +149,13 @@ class SourceTargetTextExample: source: TextExample target: TextExample question: TextExample | None = None - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None - - @property - def num_tokens(self) -> Optional[int]: - if self.input_ids is not None: - return self.input_ids.shape[0] - return None - - def tokenize(self, tokenizer: TokenizerWrapper) -> "TextExample": - input_ids = [] - context_ids = [] - if self.question: - ans = tokenizer(self.question.text, self.question.language) - input_ids.extend(ans) - context_ids.extend(ans) - ans = tokenizer(self.source.text, self.source.language) - input_ids.extend(ans) - context_ids.extend(ans) - - answer_ids = tokenizer(self.target.text, self.target.language) - input_ids.extend(answer_ids) - - self.input_ids = np.asarray(input_ids) - self.context_ids = np.asarray(context_ids) - self.answer_ids = np.asarray(answer_ids) - mask = np.zeros_like(self.input_ids, dtype=np.bool_) - mask[self.context_ids.shape[0] :] = True - self.mask = mask + custom: dict = None + def tokenize(self, tokenizer: TokenizerWrapper) -> "SourceTargetTextExample": + self.source = self.source.tokenize(tokenizer) + self.target = self.target.tokenize(tokenizer) + if self.question is not None: + self.question = self.question.tokenize(tokenizer) return self @@ -194,75 +215,46 @@ def __iter__(self) -> Iterator[SourceTargetTextExample]: ) +@registered_prompt_format_fn(SourceTargetTextExample) +def default_src_tgt_prompt_format_fn(example: SourceTargetTextExample, prompt): + if example.question is not None: + ctx = f"{example.question.text} {example.source.text}" + else: + ctx = example.source.text + return prompt.encode_dialog( + [ + {"role": "user", "slots": {"message": ctx}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) + + +""" +NeMoSFTExample: data types, file parser, default prompt formatting logic. +""" + + @dataclass -class NeMoSFTExample: +class NeMoSFTExample(Formattable, CustomFieldMixin): data: dict language: str | None = None - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None metadata: dict | None = None + custom: dict = None - def tokenize(self, tokenizer: TokenizerWrapper | TokenizerSpec) -> "NeMoSFTExample": - """ - Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). - Supports BPE tokenizers and aggregate tokenizers. - - The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. - """ - special_tokens = { - "system_turn_start": "", - "turn_start": "", - "label_start": "", - "end_of_turn": "\n", - "end_of_name": "\n", - } - if isinstance(tokenizer, TokenizerWrapper): - tokenizer = tokenizer._tokenizer - if isinstance(tokenizer, AggregateTokenizer): - assert self.language is not None, ( - f"Error: attempted to use AggregateTokenizer for NeMoSFTExample which did not specify language. " - f"Problematic example: {self}" - ) - assert self.language in tokenizer.tokenizers_dict, ( - f"Error: attempted to use AggregateTokenizer for NeMoSFTExample with unsupported language: {self.language}. " - f"The set of supported languages is: {' '.join(tokenizer.tokenizers_dict.keys())}. " - f"Problematic example: {self}" - ) - tokenizer = tokenizer.tokenizers_dict[self.language] - - label_start_tokens, name_end_token_ids, num_turn_start_tokens = _build_samples_mapping( - tokenizer, special_tokens +@registered_prompt_format_fn(NeMoSFTExample) +def default_sft_prompt_format_fn(example: NeMoSFTExample, prompt): + if "system" in example.data and example.data["system"]: + raise RuntimeError( + f"Default prompt format for NeMoSFTExample doesn't support 'system' prompt. " + f"Please specialize the prompt_format_fn for PromptFormatter of type {prompt}" ) - - tokenized = preprocess( - source=self.data, - tokenizer=tokenizer, - name_end_token_ids=name_end_token_ids, - label_start_ids=label_start_tokens, - special_tokens=special_tokens, - num_turn_start_tokens=num_turn_start_tokens, - ) - self.input_ids = tokenized["input_ids"].numpy() - self.context_ids = tokenized["context_ids"].numpy() - self.answer_ids = tokenized["answer_ids"].numpy() - self.mask = tokenized["mask"].numpy() - self.metadata = {k: v for k, v in self.data.items() if k not in ['conversations']} - - return self - - # TODO(pzelasko): for mini-batch sampling purposes, should we consider input_ids or answer_ids - # as representative of the sequence length? Putting input_ids here for now. - - @property - def tokens(self) -> np.ndarray: - return self.input_ids - - @property - def num_tokens(self) -> int: - return self.input_ids.shape[0] + return prompt.encode_dialog( + [ + {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}} + for turn in example.data["conversations"] + ] + ) @dataclass @@ -288,11 +280,6 @@ class NeMoSFTJsonlAdapter: "dataset": str, "category": str, } - - Refer to examples of this format here: - - * TODO: links to examples? - * TODO: links to more detailed schema definition? """ paths: Union[Pathlike, list[Pathlike]] @@ -313,6 +300,11 @@ def __iter__(self) -> Iterator[NeMoSFTExample]: yield NeMoSFTExample(data, language=self.language) +""" +NeMoMultimodalConversation: data types, file parser, default prompt formatting logic. +""" + + @dataclass class TextTurn: value: str @@ -342,56 +334,40 @@ def to_dict(self): @dataclass -class NeMoMultimodalConversation: +class NeMoMultimodalConversation(Formattable, CustomFieldMixin): id: str turns: list[TextTurn | AudioTurn] - input_ids: np.ndarray | None = None - context_ids: np.ndarray | None = None - answer_ids: np.ndarray | None = None - mask: np.ndarray | None = None - - def tokenize( - self, - tokenizer: TokenizerWrapper | TokenizerSpec, - prompt: PromptFormatter = None, - ) -> "NeMoMultimodalConversation": - """ - Create a tokenized variant of this example given a tokenizer (i.e. fill the optional fields). - Supports BPE tokenizers and aggregate tokenizers. - - The tokenization is compatible with Megatron's :class:`GPTSFTChatDataset`. - """ - if isinstance(tokenizer, TokenizerWrapper): - tokenizer = tokenizer._tokenizer - if isinstance(tokenizer, AggregateTokenizer): - raise NotImplementedError("NeMoMultimodalConversation does not support AggregateTokenizer yet.") - if prompt is None: - prompt = PromptFormatter.resolve("plain")(tokenizer) - elif isinstance(prompt, str): - prompt = PromptFormatter.resolve(prompt)(tokenizer) - - # Collapse consecutive same-role turns into single turn for proper prompt formatting. - turns = groupby( - [ - { - "role": turn.role, - "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, - } - for turn in self.turns - ], - key=lambda turn: turn["role"], - ) - turns = [ - {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} - for role, turn_grp in turns - ] - ans = prompt.encode_dialog(turns) - self.input_ids = ans["input_ids"] - self.context_ids = ans["context_ids"] - self.answer_ids = ans["answer_ids"] - self.mask = ans["mask"] + token_equivalent_duration: float = None + custom: dict = None - return self + @property + def input_length(self) -> int | None: + if self.context_ids is None: + return None + extra = _compute_num_audio_tokens(self, "context") + return self.context_ids.shape[0] + extra + + @property + def output_length(self) -> int | None: + if self.answer_ids is None: + return None + extra = _compute_num_audio_tokens(self, "answer") + return self.answer_ids.shape[0] + extra + + @property + def total_length(self) -> int | None: + if self.input_ids is None: + return None + extra = _compute_num_audio_tokens(self, "all") + return self.input_ids.shape[0] + extra + + @property + def has_audio_turns(self) -> bool: + return any(isinstance(t, AudioTurn) for t in self.turns) + + @property + def has_text_turns(self) -> bool: + return any(isinstance(t, TextTurn) for t in self.turns) def to_dict(self): return {"id": self.id, "conversations": [t.to_dict() for t in self.turns]} @@ -400,6 +376,54 @@ def list_cuts(self) -> list[Cut]: return [turn.cut for turn in self.turns if isinstance(turn, AudioTurn)] +def _compute_num_audio_tokens(example: NeMoMultimodalConversation, mode: Literal["context", "answer", "all"]) -> int: + if not example.has_audio_turns: + return 0 + assert example.token_equivalent_duration is not None, ( + "Cannot compute the length of a NeMoMultimodalConversation: " + "token_equivalent_duration must be set in order to estimate the number of tokens equivalent to audio turns. " + "Did you forget to set token_equivalent_duration option in your dataloading config? " + "Tip: generally it should be set to frame_shift * total_subsampling_factor of your audio encoder model." + ) + match mode: + case "context": + turns = example.turns[:-1] + case "answer": + turns = example.turns[-1:] + case "all": + turns = example.turns + case _: + raise RuntimeError(f"invalid mode for number of audio token computation: {mode}") + return sum( + [ + # subtract 1 for each audio locator tag as its token will be replaced + math.ceil(turn.cut.duration / example.token_equivalent_duration) - 1 + for turn in turns + if isinstance(turn, AudioTurn) + ] + ) + + +@registered_prompt_format_fn(NeMoMultimodalConversation) +def default_multimodal_conversation_prompt_format_fn(example: NeMoMultimodalConversation, prompt): + # Collapse consecutive same-role turns into single turn for proper prompt formatting. + turns = groupby( + [ + { + "role": turn.role, + "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, + } + for turn in example.turns + ], + key=lambda turn: turn["role"], + ) + turns = [ + {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} + for role, turn_grp in turns + ] + return prompt.encode_dialog(turns) + + @dataclass class NeMoMultimodalConversationJsonlAdapter: """ @@ -425,6 +449,7 @@ class NeMoMultimodalConversationJsonlAdapter: manifest_filepath: str | list[str] audio_locator_tag: str tarred_audio_filepaths: str | list[str] = None + token_equivalent_duration: float = None shuffle_shards: bool = False shard_seed: Union[int, Literal["trng", "randomized"]] = "trng" @@ -515,6 +540,7 @@ def _iter_jsonl(self): ) for turn in data["conversations"] ], + token_equivalent_duration=self.token_equivalent_duration, ) @@ -567,297 +593,3 @@ def _setup_writers(self): Path(self.output_dir).mkdir(exist_ok=True) self.manifest_writer = JsonlShardWriter(f"{self.output_dir}/manifest_{self.shard_idx}.jsonl", shard_size=None) self.tar_writer = AudioTarWriter(f"{self.output_dir}/audio_{self.shard_idx}.tar", shard_size=None) - - -""" -The code below is copied from nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_chat_dataset.py -with minimal modifications in order to avoid importing the NLP collection. - -We require this code for on-the-fly text example tokenization in a compatible way with Megatron, -so that we can determine the mini-batch sizes using the token counts. -""" - - -def preprocess( - source: dict, - tokenizer: TokenizerSpec, - name_end_token_ids: int, - label_start_ids: list, - special_tokens: dict, - num_turn_start_tokens: int, -): - """ - Given a conversation list. This transform: - 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; - 2. Concatenate conversations together; - 3. Tokenize the concatenated conversation; - 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. - """ - header, conversation, data_type, mask_role = _get_header_conversation_type_mask_role(source, special_tokens) - # tokenize conversations - input_ids = tokenizer.text_to_ids(conversation) - target = copy.deepcopy(input_ids) - header_tokens = tokenizer.text_to_ids(header) - header_len = len(header_tokens) - - ids = [] - tokenized_lens = [] - assert torch.equal(torch.tensor(target[:header_len]), torch.tensor(header_tokens)) - for s in source['conversations']: - # hack to remove the extra empty token in front - id1 = tokenizer.text_to_ids(PREFIX_STR + s["value"]) - id2 = tokenizer.text_to_ids(PREFIX_STR) - tokenized_sentence = id1[len(id2) :] - ids.append(torch.tensor(tokenized_sentence)) - tokenized_lens.append(len(tokenized_sentence)) - speakers = [sentence["from"] for sentence in source['conversations']] - assert mask_role in speakers, "mask role not in the conversation" - target = torch.LongTensor(target) - # not going to train on the header - target[:header_len] = IGNORE_INDEX - input_ids = torch.LongTensor(input_ids) - _mask_targets( - target, - tokenized_lens, - speakers, - header_len, - ids, - tokenizer, - mask_role, - data_type, - name_end_token_ids, - special_tokens, - label_start_ids, - num_turn_start_tokens, - ) - mask = (target != IGNORE_INDEX).bool() - assert mask.sum().item() != 0, "mask is empty" - # Choose the last conversation as answer other history are context - last_ignore_index_pos = torch.nonzero(target == IGNORE_INDEX)[-1].item() + 1 - context_ids = input_ids[:last_ignore_index_pos] - answer_ids = input_ids[last_ignore_index_pos:] - return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids) - - -def _build_samples_mapping(tokenizer, special_tokens): - # Copied from gpt_sft_chat_dataset.py - LABEL_START = special_tokens['label_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - id1 = tokenizer.text_to_ids(PREFIX_STR) - id2 = tokenizer.text_to_ids(PREFIX_STR + LABEL_START) - label_start_tokens = id2[len(id1) :] - - id1 = tokenizer.text_to_ids(PREFIX_STR + END_NAME_SIGNAL) - id2 = tokenizer.text_to_ids(PREFIX_STR) - name_end_token_ids = id1[len(id2) :] - - id1 = tokenizer.text_to_ids(PREFIX_STR + special_tokens['turn_start']) - id2 = tokenizer.text_to_ids(PREFIX_STR) - num_turn_start_tokens = len(id1) - len(id2) - - return label_start_tokens, name_end_token_ids, num_turn_start_tokens - - -PREFIX_STR = ( - "\x00" # the prefix string used in the tokenizer to deal with the added empty token for some of the tokenizers -) - -IGNORE_INDEX = -100 -SYSTEM_TOKEN = "System" - -TYPE_INSTRUCTION = { - 'TEXT_TO_VALUE': "", - 'VALUE_TO_TEXT': '', -} - - -def _get_header_conversation_type_mask_role(source, special_tokens): - END_SIGNAL = special_tokens['end_of_turn'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - data_type = None - if 'type' in source: - data_type = source['type'] - if data_type is not None: - assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported" - # add end signal and concatenate together - conversation = source['system'] - if data_type is not None: - if TYPE_INSTRUCTION[data_type] != '': - conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type] - mask_role = source.get('mask', 'User') - header = f"{special_tokens['system_turn_start']}{SYSTEM_TOKEN}{END_NAME_SIGNAL}{conversation}{END_SIGNAL}" - conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type, special_tokens) - return header, conversation, data_type, mask_role - - -def identify_start_index_of_subsequence(subsequence, sequence): - """find the location of the small tensor in the large tensor. - e.g. small = [1,3], large = [2,3,1,3], returns 2 - small = [3,2], large = [2,3,1,3], returns -1 - Args: - small (tensor): small tensor - large (tensor): large tensor - """ - for i in range(sequence.size(0) - subsequence.size(0) + 1): - if torch.equal(sequence[i : i + subsequence.size(0)], subsequence): - return i - return -1 - - -def _mask_targets( - target, - tokenized_lens, - speakers, - header_len, - s_ids, - tokenizer, - mask_role, - gtype, - name_end_token_ids, - special_tokens, - label_start_ids, - num_turn_start_tokens, -): - """This function masks the tokens so the loss is computed only on the non-masked role's responses. - For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes. - - Args: - target (Tensor): input ids - tokenized_lens (List[int]): array of lengths of each turns - speakers (List[str]): array of speakers of each turns - header_len (int): the system prompt length - s_ids (List[Tensor]): array of tokenized ids of each turns - tokenizer (TokenizerSpec): tokenizer object - mask_role (str): the speaker id to be masked from loss computation - gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT' - name_end_token_ids (int): end of name token ids - special_tokens (dict): special tokens used for the chat prompt. It has the keys: system_turn_start, turn_start, label_start, end_of_turn - label_start_ids (list): list of label start token ids, - num_turn_start_tokens (int): number of tokens of the turn_start str - """ - TURN_TOKEN = special_tokens['turn_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - label_start_ids = torch.tensor(label_start_ids) - name_end_token_ids = torch.tensor(name_end_token_ids) - - cur_idx = header_len - tgt_len = target.shape[0] - for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)): - # note, sentence piece will add extra empty token in front. has to compute the diff - id1 = tokenizer.text_to_ids(PREFIX_STR) - id2 = tokenizer.text_to_ids(PREFIX_STR + TURN_TOKEN + speaker + END_NAME_SIGNAL) - skip_name_len = len(id2) - len( - id1 - ) # s_ids[:skip_name_len] is the name part of the prompt 'TURN_TOKEN + speaker + END_NAME_SIGNAL' - # get the position of the label start string in this turn - location = identify_start_index_of_subsequence(label_start_ids, s_id) - - if location >= 0: - # if it contains the label start tokens - if gtype == 'VALUE_TO_TEXT': - # handles the case that condition on labels to generate respone - # the next token after the name part of the prompt is the beginning of the label start tokens - assert skip_name_len == location - # find the first new line token after the label part, which indicates the end of the whole label string - # newline_loc = torch.where((s_id[skip_name_len:] == name_end_token_ids))[0] - newline_loc = identify_start_index_of_subsequence(name_end_token_ids, s_id[skip_name_len:]) - if newline_loc < 0: - # cannot find new line token, which means the the whole turn is just a partial label string. Mask the whole turn - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - continue - # skip the label part and the new line token - more_skip_len = newline_loc + len(name_end_token_ids) - # skip the name part and the label part - skip_name_len += more_skip_len - elif gtype == 'TEXT_TO_VALUE': - # handles the case that condition on response to generate label - # skip the name part, response and the label start tokens part, the remainder is the label string without label start, e.g. 'quality:9,toxicity:8...' - skip_name_len = location + len(label_start_ids) - if cur_idx >= tgt_len: - break - # elif cur_idx + tokenized_len < tgt_len: - # # Check whether the mask is applied to the correct position, the first token is turn start tokens - # if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]): - # logging.warning("a sentence mismatches the corresponding piece " "in the conversation") - if i == 0 and (gtype == 'VALUE_TO_TEXT' or gtype is None): - # mask the first turn completely to provide at least one turn as context for the rest - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and i == 1 and gtype == 'TEXT_TO_VALUE': - # leave the first turn start tag unmasked, servers severs as the end of turn signal - target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and (i > 1): - # leave the first turn start tag unmasked, which severs as the end of turn signal - target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX - elif speaker == mask_role and (i <= 1): - # mask out everything in the second turn - target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX - else: - # mask up to name part, label part for VALUE_TO_TEXT, or name part, response and label start tokens for TEXT_TO_VALUE, or just the name part if gtype is None - target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX - cur_idx += tokenized_len - - -def _add_speaker_and_signal(header, source, mask_role, gtype, special_tokens): - TURN_TOKEN = special_tokens['turn_start'] - END_SIGNAL = special_tokens['end_of_turn'] - LABEL_START = special_tokens['label_start'] - END_NAME_SIGNAL = special_tokens['end_of_name'] - - """Add speaker and start/end signal on each round.""" - BEGIN_SIGNAL = "" - conversation = header - for i, sentence in enumerate(source): - sentence_from = sentence["from"] - role_token = TURN_TOKEN - if gtype is None: - sentence["value"] = ( - BEGIN_SIGNAL + role_token + sentence_from + END_NAME_SIGNAL + sentence["value"] + END_SIGNAL - ) - elif gtype == "VALUE_TO_TEXT": - sentence["value"] = ( - BEGIN_SIGNAL - + role_token - + sentence_from - + END_NAME_SIGNAL - + ( - response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) - if 'label' in sentence - else '' - ) - + sentence["value"] - + END_SIGNAL - ) - elif gtype == "TEXT_TO_VALUE": - sentence["value"] = ( - BEGIN_SIGNAL - + role_token - + sentence_from - + END_NAME_SIGNAL - + sentence["value"] - + END_SIGNAL - + ( - response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL) - if 'label' in sentence - else '' - ) - ) - else: - raise ValueError( - f"source type {gtype} not supported, only 'VALUE_TO_TEXT' and 'TEXT_TO_VALUE' are supported" - ) - conversation += sentence["value"] - # if the last turn is not masked, add next token start token to the end, which will be included for loss calculation - if sentence_from != mask_role and i == len(source) - 1: - conversation += TURN_TOKEN - return conversation - - -def response_value_formater(label, label_start, end_signal): - if isinstance(label, str): - return label_start + label + end_signal - elif label is None: - return '' - else: - raise ValueError(f'Unknown label type {type(label)}, only str type is supported') diff --git a/nemo/collections/common/data/prompt_fn.py b/nemo/collections/common/data/prompt_fn.py new file mode 100644 index 000000000000..bd1e45ea92e2 --- /dev/null +++ b/nemo/collections/common/data/prompt_fn.py @@ -0,0 +1,78 @@ +from typing import Callable, Type + +import torch + + +PromptFormatFnReturnType = dict[str, list[torch.Tensor]] +PromptFormatSignature = Callable[[object, object], PromptFormatFnReturnType] +PROMPT_FORMAT_FNS: dict[tuple[Type, Type] | Type, PromptFormatSignature] = {} + + +def registered_prompt_format_fn(example_type: Type, formatter_type: Type = None): + """ + Decorator for registering text prompt functions. + It allows to select the right prompt formatting function based on the types of the + example and the prompt formatter, allowing different strategies for formatting different + types of data with different prompt formats. + + When formatter_type is set None, registers a default prompt format function for a given data type. + + Example:: + + >>> @registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) + ... def my_src_tgt_text_prompt(example, formatter): + ... pass + ... + ... @registered_prompt_format_fn(Cut, Llama2PromptFormatter) + ... def my_audio_prompt(example, formatter): + ... pass + ... + ... prompt_fn = get_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) + """ + + def _decorator(prompt_fn: Callable[[object, object], dict[str, list[torch.Tensor]]]): + global PROMPT_FORMAT_FNS + if formatter_type is None: + PROMPT_FORMAT_FNS[example_type] = prompt_fn + else: + PROMPT_FORMAT_FNS[(example_type, formatter_type)] = prompt_fn + return prompt_fn + + return _decorator + + +def get_prompt_format_fn(example: Type | object, prompt: Type | object = None) -> PromptFormatSignature: + """See the documentation of ``text_prompt_formatter`` above.""" + + # If the user provided objects, resolve their types. + if not isinstance(example, type): + example = type(example) + if not isinstance(prompt, type): + prompt = type(prompt) + + # For the example type, first try to match it directly, then fall back to its parent classes. + for example_subtype in example.mro(): + + # First check the match for specific example type and a specific prompt format, + # and all parent types of that specific prompt formatter type. + for prompt_subtype in prompt.mro(): + if (example_subtype, prompt_subtype) in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[(example_subtype, prompt_subtype)] + + # Then for the same specific example type, fall back to its default prompt formatter implementation. + if example_subtype in PROMPT_FORMAT_FNS: + return PROMPT_FORMAT_FNS[example_subtype] + + raise ValueError( + f"Unknown prompt format function for ({example}, {prompt}). " + f"Available choices are: {list(PROMPT_FORMAT_FNS.keys())}" + ) + + +def apply_prompt_format_fn(example: object | Type, prompt: object | Type) -> PromptFormatFnReturnType: + """ + Utility for resolving the prompt format function and applying it to an example in one go. + See the documentation of ``text_prompt_formatter`` above. + """ + fn = get_prompt_format_fn(example, prompt) + return fn(example, prompt) diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py index 99950f2b3a98..77e6c346fd4d 100644 --- a/nemo/collections/common/prompts/__init__.py +++ b/nemo/collections/common/prompts/__init__.py @@ -1,5 +1,4 @@ from nemo.collections.common.prompts.canary import CanaryPromptFormatter -from nemo.collections.common.prompts.fn import get_prompt_format_fn, registered_prompt_format_fn from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.prompts.gemma import GemmaPromptFormatter from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter @@ -10,3 +9,4 @@ Phi2QAPromptFormatter, ) from nemo.collections.common.prompts.plain import PlainPromptFormatter +from nemo.collections.common.prompts.t5nmt import T5NMTPromptFormatter diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index 0eb3296bcff9..eb7412920576 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -1,13 +1,12 @@ from typing import Any import torch -from lhotse import CutSet, MonoCut -from lhotse.cut import MixedCut +from lhotse import MonoCut +from lhotse.cut import Cut, MixedCut from lhotse.utils import ifnone -from nemo.collections.common.prompts.fn import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec from nemo.collections.common.tokenizers.canary_tokenizer import ( CANARY_BOS, CANARY_EOS, @@ -94,10 +93,8 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s return slot_values -@registered_prompt_format_fn -def canary( - cuts: CutSet, tokenizer: TokenizerSpec -) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: +@registered_prompt_format_fn(Cut, CanaryPromptFormatter) +def canary(cut: Cut, prompt: CanaryPromptFormatter) -> dict[str, torch.Tensor]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -120,62 +117,51 @@ def canary( (i.e., spoken language in the recording) and the second occurrence is for the "target" language (i.e., the language in which we are going to output the text). """ - formatter = CanaryPromptFormatter(tokenizer) - - prompts_with_answers, prompts, answers = [], [], [] - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut._first_non_padding_cut - if not isinstance(cut, MonoCut): - raise TypeError( - f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" - ) - - # first, validate the utterance - expected_slots = set(formatter.get_slots("user")) - missing_keys = expected_slots - set(cut.custom) - if "task" in missing_keys and "taskname" in cut.custom: - # Compatibility with "old" Canary manifest format. - # For compatbility with inference options, this slot is now called "task". - cut.custom["task"] = cut.custom["taskname"] - missing_keys.remove("task") - if missing_keys: - raise RuntimeError( - f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" - f"Please ensure that every utterance in the input manifests contains these keys." - ) - - turns = [ - dict( - role="user", - slots={ - **{slot: cut.custom[slot] for slot in expected_slots}, - formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, - }, - ) - ] - # If data has no transcript, create empty response with only. - text = ' '.join(s.text for s in cut.supervisions if s.text is not None) - turns.append( - dict( - role="assistant", - slots={ - "text": text, - formatter.PROMPT_LANGUAGE_SLOT: ifnone( - cut.supervisions[0].language, cut.custom.get("target_lang") - ), - }, - ), + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" ) - encoded = formatter.encode_dialog(turns) - prompts_with_answers.append(encoded["input_ids"]) - prompts.append(encoded["context_ids"]) - if "answer_ids" in encoded: - assert ( - encoded["answer_ids"][-1].item() == formatter.tokenizer.eos - ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" - answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS - else: - answers.append([]) - - return prompts_with_answers, prompts, answers + + # first, validate the utterance + expected_slots = set(prompt.get_slots("user")) + missing_keys = expected_slots - set(cut.custom) + if "task" in missing_keys and "taskname" in cut.custom: + # Compatibility with "old" Canary manifest format. + # For compatbility with inference options, this slot is now called "task". + cut.custom["task"] = cut.custom["taskname"] + missing_keys.remove("task") + if missing_keys: + raise RuntimeError( + f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}" + f"Please ensure that every utterance in the input manifests contains these keys." + ) + + turns = [ + dict( + role="user", + slots={ + **{slot: cut.custom[slot] for slot in expected_slots}, + prompt.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, + }, + ) + ] + # If data has no transcript, create empty response with only. + text = ' '.join(s.text for s in cut.supervisions if s.text is not None) + turns.append( + dict( + role="assistant", + slots={ + "text": text, + prompt.PROMPT_LANGUAGE_SLOT: ifnone(cut.supervisions[0].language, cut.custom.get("target_lang")), + }, + ), + ) + + ans = prompt.encode_dialog(turns) + assert ( + ans["answer_ids"][-1].item() == prompt.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {ans['answer_ids']}" + ans["answer_ids"] = ans["answer_ids"][:-1] # Strip Canary's EOS + return ans diff --git a/nemo/collections/common/prompts/fn.py b/nemo/collections/common/prompts/fn.py deleted file mode 100644 index ce7d2fc8a69a..000000000000 --- a/nemo/collections/common/prompts/fn.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Callable, Sequence - -import torch -from lhotse import CutSet - -from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - -PROMPT_FORMAT_FNS = {} - - -def registered_prompt_format_fn( - prompt_fn: Callable[[CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]] -): - """ - Decorator for registering prompt functions under a name. - - Example:: - - >>> @registered_prompt_format_fn - ... def my_prompt(cuts, tokenizer): - ... pass - ... - ... prompt_fn = get_prompt_format_fn("my_prompt") - """ - global PROMPT_FORMAT_FNS - - PROMPT_FORMAT_FNS[prompt_fn.__name__] = prompt_fn - return prompt_fn - - -def get_prompt_format_fn( - name: str, -) -> Callable[[CutSet, TokenizerSpec], tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]]: - if name not in PROMPT_FORMAT_FNS: - raise ValueError( - f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" - ) - return PROMPT_FORMAT_FNS[name] diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py index 2570995625ee..205d6f8d4cac 100644 --- a/nemo/collections/common/prompts/gemma.py +++ b/nemo/collections/common/prompts/gemma.py @@ -2,14 +2,10 @@ Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format """ -from collections import defaultdict +from lhotse.cut import Cut, MixedCut -from lhotse import CutSet -from lhotse.cut import MixedCut - -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec GEMMA_BOS = "" GEMMA_END_OF_TURN = "" @@ -38,29 +34,19 @@ class GemmaPromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def gemma1(cuts: CutSet, tokenizer: TokenizerSpec): - prompt = GemmaPromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - if cut.has_custom("context"): - context = cut.context - elif cut.has_custom("question"): - context = cut.question - else: - context = cut.default_context - - turns = [] - if cut.has_custom("system_prompt"): - turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) - else: - turns.append({"role": "user", "slots": {"message": context}}) - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) +@registered_prompt_format_fn(Cut, GemmaPromptFormatter) +def gemma1(cut: Cut, prompt: GemmaPromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) + turns = [{"role": "user", "slots": {"message": context}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) - return ans + return prompt.encode_dialog(turns) diff --git a/nemo/collections/common/prompts/llama.py b/nemo/collections/common/prompts/llama.py index 7defc49cc61a..e011039e7c34 100644 --- a/nemo/collections/common/prompts/llama.py +++ b/nemo/collections/common/prompts/llama.py @@ -1,12 +1,9 @@ -from collections import defaultdict +import torch +from lhotse.cut import Cut, MixedCut -from lhotse import CutSet -from lhotse.cut import MixedCut -from lhotse.utils import ifnone - -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import BOS_SLOT, EOS_SLOT, Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec class Llama2PromptFormatter(PromptFormatter): @@ -40,32 +37,67 @@ class Llama2PromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def llama2(cuts: CutSet, tokenizer: TokenizerSpec): - prompt = Llama2PromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - if cut.has_custom("context"): - context = cut.context - elif cut.has_custom("question"): - context = cut.question - else: - context = cut.default_context - - turns = [] - if cut.has_custom("system_prompt"): - turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) - else: - turns.append({"role": "user", "slots": {"message": context}}) - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) - - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) - - return ans +@registered_prompt_format_fn(Cut, Llama2PromptFormatter) +def llama2(cut: Cut, prompt: Llama2PromptFormatter) -> dict[str, torch.Tensor]: + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context + + turns = [] + if cut.has_custom("system_prompt"): + turns.append({"role": "system_and_user", "slots": {"system": cut.system_prompt, "message": context}}) + else: + turns.append({"role": "user", "slots": {"message": context}}) + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) + return prompt.encode_dialog(turns) + + +@registered_prompt_format_fn(SourceTargetTextExample, Llama2PromptFormatter) +def llama2_src_tgt_text_example(example: SourceTargetTextExample, prompt: Llama2PromptFormatter): + if example.question is not None: + user_turn = { + "role": "system_and_user", + "slots": {"system": example.question.text, "message": example.source.text}, + } + else: + user_turn = { + "role": "user", + "slots": {"message": example.source.text}, + } + return prompt.encode_dialog( + [ + user_turn, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) + + +@registered_prompt_format_fn(NeMoSFTExample, Llama2PromptFormatter) +def llama2_sft_text_example(example: NeMoSFTExample, prompt: Llama2PromptFormatter): + first_user_turn = example.data["conversations"][0]["value"] + if "system" in example.data and example.data["system"]: + first_turn = { + "role": "system_and_user", + "slots": {"system": example.data["system"], "message": first_user_turn}, + } + else: + first_turn = { + "role": "user", + "slots": {"message": first_user_turn}, + } + return prompt.encode_dialog( + [first_turn] + + [ + {"role": "user" if turn["from"] == "User" else prompt.OUTPUT_ROLE, "slots": {"message": turn["value"]}} + for turn in example.data["conversations"][1:] + ] + ) LLAMA3_BOS = "<|begin_of_text|>" diff --git a/nemo/collections/common/prompts/plain.py b/nemo/collections/common/prompts/plain.py index efd7d989a9e2..de7fbe5a1830 100644 --- a/nemo/collections/common/prompts/plain.py +++ b/nemo/collections/common/prompts/plain.py @@ -1,11 +1,7 @@ -from collections import defaultdict +from lhotse.cut import Cut, MixedCut -from lhotse import CutSet -from lhotse.cut import MixedCut - -from nemo.collections.common.prompts import registered_prompt_format_fn +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn from nemo.collections.common.prompts.formatter import Modality, PromptFormatter -from nemo.collections.common.tokenizers import TokenizerSpec class PlainPromptFormatter(PromptFormatter): @@ -31,20 +27,17 @@ class PlainPromptFormatter(PromptFormatter): } -@registered_prompt_format_fn -def plain(cuts: CutSet, tokenizer: TokenizerSpec): - prompt = PlainPromptFormatter(tokenizer) - ans = defaultdict(list) - for cut in cuts: - if isinstance(cut, MixedCut): - cut = cut.first_non_padding_cut - assert cut.has_custom("context"), f"Missing mandatory metadata key 'context' in {cut=}" - - turns = [{"role": "user", "slots": {"message": cut.context}}] - if (answer := cut.supervisions[0].text) is not None: - turns.append({"role": "assistant", "slots": {"message": answer}}) +@registered_prompt_format_fn(Cut, PlainPromptFormatter) +def plain(cut: Cut, prompt: PlainPromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + ctx = cut.context + else: + ctx = "" - for k, v in prompt.encode_dialog(turns).items(): - ans[k].append(v) + turns = [{"role": "user", "slots": {"message": ctx}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) - return ans + return prompt.encode_dialog(turns) diff --git a/nemo/collections/common/prompts/t5nmt.py b/nemo/collections/common/prompts/t5nmt.py new file mode 100644 index 000000000000..0d89adcdb55a --- /dev/null +++ b/nemo/collections/common/prompts/t5nmt.py @@ -0,0 +1,91 @@ +from collections import defaultdict + +import torch +from lhotse import MonoCut +from lhotse.cut import Cut, MixedCut + +from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample +from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn +from nemo.collections.common.prompts.formatter import Modality, PromptFormatter + + +class T5NMTPromptFormatter(PromptFormatter): + """ + The default prompt format for Megatron T5 based neural machine translation models. + Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py#L790 + """ + + NAME = "t5nmt" + OUTPUT_ROLE = "assistant" + TEMPLATE = { + "user": { + "template": f"|target_lang| |message|", + "slots": { + "target_lang": Modality.Text, + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|", + "slots": { + "message": Modality.Text, + }, + }, + } + + def encode_turn(self, prompt_template: str, expected_slots: dict, slot_values: dict) -> list[int]: + # Automatically adds "<" and ">" to target lang token for T5 NMT. + # Based on: https://github.com/NVIDIA/NeMo/blob/ad5ef750e351edbb5eeb7eb6df2d0c804819600f/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py#L307 + if (val := slot_values.get("target_lang")) is not None: + if not val.startswith("<") or not val.endswith(">"): + slot_values["target_lang"] = f"<{val}>" + return super().encode_turn( + prompt_template=prompt_template, expected_slots=expected_slots, slot_values=slot_values + ) + + +@registered_prompt_format_fn(Cut, T5NMTPromptFormatter) +def t5nmt(cut: Cut, prompt: T5NMTPromptFormatter) -> dict[str, torch.Tensor]: + ans = defaultdict(list) + if isinstance(cut, MixedCut): + cut = cut._first_non_padding_cut + if not isinstance(cut, MonoCut): + raise TypeError( + f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})" + ) + + if hasattr(cut, "context"): + context = cut.context + elif hasattr(cut, "default_context"): + context = cut.default_context + else: + raise RuntimeError("Missing context/default_context custom field in cut: {cut}") + + turns = [ + dict( + role="user", + # "message" slot is the audio portion of the cut; currently it is populated inside model's forward. + slots={"target_lang": context, "message": ""}, + ), + ] + if len(cut.supervisions) > 0 and cut.supervisions[0].text is not None: + turns.append( + dict( + role=prompt.OUTPUT_ROLE, + slots={"message": cut.supervisions[0].text}, + ) + ) + return prompt.encode_dialog(turns) + + +@registered_prompt_format_fn(SourceTargetTextExample, T5NMTPromptFormatter) +def t5nmt_src_tgt_text_example(example: SourceTargetTextExample, prompt: T5NMTPromptFormatter): + ctx = f"<{example.target.language}>" + if example.has_custom("extra_prompt"): + ctx = f"{ctx} {example.extra_prompt}" + return prompt.encode_dialog( + [ + {"role": "user", "slots": {"message": example.source.text, "target_lang": ctx}}, + {"role": prompt.OUTPUT_ROLE, "slots": {"message": example.target.text}}, + ] + ) diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py index 15fb1a587789..8d64632210a4 100644 --- a/nemo/collections/multimodal/speech_llm/data/build_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -14,6 +14,7 @@ import copy from pathlib import Path +import omegaconf import torch from megatron.core import parallel_state from omegaconf.omegaconf import OmegaConf @@ -110,7 +111,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict # for eval, we need to create separate dataset so as to report splitted numbers else: dls = [] - if hasattr(data_cfg, 'manifest_filepath'): + if data_cfg.get('manifest_filepath') is not None: manifest_filepath = data_cfg.manifest_filepath for cur_manifest_filepath in manifest_filepath: conf = copy.deepcopy(data_cfg) @@ -121,6 +122,7 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), dataset=dataset, + tokenizer=dataset.text_processor.tokenizer, ) ) else: @@ -131,16 +133,25 @@ def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict assert len(input_cfg) == 1, "Only one dataset with multiple manifest paths is supported for eval" data_cfg.input_cfg = input_cfg # for getting names - manifest_filepath = [ic.manifest_filepath for ic in input_cfg[0].input_cfg] + manifest_filepath = [] + for ic in input_cfg[0].input_cfg: + if hasattr(ic, "manifest_filepath"): + manifest_filepath.append(ic.manifest_filepath) + else: + assert ic.type == "txt_pair" + manifest_filepath.append(ic.target_paths) for cur_input_cfg in input_cfg[0].input_cfg: conf = copy.deepcopy(data_cfg) conf.input_cfg[0].input_cfg = [cur_input_cfg] + OmegaConf.set_struct(conf, False) + conf.force_finite = True dls.append( get_lhotse_dataloader_from_config( conf, global_rank=parallel_state.get_data_parallel_rank(), world_size=parallel_state.get_data_parallel_world_size(), dataset=dataset, + tokenizer=dataset.text_processor.tokenizer, ) ) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 7cb0eb7cb9b5..b4cfea49ccc0 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -57,6 +57,7 @@ from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import adapter_mixins +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import AppState, logging, model_utils from nemo.utils.model_utils import inject_model_parallel_rank @@ -1202,6 +1203,16 @@ def load_state_dict(self, state_dict, strict: bool = True): else: super(MegatronGPTModel, self).load_state_dict(state_dict, strict=strict) + def on_train_epoch_start(self) -> None: + app_state = AppState() + reconfigure_num_microbatches_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + def on_load_checkpoint(self, checkpoint) -> None: """LightningModule hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint @@ -1263,9 +1274,14 @@ def inference_step(self, dataloader_iter, mode): # Evaluation of multimodal data follows the same pattern as training except predict_step batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds - self._reconfigure_and_process_inference_batch(batch, data_cfg) - # Meta data from dataset - metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + if "tokens" in batch: + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + else: + batch["tokens"] = batch["text_input_ids"] + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + batch.pop("tokens") loss = super(MegatronGPTSFTModel, self).validation_step(itertools.chain([batch]), dataloader_idx) # We need _inference_config to get generation params @@ -1278,12 +1294,22 @@ def inference_step(self, dataloader_iter, mode): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] - labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] - preds_text = [ - self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) - for t, l in zip(output['token_ids'], batch['context_lengths']) - ] + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + if audio_batch: + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in audio_batch['contexts']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in audio_batch['answers']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], audio_batch['context_lengths']) + ] + else: + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in text_batch['text_context_ids']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in text_batch['text_answer_ids']] + preds_text = [ + self.tokenizer.ids_to_text(t[l.item() :][: data_cfg.get('tokens_to_generate')]) + for t, l in zip(output['token_ids'], text_batch['text_context_lens']) + ] if data_cfg.get("end_string", None): # sometimes data_cfg.end_string != self.tokenizer.ids_to_text(self.tokenizer.text_to_ids(data_cfg.end_string)) @@ -1380,6 +1406,12 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int # for megatron_gpt_eval.py if isinstance(batch, list): inference_config['inputs'] = batch + elif "text_context_ids" in batch: + # Text mini-batch + inference_config['inputs'] = ( + batch['text_context_ids'].cuda(), + batch['text_context_lens'].cuda(), + ) elif 'num_audios' in batch: # peft_eval.py inference_config['inputs'] = ( @@ -1410,7 +1442,8 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int ) # add audio offsets to context lengths for properly decoding only the response - batch['context_lengths'] = batch['context_lengths'].cuda() + response['audio_feat_lens'] + if 'context_lengths' in batch: + batch['context_lengths'] = batch['context_lengths'].cuda() + response['audio_feat_lens'] return response @@ -1726,6 +1759,67 @@ def find_frozen_submodules(model): self.perception = self.trainer.strategy._setup_model(self.perception) self.perception = self.perception.cuda(torch.cuda.current_device()) + def oomptimizer_schema(self, schema: str = "audio") -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + + if schema == "audio": + return { + "cls": dict, + "inputs": [ + {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "tokens", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "tokens_length", + "type": NeuralType(("B",), LengthsType()), + "seq_length": "output", + }, + { + "name": "labels", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "loss_mask", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "output", + }, + { + "name": "context_start_idx", + "type": "constant", + "value": 0, + }, + ], + } + elif schema == "text": + return { + "cls": dict, + "inputs": [ + { + "name": "text_input_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "input", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "text_masks", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "input", + }, + ], + } + else: + raise RuntimeError(f"Unknown schema type for oomptimizer of class {type(self)}: '{schema}'") + class CrossAttendModularAudioGPTModel(ModularAudioGPTModel): """Modularized speech GPT model.""" diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 0e1bf46ad2f9..e3315e6f0025 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -49,6 +49,7 @@ from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.core.classes.mixins import adapter_mixins +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import AppState, logging, model_utils try: @@ -343,13 +344,8 @@ def prepare_llm_input(self, audio_batch): input_signal = audio_batch['audio_signal'] input_signal_length = audio_batch['audio_signal_length'] - - input_ids, input_length, labels, loss_mask = ( - audio_batch['contexts'], - audio_batch['context_lengths'], - audio_batch['labels'], - audio_batch['loss_mask'], - ) + input_ids = audio_batch['contexts'] + input_length = audio_batch['context_lengths'] # [b, t, c] encoded, encoded_len = self.perception( @@ -367,7 +363,7 @@ def prepare_llm_input(self, audio_batch): def forward( self, - audio_batch, + batch, checkpoint_activations_all_layers, ): """Forward pass of the model. @@ -375,39 +371,60 @@ def forward( We prepend audio embeddings to the instruction and label text tokens as the LLM input. """ - if 'audio_ratio' in audio_batch: - self.log( - 'audio_ratio', audio_batch['audio_ratio'].mean(), prog_bar=True, batch_size=1, rank_zero_only=False + + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + + multimodal_output = {} + + if 'audio_signal' in audio_batch: + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) + # enc_input = speech and text prompt + # dec_input and label = text output label + b = audio_batch['answers'].shape[0] + labels = audio_batch['answers'] + dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=labels.device), labels[:, :-1]], dim=-1) + dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() + output = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=encoder_input, ) - self.log( - 'local_batch_size', - audio_batch['audio_ratio'].shape[0], - prog_bar=True, - batch_size=1, - rank_zero_only=False, + loss_mask = dec_mask + multimodal_output['audio_text'] = (output, loss_mask) + + if text_batch: + b = text_batch['text_answer_ids'].shape[0] + encoder_input_ids = text_batch["text_context_ids"] + enc_mask = (encoder_input_ids != self.tokenizer.pad_id).long().contiguous() + decoder_input_ids = torch.cat( + [ + torch.full([b, 1], self.bos_id, device=encoder_input_ids.device), + text_batch["text_answer_ids"][:, :-1], + ], + dim=-1, ) + labels = text_batch["text_answer_ids"] + dec_mask = (decoder_input_ids != self.tokenizer.pad_id).long().contiguous() + loss_mask = dec_mask + output = self.frozen_model.enc_dec_model( + enc_input_ids=encoder_input_ids, + enc_attn_mask=enc_mask, + dec_input_ids=decoder_input_ids, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=None, + ) + multimodal_output['text'] = (output, loss_mask) - encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) - # enc_input = speech and text prompt - # dec_input and label = text output label - b = audio_batch['answers'].shape[0] - device = audio_batch['answers'].device - dec_input = audio_batch['masked_answer_ids'] if 'masked_answer_ids' in audio_batch else audio_batch['answers'] - dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=device), dec_input[:, :-1]], dim=-1) - labels = audio_batch['answers'] - dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() - output = self.frozen_model.enc_dec_model( - enc_input_ids=None, - enc_attn_mask=enc_mask, - dec_input_ids=dec_input, - dec_attn_mask=dec_mask, - token_type_ids=None, - labels=labels, - output_enc_hidden_only=False, - enc_input=encoder_input, - ) - loss_mask = dec_mask - return output, loss_mask + return multimodal_output def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): @@ -449,21 +466,42 @@ def get_forward_output_and_loss_func(self, validation_step=False): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): batch = next(dataloader_iter) batch = {key: val.cuda(non_blocking=True) for key, val in batch.items()} - output_tensor, loss_mask = self.forward( + multimodal_output = self.forward( batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers ) - def loss_func(output_tensor): + def loss_func(multimodal_output): # Loss for a micro-batch (ub) - if 'audio_ratio' in batch: - text_loss_weight = self.cfg.get('text_loss_weight', 1.0) - audio_ratio = batch['audio_ratio'] - scaled_loss_mask = loss_mask * torch.unsqueeze( - (1 * audio_ratio + text_loss_weight * (1 - audio_ratio)), 1 + loss_for_ub = None + + modality_weights = self.cfg.get("modality_loss_weights") + + for key, (output, loss_mask) in multimodal_output.items(): + cur_loss = self.loss_func(loss_mask.contiguous(), output.contiguous()) + if modality_weights is not None: + assert ( + key in modality_weights + ), f"Expected cfg.modality_loss_weights={modality_weights} to contain key {key}" + cur_loss = cur_loss * modality_weights[key] + if loss_for_ub is None: + loss_for_ub = cur_loss + else: + loss_for_ub += cur_loss + self.log( + f'{key}_loss', + cur_loss.mean(), + prog_bar=True, + batch_size=1, + rank_zero_only=False, ) - loss_for_ub = self.loss_func(scaled_loss_mask, output_tensor) - else: - loss_for_ub = self.loss_func(loss_mask, output_tensor) + self.log( + f'{key}_batch_size', + loss_mask.shape[0], + prog_bar=True, + batch_size=1, + rank_zero_only=False, + ) + if validation_step and not self.cfg.data.get('validation_drop_last', True): num_valid_tokens_in_ub = batch['loss_mask'].sum() if loss_for_ub.isnan(): @@ -487,10 +525,20 @@ def loss_func(output_tensor): reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) return loss_for_ub, {'avg': reduced_loss} - return output_tensor, loss_func + return multimodal_output, loss_func return fwd_output_and_loss_func + def on_train_epoch_start(self) -> None: + app_state = AppState() + reconfigure_num_microbatches_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + def _build_dataset(self, data_cfg, is_train=True): return build_speechllm_dataset(self, data_cfg, is_train) @@ -876,9 +924,14 @@ def _validation_step_internal( def inference_step(self, dataloader_iter, mode, dataloader_idx=0): batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds - self._reconfigure_and_process_inference_batch(batch, data_cfg) - # Meta data from dataset - metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + if "tokens" in batch: + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + else: + batch["tokens"] = batch["text_context_ids"] + self._reconfigure_and_process_inference_batch(batch, data_cfg) + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + batch.pop("tokens") loss = self._validation_step_internal(itertools.chain([batch]), batch_idx, dataloader_idx, result_mode=mode) # We need _inference_config to get generation params @@ -891,8 +944,8 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): output = self.predict_step(batch, batch_idx, dataloader_idx) - inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] - labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + inputs_text = output["input_text"] + labels_text = output["labels_text"] preds_text = output['preds_text'] if data_cfg.get("log_every_n_steps", None) is not None: if batch_idx % data_cfg.log_every_n_steps == 0: @@ -923,25 +976,42 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): return outputs def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + # the following supports STT (audio-text) inference batch = move_to_device(batch, device=self.device) - encoder_input, attention_mask, enc_mask = self.prepare_llm_input(batch) - # enc_input = speech and text prompt - # dec_input and label = text output label - predicted_token_ids, log_probs = self.frozen_model.decode( - tokens_enc=None, - enc_mask=enc_mask, - num_tokens_to_generate=self._inference_config['tokens_to_generate'], - encoder_input=encoder_input, - tokenizer=self.tokenizer, - bos_id=self.bos_id, - ) + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + assert ( + audio_batch or text_batch and not (audio_batch and text_batch) + ), f"Expecting only text or audio batch, got {len(text_batch)=} and {len(audio_batch)=}" + + if audio_batch: + input_text = audio_batch['contexts'] + labels = audio_batch['answers'] + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=None, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + encoder_input=encoder_input, + tokenizer=self.tokenizer, + bos_id=self.bos_id, + ) + if text_batch: + input_text = text_batch['text_context_ids'] + labels = text_batch["text_answer_ids"] + enc_mask = (input_text != self.tokenizer.pad_id).long().contiguous() + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=input_text, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + tokenizer=self.tokenizer, + bos_id=self.bos_id, + ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. - input_text = batch['contexts'] preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) input_text = MegatronT5SFTModel.ids_to_text(input_text, self.tokenizer) - labels = batch['answers'] if labels is not None: labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) @@ -1175,68 +1245,99 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): batch = next(dataloader_iter) # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} - _, seq_length = batch['tokens'].shape - # handle the case where the batch size from dynamic bucketting is not divisible in lhotse - data_iter = get_iterator_k_split(batch, get_num_microbatches(), enforce_divisible_batch=False) - - # handle asynchronous grad reduction - no_sync_func = None - grad_sync_func = None - param_sync_func = None - if not forward_only and self.with_distributed_adam: - no_sync_func = partial( - self._optimizer.no_sync, - greedy_grad_copy=self.megatron_amp_O2, - ) - grad_sync_func = self.reduce_overlap_gradients - param_sync_func = self.sync_overlap_parameters - - self.model.config.no_sync_func = no_sync_func - self.model.config.grad_sync_func = grad_sync_func - self.model.config.param_sync_func = param_sync_func - - fwd_bwd_function = get_forward_backward_func() - - dec_seq_length = batch['answers'].shape[1] - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(), - data_iterator=data_iter, - model=[self.model], - num_microbatches=get_num_microbatches(), - forward_only=forward_only, - seq_length=seq_length, - micro_batch_size=get_micro_batch_size(), - decoder_seq_length=dec_seq_length, - ) - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - if (not forward_only) or self.cfg.data.get('validation_drop_last', True): - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + + # Note: We want to perform full fwd+bwd separately for each modality, + # as it allows us to save GPU memory. Otherwise, we'd have to + # hold the activations from one modality in memory while running + # forward for the other. + batch_losses = [] + for batch in (audio_batch, text_batch): + if not batch: + continue + + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() + batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + + # TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities, + # but I feel like this needs larger refactoring + if 'tokens' in batch and 'text_input_ids' in batch: + seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1]) + dec_seq_length = max(batch['answers'].shape[1], batch['text_answer_ids'].shape[1]) + elif 'tokens' in batch: + seq_length = batch['tokens'].shape[1] + dec_seq_length = batch['answers'].shape[1] + elif 'text_input_ids' in batch: + seq_length = batch['text_input_ids'].shape[1] + dec_seq_length = batch['text_answer_ids'].shape[1] else: - # Get the total loss since micro batches sizes are not uniform - loss_sum_tensors_list = [ - loss_sum['loss_sum_and_ub_size'] - for loss_sum in losses_reduced_per_micro_batch - if loss_sum['loss_sum_and_ub_size'][1] > 0 - ] - loss_sum = ( - torch.vstack(loss_sum_tensors_list).sum(axis=0) - if len(loss_sum_tensors_list) > 0 - else torch.tensor([0.0, 0.0]).cuda() + seq_length = None # TODO(pzelasko): not sure if it is even needed ??? + dec_seq_length = None + + # handle the case where the batch size from dynamic bucketting is not divisible in lhotse + data_iter = get_iterator_k_split(batch, get_num_microbatches(), enforce_divisible_batch=False) + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, ) - return loss_sum - else: - # we're not on the last pipeline stage so no losses - if forward_only: - loss_mean = [] + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + self.model.config.no_sync_func = no_sync_func + self.model.config.grad_sync_func = grad_sync_func + self.model.config.param_sync_func = param_sync_func + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(validation_step=forward_only), + data_iterator=data_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_mean = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) else: - loss_mean = torch.tensor(0.0).cuda() + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + if loss_mean.ndim == 0: + loss_mean = loss_mean.unsqueeze(0) + batch_losses.append(loss_mean) + loss_mean = torch.cat(batch_losses).mean() return loss_mean def loss_func(self, loss_mask, output_tensor): @@ -1263,7 +1364,12 @@ def test_step(self, dataloader_iter, dataloader_idx=0): return self.inference_step(dataloader_iter, 'test') def training_step(self, dataloader_iter): - batch, batch_idx, dataloader_idx = next(dataloader_iter) + ans = next(dataloader_iter) + if isinstance(ans, tuple) and len(ans) == 3: + batch, batch_idx, dataloader_idx = ans + else: + batch = ans + batch_idx = 0 return super().training_step(itertools.chain([batch]), batch_idx=batch_idx) def setup_mcore_distributed_parallel(self): @@ -1271,6 +1377,63 @@ def setup_mcore_distributed_parallel(self): if self.with_distributed_adam and self.use_mcore_dist_optim: raise ValueError("T5 does not support both distributed adam and mcore distributed data parallel.") + def oomptimizer_schema(self, schema: str = "audio") -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ + + if schema == "audio": + return { + "cls": dict, + "inputs": [ + {"name": "audio_signal", "type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + {"name": "audio_signal_length", "type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, + { + "name": "contexts", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "context_lengths", + "type": NeuralType(("B",), LengthsType()), + "seq_length": "output", + }, + { + "name": "answers", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "loss_mask", + "type": NeuralType(("B", "T"), MaskType()), + "seq_length": "output", + }, + ], + } + elif schema == "text": + return { + "cls": dict, + "inputs": [ + { + "name": "text_context_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "input", + "vocab_size": self.tokenizer.vocab_size, + }, + { + "name": "text_answer_ids", + "type": NeuralType(("B", "T"), LabelsType()), + "seq_length": "output", + "vocab_size": self.tokenizer.vocab_size, + }, + ], + } + else: + raise RuntimeError(f"Unknown schema type for oomptimizer of class {type(self)}: '{schema}'") + class DecoderTextPromptModularizedAudioT5Model(ModularizedAudioT5Model): """Modularized speech GPT model.""" diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index 43f08afea4c9..494667c5bfb1 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -18,7 +18,8 @@ import torch from lhotse.cut import Cut -from nemo.collections.common.prompts import PromptFormatter, get_prompt_format_fn +from nemo.collections.common.data.prompt_fn import get_prompt_format_fn +from nemo.collections.common.prompts import PromptFormatter from nemo.utils import logging, logging_mode @@ -403,7 +404,8 @@ def __init__( audio_locator: Optional[str] = None, max_seq_length: Optional[int] = 8192, ): - self.prompt_format_fn = get_prompt_format_fn(prompt_format) + self.prompt = PromptFormatter.resolve(prompt_format)(tokenizer) + self.prompt_format_fn = get_prompt_format_fn(Cut, self.prompt) self.tokenizer = tokenizer self.audio_locator = audio_locator self.max_seq_length = max_seq_length @@ -418,8 +420,7 @@ def __init__( ) def _process_example(self, cut: Cut): - ans = self.prompt_format_fn([cut], self.tokenizer) - ans = {k: v[0] for k, v in ans.items()} + ans = self.prompt_format_fn(cut, self.prompt) context_start_idx = [0] if self.audio_locator_id is not None: if len(self.audio_locator_id) == 1: # fast case, special "insert audio" token diff --git a/scripts/speech_llm/estimate_token_bins.py b/scripts/speech_llm/estimate_token_bins.py new file mode 100644 index 000000000000..b198158498c1 --- /dev/null +++ b/scripts/speech_llm/estimate_token_bins.py @@ -0,0 +1,316 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import ast +import math +from functools import partial +from itertools import islice +from typing import Callable, Iterable + +import numpy as np +import pandas as pd +from lhotse.cut import Cut +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper +from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config +from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize, tokenize_with_prompt +from nemo.collections.common.data.lhotse.sampling import ( + MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, + TokenCountFilter, + TokenPerTokenFilter, +) +from nemo.collections.common.prompts.formatter import PromptFormatter +from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Estimate token bins for Lhotse dynamic bucketing using a sample of the input dataset. " + "The dataset is read either from one or more manifest files and supports data weighting. " + "Unlike estimate_duration_bins.py, this script is intended for text data only. " + "It supports 2D bucketing. ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "input", + help='Path to a data input configuration YAML file. ' + 'This is the only type of input specification supported for text data.', + ) + parser.add_argument( + "-t", + "--tokenizer", + nargs="+", + required=True, + help="Path to one or more SPE tokenizers. More than one means we'll use AggregateTokenizer and --langs argument must also be used. When provided, we'll estimate a 2D distribution for input and output sequence lengths.", + ) + parser.add_argument( + "-a", "--langs", nargs="+", help="Language names for each of AggregateTokenizer sub-tokenizers." + ) + parser.add_argument( + "-b", + "--buckets", + type=int, + default=30, + help="The desired number of buckets (dim0 => covers input sequence length / audio duration).", + ) + parser.add_argument( + "-s", + "--sub-buckets", + type=int, + default=None, + help="The desired number of sub-buckets (dim1 => covers output sequence length / num_tokens). " + "If not provided, we'll only perform 1D bucketing. ", + ) + parser.add_argument( + "-n", + "--num_examples", + type=int, + default=-1, + help="The number of examples (utterances) to estimate the bins. -1 means use all data " + "(be careful: it could be iterated over infinitely).", + ) + parser.add_argument( + "-l", + "--min_tokens", + type=float, + default=-float("inf"), + help="If specified, we'll filter out examples with less tokens than this number.", + ) + parser.add_argument( + "-u", + "--max_tokens", + type=float, + default=float("inf"), + help="If specified, we'll filter out examples with more tokens than this number.", + ) + parser.add_argument( + "--max_tpt", + type=float, + default=float("inf"), + help="If specified, we'll filter out examples with more output tokens per input token than this. ", + ) + parser.add_argument( + "-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." + ) + parser.add_argument( + "-f", + "--prompt-format", + type=str, + help="When specified, we'll use a prompt formatter in addition to the tokenizer for the purpose of estimating token count bins. " + "This is useful for accurate 2D bucket estimation with models such as EncDecMultiTaskModel (Canary-1B), " + "or any model where the label sequence consists of a user prompt and a model's response.", + ) + parser.add_argument( + "-p", + "--prompt", + type=str, + help="Prompt slots provided as a Python list of dicts. It is used together with --prompt-format option." + "For example, with Canary-1B you may use: [{'role':'user','slots':{'source_lang':'en','target_lang':'en','task':'asr','pnc':'yes'}]", + ) + return parser.parse_args() + + +def estimate_token_buckets( + cuts: Iterable[Cut], + num_buckets: int, + num_subbuckets: int | None, + quiet: bool, +) -> list[tuple[float, float]]: + """ + This function is based on lhotse.dataset.sampling.dynamic_bucketing.estimate_duration_buckets. + It extends it to a 2D bucketing case. + """ + assert num_buckets > 1 + is_2d = num_subbuckets is not None + + if is_2d: + constraint = MultimodalFixedBucketBatchSizeConstraint2D([(0.0, 0.0)], [0], measure_total_length=False) + else: + constraint = MultimodalSamplingConstraint(measure_total_length=True) + + # Gather the duration and token count statistics for the dataset. + num_input_tokens = [] + num_output_tokens = [] + for c in cuts: + ans = constraint.measure_length(c) + if is_2d: + itoks, otoks = ans + num_input_tokens.append(itoks) + num_output_tokens.append(otoks) + else: + num_input_tokens.append(ans) + num_input_tokens = np.array(num_input_tokens, dtype=np.int32) + if is_2d: + num_output_tokens = np.array(num_output_tokens, dtype=np.int32) + joint = np.rec.fromarrays([num_input_tokens, num_output_tokens]) + joint.sort() + num_input_tokens = joint.f0 + num_output_tokens = joint.f1 + else: + num_input_tokens.sort() + + # We are building buckets with equal duration (empirically leads to more even bucket exhaustion over time). + # We need to determine how much duration to allocate per bucket. + size_per_bucket = num_input_tokens.sum() / num_buckets + + if not quiet: + print("Duration distribution:") + print(pd.Series(num_input_tokens).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + max_input_tokens = num_input_tokens[-1] + + if is_2d: + tpt = num_output_tokens / num_input_tokens + if not quiet: + print("Output tokens per input token distribution:") + print(pd.Series(tpt).describe(percentiles=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99])) + max_tpt = tpt.max() + del tpt + + bins = [] + bin_indexes = [0] + tot = 0.0 + + def _estimate_output_token_buckets(max_bucket_duration): + # Since this is 2D bucketing, apply the same bin creation logic + # for the second dimension (i.e. token count) as for the first dimension (duration). + # That means we aim to have each bucket contain roughly the same number of tokens. + # Note that this estimation is biased towards more padding if you have + # a lot of zero-token examples (e.g. non-speech). + nonlocal bins + num_tokens_bucket = num_output_tokens[bin_indexes[-1] : binidx] + num_tokens_bucket.sort() + tokens_per_subbucket = num_tokens_bucket.sum() / num_subbuckets + tot_toks = 0 + # Iterate over token counts, and whenever we hit tokens_per_subbucket, create a new 2D bucket bin. + for num_toks in num_tokens_bucket: + # Threshold hit: we are creating a new (max_duration, max_num_tokens) bin. + if tot_toks > tokens_per_subbucket: + bins.append((max_bucket_duration, num_toks)) + tot_toks = 0 + tot_toks += num_toks + bins.append((size, math.ceil(size * max_tpt))) + + # Iterate over data, and whenever we hit size_per_bucket, create a new bucket bin. + for binidx, size in enumerate(num_input_tokens): + if tot > size_per_bucket: + # Threshold hit: we are creating a new duration bin (multiplied by number of token bins). + if is_2d: + _estimate_output_token_buckets(max_bucket_duration=size) + else: + bins.append(size) + tot = 0.0 + tot += size + + # Estimate an extra 2D bin set for global max duration. + if num_subbuckets is not None: + _estimate_output_token_buckets(max_bucket_duration=max_input_tokens) + + return bins + + +def load_tokenizer(paths: list[str], langs: list[str] = None) -> TokenizerWrapper: + if len(paths) == 1: + tok = SentencePieceTokenizer(paths[0]) + else: + assert langs is not None and len(paths) == len( + langs + ), f"Cannot create AggregateTokenizer; each tokenizer must have assigned a language via --langs option (we got --tokenizers={paths} and --langs={langs})" + tok = AggregateTokenizer({lang: SentencePieceTokenizer(p) for lang, p in zip(langs, paths)}) + return TokenizerWrapper(tok) + + +def apply_tokenizer(cut, tokenizer=None, prompt: PromptFormatter = None): + if prompt is not None: + cut = tokenize_with_prompt(cut, tokenizer, prompt) + elif tokenizer is not None: + cut = tokenize(cut, tokenizer) + return cut + + +class RejectionsCounter: + def __init__(self, predicate: Callable, message: str): + self.predicate = predicate + self.message = message + self.total = 0 + self.rejected = 0 + + def __call__(self, example) -> bool: + ans = self.predicate(example) + self.total += 1 + if not ans: + self.rejected += 1 + return ans + + def print_report(self) -> None: + if self.rejected: + print(f"{self.message} | Rejected {self.rejected}/{self.total} examples.") + + +def main(): + args = parse_args() + + if not args.quiet: + pd.set_option('display.float_format', lambda x: '%.2f' % x) + + tokenizer = None + prompt = None + if args.tokenizer is not None: + tokenizer = load_tokenizer(args.tokenizer, args.langs) + if args.prompt_format is not None: + prompt_defaults = None + if args.prompt is not None: + prompt_defaults = ast.literal_eval(args.prompt) + prompt = PromptFormatter.resolve(args.prompt_format)(tokenizer._tokenizer, defaults=prompt_defaults) + + assert args.input.endswith(".yaml") + config = OmegaConf.merge( + OmegaConf.structured(LhotseDataLoadingConfig), + OmegaConf.from_dotlist([f"input_cfg={args.input}"]), + ) + cuts, _ = read_cutset_from_config(config) + cuts = cuts.map(partial(apply_tokenizer, tokenizer=tokenizer, prompt=prompt), apply_fn=None) + if hasattr(cuts, "prefetch"): + cuts = cuts.prefetch() # to be released in lhotse 1.27 + token_filter = RejectionsCounter(TokenCountFilter(args.min_tokens, args.max_tokens), "Token count filtering") + cuts = cuts.filter(token_filter) + tpt_filter = RejectionsCounter(TokenPerTokenFilter(-1, args.max_tpt), "Output tokens per input token filtering") + cuts = cuts.filter(tpt_filter) + if (N := args.num_examples) > 0: + cuts = islice(cuts, N) + + token_bins = estimate_token_buckets( + cuts, + num_buckets=args.buckets, + num_subbuckets=args.sub_buckets, + quiet=args.quiet, + ) + if args.sub_buckets is not None: + token_bins = "[" + ','.join(f"[{b:d},{sb:d}]" for b, sb in token_bins) + "]" + else: + token_bins = "[" + ','.join(f"{b:d}" for b in token_bins) + "]" + if args.quiet: + print(token_bins) + return + token_filter.print_report() + tpt_filter.print_report() + print("Use the following options in your config:") + print(f"\tnum_buckets={args.buckets}") + print(f"\tbucket_duration_bins={token_bins}") + + +if __name__ == "__main__": + main() diff --git a/scripts/speech_llm/oomptimizer.py b/scripts/speech_llm/oomptimizer.py new file mode 100755 index 000000000000..63afbe743364 --- /dev/null +++ b/scripts/speech_llm/oomptimizer.py @@ -0,0 +1,578 @@ +#!/usr/bin/env python +import importlib +import math +import sys +from numbers import Number +from typing import Iterable, Literal + +import click +import pytorch_lightning as pl +import torch +from lhotse import compute_num_samples +from omegaconf import OmegaConf + +from nemo.collections.asr.models.asr_model import ASRModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType +from nemo.utils import logging + + +class ProfilingBatchGenerator: + """ + ProfilingBatchGenerator is used to generate artificial mini-batches for model training + and tracking the progress of batch size optimization. + + The high-level usage API is the following:: + + >>> gen = ProfilingBatchGenerator(schema) + ... finished = False + ... while not finished: + ... batch = gen(input_seq_len, output_seq_len) + ... try: + ... training_step(model, batch) + ... oom = False + ... except torch.cuda.OutOfMemoryError: + ... oom = True + ... finished = gen.advance(oom) + ... solution = gen.max_batch_size # The solution of the search problem. + ... gen.reset() # Can re-use for other sequence lengths now. + + The search terminates once the difference between max working batch size and min OOM batch size + divided by the latter is smaller than ``rel_gap_thresh`` that difference amounts to a single element. + For example, a max working batch size is 96 and min OOM batch size is 100 indicates a gap of 0.04, + which would terminate the search with threshold of 0.05. + + In order to generate mini-batches compatible with a given model, the generator: + + * accepts a ``schema`` argument in its constructor, and + + * accepts input/output sequence lengths in each call to generate a mini-batch. + + ``schema`` has the following structure:: + + + >>> { + ... "cls": tuple | MyBatchType, + ... "inputs": [ + ... { + ... "type": NeuralType(...) | Literal["dummy"], + ... "seq_length": Literal["input", "output"], + ... "vocab_size": int, # optional, required only for LabelsType + ... "name": str, # optional, indicates kwarg + ... }, + ... ..., + ... ] + ... } + + ``cls`` indicates how we should construct the mini-batch. Typically you can just use ``tuple`` for most + batch schemas. However, if the model expects a specific, e.g., dataclass, you can tell ``ProfilingBatchGenerator`` + to use it. The mini-batch object will be constructed using the items in ``inputs``. + + Each element of ``inputs`` specifies a NeMo NeuralType which needs to have a defined ``elements_type``. + The supported types are ``AudioSignal``, ``LengthsType`` and ``LabelsType``. + If "type" is not a NeuralType, we interpret that as a placeholder tensor that's not relevant but expected + by the model/batch constructor. In addition, ``"seq_length"`` key is used to determine whether we should apply + input or output sequence length to a given tensor. + + Optional keys: + + * ``vocab_size`` is required for ``LabelsType`` so that we can generate proper label values. + + * ``name`` is required if objects of ``cls`` have to be constructed using keyword arguments. + + A simple schema example for a model using audio/lengths tensor pair (unsupervised/self-supervised):: + + >>> { + ... "cls": tuple, + ... "inputs": [ + ... {"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, + ... {"type": NeuralType(("B"), LengthsType()), "seq_length": "input"}, + ... ] + ... } + + """ + + def __init__( + self, + schema: dict, + start_batch_size: int = 32, + rel_gap_thresh: float = 0.05, + device: str = "cuda", + ): + self.schema = schema + self.start_batch_size = start_batch_size + self.rel_gap_thresh = rel_gap_thresh + self.device = device + self.reset() + + def __call__(self, input_seq_length: int, output_seq_length: int): + B = self._current + select_seq_length = {"input": input_seq_length, "output": output_seq_length} + batch = [] + names = [] + for item in self.schema["inputs"]: + nt = item["type"] + if isinstance(nt, str) and nt == "constant": + if isinstance(val := item["value"], str) and val == "batch": + tnsr = torch.tensor([B], dtype=torch.long, device=self.device) + else: + tnsr = torch.tensor([val], dtype=torch.long, device=self.device) + elif not isinstance(nt, NeuralType): # placeholder + tnsr = torch.tensor([]) + elif isinstance(nt.elements_type, AudioSignal): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.randn(B, seq_length, dtype=torch.float32, device=self.device) + elif isinstance(nt.elements_type, LengthsType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, dtype=torch.long, device=self.device) * seq_length + elif isinstance(nt.elements_type, MaskType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, seq_length, device=self.device) + elif isinstance(nt.elements_type, LabelsType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.randint(0, item["vocab_size"], size=(B, seq_length), device=self.device) + else: + raise RuntimeError("Unexpected item in oomptimizer schema: {item}") + batch.append(tnsr) + names.append(item.get("name")) + args = [elem for name, elem in zip(names, batch) if name is None] + kwargs = {name: elem for name, elem in zip(names, batch) if name is not None} + if not kwargs and self.schema["cls"] == tuple: + return tuple(args) + return self.schema["cls"](*args, **kwargs) + + @property + def max_batch_size(self) -> int | None: + """ + Return the solution of the batch size search problem. + It will keep returning None until the search is done. + """ + if ( + self._max_ok is not None + and self._min_err is not None + and (self.current_rel_gap <= self.rel_gap_thresh or self._min_err - self._max_ok <= 1) + ): + return self._max_ok + return None + + @property + def current_rel_gap(self) -> float | None: + """ + Return the current gap between the largest batch that works and the smallest batch that triggers OOM. + The gap is defined as the batch size difference divided by the larger element. + E.g., if the best found batch size is 95 and the smallest that triggers OOM is 100, the gap is 0.05. + """ + if self._min_err is None or self._max_ok is None: + return None + return (self._min_err - self._max_ok) / self._min_err + + def reset(self): + """Reset the generator to prepare it for a new search.""" + self._current = self.start_batch_size + self._max_ok = None # max batch size that works + self._min_err = None # min batch size that doesn't work + + def advance(self, oom: bool) -> bool: + """ + Adjusts the current batch size based on the outcome. + Returns a bool indicating whether the calibration is complete. + """ + if self.max_batch_size is not None: + return True + + if oom: + # Training step failed with OOM. + # Update the minimum known batch size that causes an error. + self._min_err = min(float("inf") if self._min_err is None else self._min_err, self._current) + # Training step failed on OOM + if self._max_ok is None: + # We haven't found a batch size that works yet, keep going 2x down. + self._current = round(self._current / 2) + else: + # Try the middle-point between the known extremes. + self._current = round((self._max_ok + self._min_err) / 2) + else: + # Training step successful. + # Update the maximum known batch size that works. + self._max_ok = max(-1 if self._max_ok is None else self._max_ok, self._current) + if self._min_err is None: + # We haven't found a batch size that causes an error yet, keep going 2x higher + self._current *= 2 + else: + # Try the middle-point between the known extremes. + self._current = round((self._max_ok + self._min_err) / 2) + + if self._current == 0: + raise RuntimeError( + "We diverged and arrived batch_size=0. Perhaps the input is too large for this model and hardware." + ) + + return False + + +class FloatList(click.Option): + """Support passing bucket duration bins as [1.1,2.5,5.6,...]""" + + name = "list[float]" + + def type_cast_value(self, ctx, value): + if isinstance(value, list) and all(isinstance(v, float) for v in value): + return value + try: + import ast + + ans = ast.literal_eval(value) + if isinstance(ans[0], list): + ans = [tuple(item) for item in ans] + return ans + except ValueError: + raise click.BadParameter(value) + + +@click.command(context_settings={'show_default': True}) +@click.option( + "-n", + "--pretrained-name", + type=str, + default=None, + help="Name of a pretrained model to use, e.g. 'nvidia/canary-1b'.", +) +@click.option( + "-m", + "--module-name", + type=str, + default=None, + help="Full path to NeMo's module corresponding to CONFIG_PATH, e.g. 'nemo.collections.asr.models.EncDecMultiTaskModel'.", +) +@click.option( + "-c", "--config-path", type=str, default=None, help="Path to the training configuration file for MODULE_NAME." +) +@click.option( + "--schema", + type=str, + default="audio", + help="Which schema to use (typically used for choosing the modality, i.e., 'audio' / 'text'", +) +@click.option( + "-b", + "--buckets", + cls=FloatList, + default=[5.0, 10.0, 15.0, 20.0, 25.0, 30.0], + help="List of upper-bound bucket bins (i.e. first bucket is [0.0 - item0), second bucket is [item0 - item1), etc.). " + "We also support a nested list for 2D bucketing, e.g. [[2.0, 10],[2.0,20],[4.5,15],[4.5,30],...], " + "where each item is a pair of (max_input_seq_len, max_output_seq_len) for a given bucket.", +) +@click.option( + "-t", + "--threshold", + type=float, + default=0.05, + help="Search stopping criterion in range [0, 1], lower is more precise. Interpret as the uncerainty gap, i.e. (min_oom_batch_size - max_ok_batch_size) / min_oom_batch_size.", +) +@click.option("-s", "--start-batch-size", type=int, default=32, help="Initial batch size to start the search from.") +@click.option( + "-r", + "--ratio", + type=int, + default=12, # conservative estimate towards longer transcripts + help="The output_sequence_length to input_sequence_length ratio for the purpose of determing the maximum output sequence lengths. " + "The interpretation depends on input and output modalities. Examples: for audio->text it's tokens per second. " + "For text->audio it's seconds per token. For audio->audio it's output seconds per input second. " + "For text->text it's output tokens per input token. " + "In general larger ratio means longer output sequences and increased memory consumption. " + "The default value is set adequately for automatic speech recognition. " + "This argument is ignored when 2D buckets are provided to --buckets option. " + "For GPT-style models, use --ratio=1 ", +) +@click.option( + "-f", + "--memory-fraction", + type=float, + default=0.9, + help="Limits the use of CUDA memory for this process to MEMORY_FRACTION of the total device memory. " + "By default we force 5% memory to be unused to account for non-training-loop related CUDA memory usage" + "in actual training scripts.", +) +@click.option( + "-d", + "--device", + default="cuda:0", + help="Device string to be passed to torch.device; due to MEMORY_FRACTION option, " + "it must specify the device index (e.g. cuda:0). " + "You can also leave the default index and select a specific GPU using env var CUDA_VISIBLE_DEVICES=", +) +@click.option( + "-y", + "--dtype", + default="bfloat16", + help="Float precision to use for computation (used together with autocast).", +) +@click.option( + "--ddp/--no-ddp", + type=bool, + default=True, + help="Whether we should simulate DDP GPU RAM usage. Stores an extra copy of the model in GPU memory. Enabled by default.", +) +def oomptimizer( + pretrained_name: str | None, + module_name: str | None, + config_path: str | None, + schema: str, + buckets: list[float], + threshold: float, + start_batch_size: int, + ratio: int, + memory_fraction: float, + device: str, + dtype: str, + ddp: bool, +): + """ + OOMptimizer finds the optimal batch sizes for training your model with bucketing dataloading. + It performs a search over batch sizes until it converges by measuring the GPU memory usage for + a model's training step and optimizer update. + + \b + There are two main usage patterns: for using a pretrained model or an untrained model configuration. + The latter is more flexible but requires the user to provide two separate arguments. Examples: + * python oomptimizer.py --pretrained-name nvidia/canary-1b + * python oomptimizer.py --module-name nemo.collections.asr.models.EncDecMultiTaskModel \ + --config-path examples/asr/conf/speech_multitask/fast-conformer_aed.yaml + + Dynamic bucketing is notoriously difficult to tune as you risk running into CUDA OOM many steps into the training. + In order to simplify finding the optimal settings, OOMptimizer scans each bucket to find the maximum possible + batch size that doesn't trigger a CUDA OOM. + + \b + The suggested workflow is the following: + 1) Run scripts/speech_recognition/estimate_duration_bins.py to get the duration distribution of your data. + (consider running estimate_duration_bins_2d.py for models with a strong dependency on output sequence length + such as attention-encoder-decoder models). + 2) Run OOMptimizer to find the optimal batch sizes for your specific model, optimizer, and GPU. + 3) Use these optimal settings in your actual training script and enjoy optimal GPU utilization OOM-free. + + In the unlikely event that OOMptimizer bucket batch sizes are still leading to OOMs, + please try a lower setting of the MEMORY_FRACTION option, e.g. 0.75 (75% of GPU memory). + This may be required in very complex setups where there are additional GPU RAM loads that can't be anticipated + through the combination of training_step and optimizer update. + """ + if all(opt is None for opt in (pretrained_name, module_name, config_path)): + click.secho( + "You need to provide either PRETRAINED_NAME or the pair of MODULE_NAME and CONFIG_PATH.", fg="yellow" + ) + sys.exit(1) + logging.setLevel(logging.CRITICAL) + torch.cuda.set_per_process_memory_fraction(memory_fraction, device) + + model_clones = [] + for _ in range(2 if ddp else 1): + if pretrained_name is not None: + assert ( + config_path is None and module_name is None + ), "--pretrained-name cannot be used together with --module-name/--config-path" + click.echo(f"Intializing ASR model from pretrained checkpoint {pretrained_name}.") + trainer = pl.Trainer(barebones=True) + trainer.log_every_n_steps = 1000000 + model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device) + else: + assert config_path is not None, "--module-name requires --config-path to be specified as well." + assert module_name is not None, "--config-path requires --module-name to be specified as well." + cfg = OmegaConf.load(config_path) + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + trainer.log_every_n_steps = 1000000 + namespace, name = module_name.rsplit('.', maxsplit=1) + model_cls = getattr(importlib.import_module(namespace), name) + model = model_cls.restore_from_pretrained_models(cfg, trainer=trainer).to(device) + model.log = lambda *args, **kwargs: None + model_clones.append(model) + model = model_clones[-1] + model.init_consumed_samples = 0 + model._compute_consumed_samples_after_training_step = lambda *args, **kwargs: 1 + + from megatron.core.parallel_state import initialize_model_parallel + from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo + + initialize_model_parallel_for_nemo( + world_size=1, global_rank=0, local_rank=0, micro_batch_size=16, global_batch_size=16 + ) + torch.distributed.init_process_group("nccl", world_size=1, rank=0) + initialize_model_parallel() + + if not hasattr(model, "oomptimizer_schema"): + click.secho( + f"We read model of type {type(model)} which doesn't seem to support OOMptimizer " + f"(we could not find the property .oomptimizer_schema).", + fg="red", + ) + sys.exit(1) + + schema = model.oomptimizer_schema(schema) + + click.echo("Setting up the optimizers.") + optimizer = model.configure_optimizers() + if isinstance(optimizer, tuple): + optimizer = optimizer[0][0] + + # warmup - preallocate model/optimizer memory for all modality modules + for sch_ in ("text", "audio"): + gen_ = ProfilingBatchGenerator(model.oomptimizer_schema(sch_), start_batch_size=1) + with torch.autocast("cuda", getattr(torch, dtype)): + if sch_ == "audio": + batch_ = gen_(17519, 13) + else: + batch_ = gen_(9, 7) + optimizer.zero_grad() + out = model.training_step(iter([batch_])) + optimizer.step() + + is_2d_bucketing = all( + isinstance(item, (list, tuple)) and len(item) == 2 and all(isinstance(v, Number) for v in item) + for item in buckets + ) + # Determine modality for input and output. + modalities = [ + ( + "text" + if any( + isinstance(item["type"].elements_type, LabelsType) and item["seq_length"] == direction + for item in schema["inputs"] + if not isinstance(item["type"], str) + ) + else "audio" + ) + for direction in ("input", "output") + ] + + def get_max_seq_lens(buckets): + + def _determine_lens_for_bucket(bin): + if is_2d_bucketing: + input_len, output_len = bin + else: + input_len = bin + output_len = math.ceil(ratio * input_len) + sampling_rate = getattr( + model, "sample_rate", 16000 + ) # TODO: may need to extend schema for broader model coverage + match modalities: + case "audio", "audio": + return ( + compute_num_samples(input_len, sampling_rate=sampling_rate), + compute_num_samples(output_len, sampling_rate=sampling_rate), + ) + case "audio", "text": + return (compute_num_samples(input_len, sampling_rate=sampling_rate), output_len) + case "text", "audio": + return ( + input_len, + compute_num_samples(output_len, sampling_rate=sampling_rate), + ) + case "text", "text": + return input_len, output_len + case _: + raise RuntimeError(f"Unexpected modality combination: {_}") + + return [_determine_lens_for_bucket(bin) for bin in buckets] + + click.echo("Starting profiling.") + max_seq_lens = get_max_seq_lens(buckets) + gen = ProfilingBatchGenerator(schema=schema, start_batch_size=start_batch_size, rel_gap_thresh=threshold) + profile = {} + + # Iterate buckets from the largest to the smallest sequences. This usually ends up creating + # a tiny bit smaller batches, likely due to worse memory fragmentation. + with torch.autocast("cuda", getattr(torch, dtype)): + for bucket, (seq_len_in, seq_len_out) in reversed(list(zip(buckets, max_seq_lens))): + click.echo(f"The current sequence lengths are: input={seq_len_in} output={seq_len_out}.") + gen.reset() + batch_idx = 0 + + def step(): + click.echo( + f"\t[BEGIN step] [CUDA RAM CURRENT: {torch.cuda.memory_allocated() / (1024 * 1024):.1f}MB] [CUDA RAM MAX: {torch.cuda.max_memory_allocated() / (1024*1024):.1f}MB]" + ) + batch = gen(seq_len_in, seq_len_out) + oom = False + try: + click.echo( + f"\tCurrent settings | batch_size={gen._current} | gap: {gen.current_rel_gap}... ", nl=False + ) + optimizer.zero_grad() + # In SpeechLLM training_step performs both forward and backward; no need for manual backward + out = model.training_step(iter([batch])) + optimizer.step() + except torch.cuda.OutOfMemoryError as e: + click.secho(f"OOM!", fg="yellow") + oom = True + except RuntimeError as e: + if "cuFFT error: CUFFT_INTERNAL_ERROR" not in str(e): + raise + click.secho(f"OOM!", fg="yellow") + oom = True + else: + click.secho(f"OK!", fg="green") + finally: + click.echo( + f"\t[END step] [CUDA RAM CURRENT: {torch.cuda.memory_allocated() / (1024 * 1024):.1f}MB] [CUDA RAM MAX: {torch.cuda.max_memory_allocated() / (1024*1024):.1f}MB]" + ) + del batch + # Note: We could call empty_cache() to free up some more memory on the GPU, + # but we have found out empirically that this causes a mismatched condition + # between OOMptimizer and the actual training. During training, there is some + # degree of memory fragmentation and it's better to simulate that in OOMptimizer. + # torch.cuda.memory.empty_cache() + torch.cuda.reset_max_memory_allocated() + return oom + + oom = step() + while not (finished := gen.advance(oom)): + click.echo("\t" + "=" * 80) + oom = step() + + click.secho( + f"=> Optimal setting for bucket={bucket} (input={seq_len_in} output={seq_len_out}) is max_batch_size={gen.max_batch_size}", + fg="green", + ) + profile[(bucket, seq_len_in, seq_len_out)] = gen.max_batch_size + gen.start_batch_size = gen.max_batch_size * 2 + + # Reverse the profile to be ascendingly sorted again. + profile = dict(reversed(list(profile.items()))) + + click.echo("The 1st stage profile is:") + for (bucket, seq_len_in, seq_len_out), bs in profile.items(): + click.echo(f"Bucket={bucket} (input={seq_len_in} output={seq_len_out}) => max_batch_size={bs}") + + if is_2d_bucketing: + # 2D bucketing doesn't support bucket merging. + final_profile = [["[" + ",".join(map(str, b)) + "]", bs] for (b, _, __), bs in profile.items()] + max_input_len, max_output_len = buckets[-1] + ratio = max_output_len / max_input_len + else: + click.echo("Bucket merging stage...") + final_profile = [] + for idx, ((bucket, seq_len_in, seq_len_out), bs) in enumerate(profile.items()): + if idx == 0: + final_profile.append([bucket, bs]) + continue + if bs == final_profile[-1][1]: + click.echo(f"Merging bucket {idx} with bucket {idx-1} due to identical batch sizes.") + final_profile[-1][0] = bucket + continue + final_profile.append([bucket, bs]) + max_input_len = final_profile[-1][0] + + click.secho(f"The profile was created with the following settings:") + click.secho(f"* using {memory_fraction:.1%} of available GPU RAM.") + click.secho(f"* {'' if ddp else 'not '}simulating DDP memory overhead.") + click.secho(f"* using AMP with dtype={dtype}.") + click.secho("The final profile is:", bold=True) + click.secho("\tbucket_duration_bins=[" + ",".join(str(seqlen) for seqlen, bs in final_profile) + "]", bold=True) + click.secho("\tbucket_batch_size=[" + ",".join(str(bs) for seqlen, bs in final_profile) + "]", bold=True) + click.secho("\t(The following flags are suitable for ASR/speech-to-text models):") + click.secho(f"\tmax_tps={ratio}", bold=True) + click.secho(f"\tmax_duration={max_input_len}", bold=True) + + +if __name__ == "__main__": + oomptimizer() diff --git a/scripts/speech_recognition/estimate_duration_bins_2d.py b/scripts/speech_recognition/estimate_duration_bins_2d.py index 52d5b3620a2a..0f4a021e09cc 100644 --- a/scripts/speech_recognition/estimate_duration_bins_2d.py +++ b/scripts/speech_recognition/estimate_duration_bins_2d.py @@ -27,12 +27,11 @@ from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config -from nemo.collections.common.data.lhotse.dataloader import ( +from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig, tokenize +from nemo.collections.common.data.lhotse.sampling import ( DurationFilter, FixedBucketBatchSizeConstraint2D, - LhotseDataLoadingConfig, TokenPerSecondFilter, - tokenize, ) from nemo.collections.common.prompts.formatter import PromptFormatter from nemo.collections.common.tokenizers import AggregateTokenizer, SentencePieceTokenizer diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 165ac5ac692d..22b08a9cebee 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -12,7 +12,7 @@ from omegaconf import OmegaConf from nemo.collections.asr.models.asr_model import ASRModel -from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import logging @@ -111,7 +111,12 @@ def __call__(self, input_seq_length: int, output_seq_length: int): names = [] for item in self.schema["inputs"]: nt = item["type"] - if not isinstance(nt, NeuralType): # placeholder + if isinstance(nt, str) and nt == "constant": + if isinstance(val := item["value"], str) and val == "batch": + tnsr = torch.tensor([B], dtype=torch.long, device=self.device) + else: + tnsr = torch.tensor([val], dtype=torch.long, device=self.device) + elif not isinstance(nt, NeuralType): # placeholder tnsr = torch.tensor([]) elif isinstance(nt.elements_type, AudioSignal): seq_length = select_seq_length[item["seq_length"]] @@ -122,6 +127,9 @@ def __call__(self, input_seq_length: int, output_seq_length: int): elif isinstance(nt.elements_type, LabelsType): seq_length = select_seq_length[item["seq_length"]] tnsr = torch.randint(0, item["vocab_size"], size=(B, seq_length), device=self.device) + elif isinstance(nt.elements_type, MaskType): + seq_length = select_seq_length[item["seq_length"]] + tnsr = torch.ones(B, seq_length, device=self.device) else: raise RuntimeError("Unexpected item in oomptimizer schema: {item}") batch.append(tnsr) diff --git a/tests/collections/asr/test_asr_multitask_model_bpe.py b/tests/collections/asr/test_asr_multitask_model_bpe.py index 3b3268423812..df91ad4f5e2f 100644 --- a/tests/collections/asr/test_asr_multitask_model_bpe.py +++ b/tests/collections/asr/test_asr_multitask_model_bpe.py @@ -404,7 +404,9 @@ def test_predict_step(self, asr_model, test_data_dir): c.target_lang = "en" c.task = "asr" c.pnc = "no" - dataset = PromptedAudioToTextLhotseDataset(tokenizer=asr_model.tokenizer, prompt_format_fn=canary) + dataset = PromptedAudioToTextLhotseDataset( + tokenizer=asr_model.tokenizer, prompt=CanaryPromptFormatter(asr_model.tokenizer) + ) batch = dataset[cuts] # Numpy array test @@ -434,7 +436,9 @@ def test_FrameBatchMultiTaskAED(self, asr_model, test_data_dir): @pytest.mark.unit def test_prompted_dataset(asr_model): - dataset = PromptedAudioToTextLhotseDataset(tokenizer=asr_model.tokenizer, prompt_format_fn=canary) + dataset = PromptedAudioToTextLhotseDataset( + tokenizer=asr_model.tokenizer, prompt=CanaryPromptFormatter(asr_model.tokenizer) + ) cuts = DummyManifest(CutSet, begin_id=0, end_id=3, with_data=True) diff --git a/tests/collections/common/test_2d_bucketing_constraint.py b/tests/collections/common/test_2d_bucketing_constraint.py index ba67d2e1fabb..fa771eb75f85 100644 --- a/tests/collections/common/test_2d_bucketing_constraint.py +++ b/tests/collections/common/test_2d_bucketing_constraint.py @@ -3,7 +3,7 @@ from lhotse import CutSet, Seconds, SupervisionSegment from lhotse.dataset import DynamicBucketingSampler from lhotse.testing.dummies import DummyManifest, dummy_cut -from nemo.collections.common.data.lhotse.dataloader import FixedBucketBatchSizeConstraint2D +from nemo.collections.common.data.lhotse.sampling import FixedBucketBatchSizeConstraint2D @pytest.fixture diff --git a/tests/collections/common/test_lhotse_dataloading.py b/tests/collections/common/test_lhotse_dataloading.py index adcb80ec3e55..b5eb1017f1e2 100644 --- a/tests/collections/common/test_lhotse_dataloading.py +++ b/tests/collections/common/test_lhotse_dataloading.py @@ -21,14 +21,14 @@ import numpy as np import pytest import torch -from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, SupervisionSegment, compute_num_samples +from lhotse import CutSet, MonoCut, NumpyFilesWriter, Recording, compute_num_samples from lhotse.audio import AudioLoadingError from lhotse.cut import Cut, MixedCut from lhotse.dataset import RoundRobinSampler, ZipSampler from lhotse.testing.dummies import dummy_recording +from lhotse.testing.random import deterministic_rng from omegaconf import OmegaConf -from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample, TextExample from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model @@ -1352,18 +1352,18 @@ def test_text_file_pairs_shards_input(txt_pair_paths_shards: tuple[str, str], qu @pytest.fixture(scope="session") -def en_es_tokenizer(tmp_path_factory, txt_en_path, txt_es_path) -> TokenizerWrapper: +def en_es_tokenizer(tmp_path_factory, txt_en_path, txt_es_path) -> SentencePieceTokenizer: tmpdir = tmp_path_factory.mktemp("en_es_tokenizer") text_path = tmpdir / "text.txt" text_path.write_text(txt_en_path.read_text() + "\n" + txt_es_path.read_text()) create_spt_model(text_path, vocab_size=128, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) - return TokenizerWrapper(SentencePieceTokenizer(str(tmpdir / "tokenizer.model"))) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) def test_multimodal_text_audio_dataloading( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1396,6 +1396,7 @@ def test_multimodal_text_audio_dataloading( "shuffle": True, "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": BT, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1437,16 +1438,16 @@ def test_multimodal_text_audio_dataloading( assert isinstance(ex.source.text, str) assert isinstance(ex.target.text, str) assert isinstance(ex.question.text, str) - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) def test_multimodal_text_audio_dataloading_zip_strategy( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1455,10 +1456,12 @@ def test_multimodal_text_audio_dataloading_zip_strategy( config = OmegaConf.create( { "multi_config": True, + "sampler_fusion": "zip", # <---- !!! this option is being tested here !!! + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, "audio": { - "sampler_fusion": "zip", # <---- !!! this option is being tested here !!! - "seed": 0, - "shard_seed": 0, "input_cfg": [ { "type": "nemo_tarred", @@ -1469,8 +1472,7 @@ def test_multimodal_text_audio_dataloading_zip_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, + "prompt_format": "plain", "use_multimodal_sampling": True, "batch_tokens": BT, # How to set token equivalent duration in actual training? @@ -1497,9 +1499,8 @@ def test_multimodal_text_audio_dataloading_zip_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": 64, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1543,10 +1544,10 @@ def test_multimodal_text_audio_dataloading_zip_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) b = batches[1] assert isinstance(b, lhotse.CutSet) @@ -1565,16 +1566,16 @@ def test_multimodal_text_audio_dataloading_zip_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) def test_multimodal_text_audio_dataloading_round_robin_strategy( txt_pair_paths_shards: tuple[str, str], nemo_tarred_manifest_path_multi: tuple[str, str], - en_es_tokenizer: TokenizerWrapper, + en_es_tokenizer: SentencePieceTokenizer, questions_path: str, ): en_paths, es_paths = txt_pair_paths_shards @@ -1583,10 +1584,12 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( config = OmegaConf.create( { "multi_config": True, + "sampler_fusion": "round_robin", # <---- !!! this option is being tested here !!! + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, "audio": { - "sampler_fusion": "round_robin", # <---- !!! this option is being tested here !!! - "seed": 0, - "shard_seed": 0, "input_cfg": [ { "type": "nemo_tarred", @@ -1597,9 +1600,8 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, "use_multimodal_sampling": True, + "prompt_format": "plain", "batch_tokens": BT, # How to set token equivalent duration in actual training? # assuming fbank frames: 0.01 is the base due to frame shift; @@ -1625,8 +1627,7 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( }, }, ], - "shuffle": True, - "num_workers": 0, + "prompt_format": "plain", "use_multimodal_sampling": True, "batch_tokens": BT, # How to set token equivalent duration in actual training? @@ -1677,10 +1678,156 @@ def test_multimodal_text_audio_dataloading_round_robin_strategy( assert ex.modality == "text" assert ex.source.language == "en" assert ex.target.language == "es" - assert isinstance(ex.input_ids, np.ndarray) - assert isinstance(ex.context_ids, np.ndarray) - assert isinstance(ex.answer_ids, np.ndarray) - assert isinstance(ex.mask, np.ndarray) + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) + + +def test_multimodal_text_audio_dataloading_randomized_round_robin_strategy( + deterministic_rng, + txt_pair_paths_shards: tuple[str, str], + nemo_tarred_manifest_path_multi: tuple[str, str], + en_es_tokenizer: SentencePieceTokenizer, + questions_path: str, +): + en_paths, es_paths = txt_pair_paths_shards + manifest_filepath, tarred_audio_filepaths = nemo_tarred_manifest_path_multi + QF, BT = 50, 64 + config = OmegaConf.create( + { + "multi_config": True, + "sampler_fusion": "randomized_round_robin", # <---- !!! this option is being tested here !!! + "sampler_weights": { + "audio": 0.5, + "text": 0.5, + }, + "seed": 0, + "shard_seed": 0, + "shuffle": True, + "num_workers": 0, + "audio": { + "input_cfg": [ + { + "type": "nemo_tarred", + "manifest_filepath": manifest_filepath, + "tarred_audio_filepaths": tarred_audio_filepaths, + "tags": { + "modality": "audio", + }, + }, + ], + "use_multimodal_sampling": True, + "prompt_format": "plain", + "batch_tokens": BT, + # How to set token equivalent duration in actual training? + # assuming fbank frames: 0.01 is the base due to frame shift; + # + subsampling x8 gives us 0.08 + # assuming discrete audio tokens, with frame rate 50Hz, + # we'd get 0.02 + # in this test we'll just use 0.1 for simplicity + "token_equivalent_duration": 0.1, + "quadratic_factor": QF, + }, + "text": { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": en_paths, + "target_paths": es_paths, + "source_language": "en", + "target_language": "es", + "questions_path": questions_path, + "questions_language": "en", + "tags": { + "modality": "text", + }, + }, + ], + "prompt_format": "plain", + "use_multimodal_sampling": True, + "batch_tokens": BT, + # How to set token equivalent duration in actual training? + # assuming fbank frames: 0.01 is the base due to frame shift; + # + subsampling x8 gives us 0.08 + # assuming discrete audio tokens, with frame rate 50Hz, + # we'd get 0.02 + # in this test we'll just use 0.1 for simplicity + "token_equivalent_duration": 0.1, + "quadratic_factor": QF, + }, + } + ) + + dl = get_lhotse_dataloader_from_config( + config=config, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=en_es_tokenizer, + ) + + assert isinstance(dl.dataset.sampler, RoundRobinSampler) + + # Note: we use islice here because the dataloader will be infinite. + batches = [batch for batch in islice(dl, 2)] + + # Batch 0 is audio-only + b = batches[0] + assert isinstance(b, lhotse.CutSet) + assert len(b) + assert all(isinstance(ex, Cut) for ex in b) + # Batch tokens is not exceeded after applying the quadratic factor correction + assert sum(ex.num_tokens**2 / QF for ex in b) <= BT + for ex in b: + assert ex.modality == "audio" + assert isinstance(ex.load_audio(), np.ndarray) + assert isinstance(ex.supervisions[0].text, str) + + # Batch 1 is text-only + b = batches[1] + assert isinstance(b, lhotse.CutSet) + assert len(b) + assert all(isinstance(ex, SourceTargetTextExample) for ex in b) + # Batch tokens is not exceeded after applying the quadratic factor correction + assert sum(ex.num_tokens**2 / QF for ex in b) <= BT + for ex in b: + assert ex.modality == "text" + assert ex.source.language == "en" + assert ex.target.language == "es" + assert torch.is_tensor(ex.input_ids) + assert torch.is_tensor(ex.context_ids) + assert torch.is_tensor(ex.answer_ids) + assert torch.is_tensor(ex.mask) + + +def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): + config = OmegaConf.create( + { + "cuts_path": str(cutset_path), + "noise_path": str(nemo_manifest_path), + "noise_mix_prob": 1.0, + "noise_snr": [-5.0, 5.0], + "batch_size": 2, + "seed": 0, + "shard_seed": 0, + } + ) + dl = get_lhotse_dataloader_from_config( + config=config, + global_rank=0, + world_size=1, + dataset=Identity(), + ) + batch = next(iter(dl)) + assert isinstance(batch, CutSet) + assert len(batch) == 2 + cut = batch[0] + assert isinstance(cut, MixedCut) + assert -5.0 < cut.tracks[1].snr < 5.0 + cut = batch[1] + assert isinstance(cut, MixedCut) + assert -5.0 < cut.tracks[1].snr < 5.0 def test_dataloader_with_noise_nemo_json(cutset_path: Path, nemo_manifest_path: Path): diff --git a/tests/collections/common/test_lhotse_multimodal_dataloading.py b/tests/collections/common/test_lhotse_multimodal_dataloading.py index f53a45a72971..9870ad14b25e 100644 --- a/tests/collections/common/test_lhotse_multimodal_dataloading.py +++ b/tests/collections/common/test_lhotse_multimodal_dataloading.py @@ -1,14 +1,15 @@ -import json -from itertools import islice - import lhotse import numpy as np import pytest import torch -from lhotse.testing.dummies import dummy_cut, dummy_recording +from lhotse.testing.dummies import dummy_recording from omegaconf import OmegaConf from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.sampling import ( + MultimodalFixedBucketBatchSizeConstraint2D, + MultimodalSamplingConstraint, +) from nemo.collections.common.data.lhotse.text_adapters import ( AudioTurn, NeMoMultimodalConversation, @@ -16,7 +17,7 @@ NeMoMultimodalConversationTarWriter, TextTurn, ) -from nemo.collections.common.tokenizers.aggregate_tokenizer import TokenizerWrapper +from nemo.collections.common.prompts import Llama2PromptFormatter from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model @@ -134,8 +135,6 @@ def test_multimodal_conversation_input(multimodal_conversations_path): assert isinstance(t, TextTurn) assert t.role == "assistant" assert t.value == "Of course!" - for key in ("input_ids", "context_ids", "answer_ids", "mask"): - assert getattr(ex, key) is None # not tokenized/prompted @pytest.fixture @@ -221,6 +220,171 @@ def test_multimodal_conversation_input_with_prompt(multimodal_conversations_path assert (ex.mask[95:] == True).all() # assistant turn +def test_text_only_conversation_length_measurement(tokenizer): + convo = NeMoMultimodalConversation( + id="textonly-1", + turns=[ + TextTurn("hello", "user"), + TextTurn("hi", "assistant"), + ], + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + assert tokenizer.ids_to_text(convo.input_ids) == "[INST] hello [/INST] hi" + assert tokenizer.ids_to_text(convo.context_ids) == "[INST] hello [/INST]" + assert tokenizer.ids_to_text(convo.answer_ids) == "hi" + + assert convo.input_length == len(convo.context_ids) == 10 + assert convo.output_length == len(convo.answer_ids) == 4 + assert convo.total_length == len(convo.input_ids) == 14 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 10 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 14 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[5, 10, 15], batch_sizes=[3, 2, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 14 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 2 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[(5, 2), (5, 5), (15, 3), (15, 6), (15, 10)], + batch_sizes=[5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (10, 4) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 3 + + +def test_audio_only_conversation_length_measurement(tokenizer, tmp_path_factory): + audio_dir = tmp_path_factory.mktemp("audio") + c1 = dummy_recording(0, duration=7.16, with_data=True).to_cut().save_audio(audio_dir / "1.wav") + c2 = dummy_recording(1, duration=15.96, with_data=True).to_cut().save_audio(audio_dir / "2.wav") + convo = NeMoMultimodalConversation( + id="audioonly-1", + turns=[ + AudioTurn(c1, "user", "[audio]"), + AudioTurn(c2, "assistant", "[audio]"), + ], + token_equivalent_duration=0.1, # 10ms frame_shift * 10x subsampling for easy testing + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + assert tokenizer.ids_to_text(convo.input_ids) == "[INST] [audio] [/INST] [audio]" + assert tokenizer.ids_to_text(convo.context_ids) == "[INST] [audio] [/INST]" + assert tokenizer.ids_to_text(convo.answer_ids) == "[audio]" + + # NOTE: Unlike text-only, len(context_ids) != convo.input_length! The same is true for answer and input ids. + # 7.16s with 100ms frame is 72 tokens, we have 7 context tokens, but replace 1 audio locator tag. + assert len(convo.context_ids) == 7 + assert convo.input_length == 78 + + # 15.96s with 100ms frame is 160 tokens, we have 3 answer tokens, but replace 1 audio locator tag. + assert len(convo.answer_ids) == 3 + assert convo.output_length == 162 + + assert len(convo.input_ids) == 10 + assert convo.total_length == 162 + 78 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 78 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 162 + 78 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[100, 200, 300, 400], batch_sizes=[3, 2, 1, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 162 + 78 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 2 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[ + (50, 50), + (50, 100), + (50, 200), + (100, 50), + (100, 150), + (100, 200), + (100, 300), + (400, 400), + ], + batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (78, 162) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 5 + + +def test_multimodal_conversation_length_measurement(tokenizer, tmp_path_factory): + audio_dir = tmp_path_factory.mktemp("audio") + c1 = dummy_recording(0, duration=7.16, with_data=True).to_cut().save_audio(audio_dir / "1.wav") + c2 = dummy_recording(1, duration=15.96, with_data=True).to_cut().save_audio(audio_dir / "2.wav") + convo = NeMoMultimodalConversation( + id="multimodal-1", + turns=[ + TextTurn("listen to this and tell me your opinion", "user"), + AudioTurn(c1, "user", "[audio]"), + TextTurn("its fine", "assistant"), + TextTurn("remove the noise", "user"), + TextTurn("sure", "assistant"), + AudioTurn(c2, "assistant", "[audio]"), + ], + token_equivalent_duration=0.1, # 10ms frame_shift * 10x subsampling for easy testing + ) + convo = convo.apply_prompt_format(Llama2PromptFormatter(tokenizer)) + print(convo) + assert ( + tokenizer.ids_to_text(convo.input_ids) + == "[INST] listen to this and tell me your opinion [audio] [/INST] its fine [INST] remove the noise [/INST] sure [audio]" + ) + assert ( + tokenizer.ids_to_text(convo.context_ids) + == "[INST] listen to this and tell me your opinion [audio] [/INST] its fine [INST] remove the noise [/INST]" + ) + assert tokenizer.ids_to_text(convo.answer_ids) == "sure [audio]" + + assert len(convo.context_ids) == 66 + assert convo.input_length == 66 + 72 - 1 == 137 + + # 15.96s with 100ms frame is 160 tokens, we have 3 answer tokens, but replace 1 audio locator tag. + assert len(convo.answer_ids) == 7 + assert convo.output_length == 7 + 160 - 1 == 166 + + assert len(convo.input_ids) == 73 + assert convo.total_length == 137 + 166 == 303 + + constr = MultimodalSamplingConstraint(measure_total_length=False) + assert constr.measure_length(convo) == 137 + + constr = MultimodalSamplingConstraint(measure_total_length=True) + assert constr.measure_length(convo) == 303 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[100, 200, 300, 400], batch_sizes=[3, 2, 1, 1], measure_total_length=True + ) + assert constr.measure_length(convo) == 303 + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 3 + + constr = MultimodalFixedBucketBatchSizeConstraint2D( + max_seq_len_buckets=[ + (50, 50), + (50, 100), + (50, 200), + (100, 50), + (100, 150), + (100, 200), + (100, 300), + (400, 400), + ], + batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1], + measure_total_length=False, + ) + assert constr.measure_length(convo) == (137, 166) + assert constr.select_bucket(constr.max_seq_len_buckets, convo) == 7 + + def test_multimodal_conversation_tarred_format(multimodal_conversations_path, tmp_path_factory): (conversation,) = list(NeMoMultimodalConversationJsonlAdapter(multimodal_conversations_path, "[audio]")) tar_dir = tmp_path_factory.mktemp("multi_convo_tarred") diff --git a/tests/collections/common/test_lhotse_prompt_format_data_types.py b/tests/collections/common/test_lhotse_prompt_format_data_types.py new file mode 100644 index 000000000000..4347c467a4ae --- /dev/null +++ b/tests/collections/common/test_lhotse_prompt_format_data_types.py @@ -0,0 +1,283 @@ +import lhotse.serialization +import pytest +from lhotse import CutSet, SupervisionSegment +from lhotse.cut import Cut +from lhotse.testing.dummies import dummy_cut + +from nemo.collections.common.data import ( + NeMoSFTExample, + SourceTargetTextExample, + TextExample, + get_lhotse_dataloader_from_config, +) +from nemo.collections.common.tokenizers import SentencePieceTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model + + +@pytest.fixture +def tokenizer(tmp_path_factory): + tmpdir = tmp_path_factory.mktemp("tok") + text_path = tmpdir / "text.txt" + text_path.write_text("\n".join(chr(i) for i in range(256))) + create_spt_model( + text_path, + vocab_size=512, + sample_size=-1, + do_lower_case=False, + output_dir=str(tmpdir), + bos=True, + eos=True, + user_defined_symbols=[ + "[INST]", + "[/INST]", + "<>", + "<>", + "[audio]", + "", + "", + ], + ) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +@pytest.fixture +def cuts_path(tmp_path_factory): + tmp_path = tmp_path_factory.getbasetemp() / "cuts.jsonl" + c = dummy_cut(0, duration=1.0, supervisions=[SupervisionSegment("", "", 0, 1.0, text="dummy text")]) + c.context = "dummy context" + CutSet([c]).to_file(tmp_path) + return tmp_path + + +@pytest.fixture +def src_tgt_example(tmp_path_factory): + d = tmp_path_factory.mktemp("src_tgt") + (d / "src.txt").write_text("an example") + (d / "tgt.txt").write_text("elpmaxe na") + return (d / "src.txt"), (d / "tgt.txt") + + +@pytest.fixture +def nemo_sft_example(tmp_path_factory): + tmp_path = tmp_path_factory.getbasetemp() / "nemo_sft.jsonl" + lhotse.serialization.save_to_jsonl( + [ + { + "system": "", + "mask": "User", + "dataset": "", + "conversations": [ + { + "from": "User", + "value": "Hi, how are you?", + }, + { + "from": "Assistant", + "value": "Good day, I'm a useful assistant.", + }, + ], + } + ], + tmp_path, + ) + return tmp_path + + +class Identity: + def __getitem__(self, item): + return item + + +def test_prompt_format_cut(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "min_duration": 0, + "max_duration": 10, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, Cut) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] dummy context [/INST] dummy text" + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] dummy context [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "dummy text" + + +def test_prompt_format_cut_filtered_out(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "min_duration": 0, + "max_duration": 0.5, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + next(iter(dl)) + + +def test_prompt_format_cut_max_tokens_has_no_filtering_effect(cuts_path, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "cuts_path": cuts_path, + "batch_size": 1, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "token_equivalent_duration": 0.1, + "min_tokens": 1, + "max_tokens": 2, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, Cut) + + +def test_prompt_format_src_tgt(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + {"type": "txt_pair", "source_paths": src_tgt_example[0], "target_paths": src_tgt_example[1]} + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 50, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, SourceTargetTextExample) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] an example [/INST] elpmaxe na" + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] an example [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "elpmaxe na" + + +def test_prompt_format_src_tgt_filtered_out(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + {"type": "txt_pair", "source_paths": src_tgt_example[0], "target_paths": src_tgt_example[1]} + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 10, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + batch = next(iter(dl)) + + +def test_prompt_format_src_tgt_2d(src_tgt_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": src_tgt_example[0], + "target_paths": src_tgt_example[1], + "target_language": "reversed", + } + ], + "batch_size": 1, + "force_finite": True, + "prompt_format": "t5nmt", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 50, + "use_total_length": False, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, SourceTargetTextExample) + assert tokenizer.ids_to_text(ex.input_ids) == " an example elpmaxe na" + assert tokenizer.ids_to_text(ex.context_ids) == " an example" + assert tokenizer.ids_to_text(ex.answer_ids) == "elpmaxe na" + + +def test_prompt_format_nemo_sft(nemo_sft_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [{"type": "nemo_sft_jsonl", "paths": nemo_sft_example}], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 100, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + + batch = next(iter(dl)) + ex = batch[0] + assert isinstance(ex, NeMoSFTExample) + assert tokenizer.ids_to_text(ex.input_ids) == "[INST] Hi, how are you? [/INST] Good day, I'm a useful assistant." + assert tokenizer.ids_to_text(ex.context_ids) == "[INST] Hi, how are you? [/INST]" + assert tokenizer.ids_to_text(ex.answer_ids) == "Good day, I'm a useful assistant." + + +def test_prompt_format_nemo_sft_filtered_out(nemo_sft_example, tokenizer): + dl = get_lhotse_dataloader_from_config( + { + "input_cfg": [{"type": "nemo_sft_jsonl", "paths": nemo_sft_example}], + "batch_size": 1, + "force_finite": True, + "prompt_format": "llama2", + "use_multimodal_dataloading": True, + "min_tokens": 1, + "max_tokens": 5, + "use_total_length": True, + }, + global_rank=0, + world_size=1, + dataset=Identity(), + tokenizer=tokenizer, + ) + with pytest.raises(StopIteration): + batch = next(iter(dl)) diff --git a/tests/collections/common/test_lhotse_seqlen_filters.py b/tests/collections/common/test_lhotse_seqlen_filters.py new file mode 100644 index 000000000000..ba77b235c6e5 --- /dev/null +++ b/tests/collections/common/test_lhotse_seqlen_filters.py @@ -0,0 +1,171 @@ +from copy import deepcopy + +import numpy as np +import pytest +from lhotse import SupervisionSegment +from lhotse.testing.dummies import dummy_cut + +from nemo.collections.common.data.lhotse.sampling import ( + DurationFilter, + TokenCountFilter, + TokenPerSecondFilter, + TokenPerTokenFilter, +) +from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample, SourceTargetTextExample, TextExample + + +@pytest.fixture +def cut(): + c = dummy_cut(0, duration=1.0, supervisions=[SupervisionSegment("", "", 0, 1.0, text="dummy")]) + c.supervisions[0].tokens = [1, 37, 12, 2] + return c + + +def test_cut_duration_filter(cut): + f = DurationFilter(0, 10) + assert f(cut) == True + + f = DurationFilter(0, 0.5) + assert f(cut) == False + + f = DurationFilter(1.5, 2.0) + assert f(cut) == False + + +def test_cut_token_per_second_filter(cut): + f = TokenPerSecondFilter(tps_min=0.0, tps_max=5.0) + assert f(cut) == True + + f = TokenPerSecondFilter(tps_min=0.0, tps_max=1.0) + assert f(cut) == False + + f = TokenPerSecondFilter(tps_min=10.0, tps_max=12.0) + assert f(cut) == False + + +def test_cut_passes_by_token_count_and_tpt_filter(cut): + assert TokenCountFilter(1, 10, measure_total_length=True)(cut) == True + assert TokenPerTokenFilter(1, 10)(cut) == True + + +def test_cut_passes_by_token_count_and_tpt_filter(cut): + assert TokenCountFilter(1, 10, measure_total_length=True)(cut) == True + assert TokenPerTokenFilter(1, 10)(cut) == True + + +@pytest.fixture +def src_tgt_example(): + return SourceTargetTextExample( + source=TextExample("", tokens=np.array([1, 37, 12, 2])), + target=TextExample("", tokens=np.array([1, 1823, 1245, 2446, 1038, 2])), + ) + + +def test_src_tgt_token_filter_requires_prompt_formatting(src_tgt_example): + with pytest.raises(RuntimeError): + TokenCountFilter(0, 1, True)(src_tgt_example) + + +def test_src_tgt_passes_by_duration_filter(src_tgt_example): + assert DurationFilter(1, 10)(src_tgt_example) == True + assert TokenPerSecondFilter(1, 10)(src_tgt_example) == True + + +def test_src_tgt_token_filter(src_tgt_example): + example = deepcopy(src_tgt_example) + example.input_ids = np.concatenate((example.source.tokens, example.target.tokens)) + example.context_ids = example.source.tokens + example.answer_ids = example.target.tokens + + """ + Input length measurement / encoder-decoder models / 2D bucketing + """ + f = TokenCountFilter(1, 5, measure_total_length=False) + assert f(example) == True + + f = TokenCountFilter(1, 3, measure_total_length=False) + assert f(example) == False + + f = TokenCountFilter(10, 30, measure_total_length=False) + assert f(example) == False + + """ + Total length measurement / decoder-only models / 1D bucketing + """ + f = TokenCountFilter(1, 5, measure_total_length=True) + assert f(example) == False + + f = TokenCountFilter(1, 20, measure_total_length=True) + assert f(example) == True + + f = TokenCountFilter(1, 3, measure_total_length=True) + assert f(example) == False + + f = TokenCountFilter(20, 30, measure_total_length=True) + assert f(example) == False + + +@pytest.fixture +def nemo_sft_example(): + example = NeMoSFTExample( + data={ + "system": "", + "mask": "User", + "dataset": "", + "conversations": [ + { + "from": "User", + "value": "Hi, how are you?", + }, + { + "from": "Assistant", + "value": "Good day, I'm a useful assistant.", + }, + ], + }, + ) + return example + + +def test_nemo_sft_token_filter_requires_prompt_formatting(nemo_sft_example): + with pytest.raises(RuntimeError): + TokenCountFilter(0, 1, True)(nemo_sft_example) + + +def test_nemo_sft_passes_by_duration_filter(nemo_sft_example): + assert DurationFilter(1, 10)(nemo_sft_example) == True + assert TokenPerSecondFilter(1, 10)(nemo_sft_example) == True + + +def test_nemo_sft_token_filter(nemo_sft_example): + example = deepcopy(nemo_sft_example) + example.input_ids = np.array([1, 123, 3425, 123, 2345, 324, 54, 2]) + example.context_ids = np.array([1, 123, 3425]) + example.answer_ids = np.array([123, 2345, 324, 54, 2]) + + """ + Input length measurement / encoder-decoder models / 2D bucketing + """ + f = TokenCountFilter(1, 5, measure_total_length=False) + assert f(example) == True + + f = TokenCountFilter(1, 2, measure_total_length=False) + assert f(example) == False + + f = TokenCountFilter(10, 30, measure_total_length=False) + assert f(example) == False + + """ + Total length measurement / decoder-only models / 1D bucketing + """ + f = TokenCountFilter(1, 5, measure_total_length=True) + assert f(example) == False + + f = TokenCountFilter(1, 20, measure_total_length=True) + assert f(example) == True + + f = TokenCountFilter(1, 3, measure_total_length=True) + assert f(example) == False + + f = TokenCountFilter(10, 30, measure_total_length=True) + assert f(example) == False diff --git a/tests/collections/multimodal/test_emmett.py b/tests/collections/multimodal/test_emmett.py new file mode 100644 index 000000000000..553343b8a711 --- /dev/null +++ b/tests/collections/multimodal/test_emmett.py @@ -0,0 +1,239 @@ +import pytest +import torch +from lhotse import CutSet, MonoCut, SupervisionSegment +from lhotse.testing.dummies import dummy_recording +from omegaconf import OmegaConf + +from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config +from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample, TextExample +from nemo.collections.common.tokenizers import SentencePieceTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import create_spt_model +from nemo.collections.multimodal.speech_llm.data.lhotse_dataset import LhotseAudioQuestionAnswerDataset +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import PromptFormatterTextProcessing + + +class Identity(torch.utils.data.Dataset): + def __getitem__(self, cuts): + return cuts + + +@pytest.fixture +def tokenizer(capsys, tmp_path_factory): + TOKENIZER_TRAIN_TEXT = """ + Example system message. + Example user message. + Example assistant message. + TEST + [INST] + [/INST] + + + <> + <> + User: Assistant: + user model + Instruct Output + \n\n + + <| + |> + <|en|> <|de|> <|fr|> <|es|> <|transcribe|> <|translate|> <|pnc|> <|nopnc|> <|startoftranscript|> <|endoftext|> + Feel free to add new tokens for your own tests!? + But know that if you do so, you may need to update the token IDs in the existing tests! + So, it might be a good idea to create a new tokenizer instead when adding new prompt formats. + """ + tmpdir = tmp_path_factory.mktemp("bpe_tokenizer") + text_path = tmpdir / "text.txt" + text_path.write_text(TOKENIZER_TRAIN_TEXT) + with capsys.disabled(): + create_spt_model(str(text_path), vocab_size=512, sample_size=-1, do_lower_case=False, output_dir=str(tmpdir)) + return SentencePieceTokenizer(str(tmpdir / "tokenizer.model")) + + +""" +TEST FOR AUDIO DATALOADING WITH EMMETT +""" + + +@pytest.fixture +def cuts(): + return CutSet( + [ + MonoCut( + id="ex0", + start=0, + duration=5.0, + channel=0, + supervisions=[ + SupervisionSegment( + id="ex0", + recording_id="dummy-recording-0000", + start=0, + duration=5.0, + text="some transcription", + language="en", + ) + ], + recording=dummy_recording(0, duration=5.0, with_data=True), + custom={ + "context": "", + "answer": "some desired answer", + }, + ), + ] + ) + + +@pytest.fixture +def cuts_path(tmp_path_factory, cuts): + tmp_path = tmp_path_factory.mktemp("data") + p = tmp_path / "cuts.jsonl.gz" + pa = tmp_path / "audio" + cuts.save_audios(pa).to_file(p) + return p + + +def test_audio_example_with_prompt_emmett_t5(cuts_path, tokenizer): + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "lhotse", + "cuts_path": cuts_path, + }, + ], + "prompt_format": "t5nmt", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # First test that sampling is correct and tokenizer + prompt formatter is applied there + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, MonoCut) + + assert ex.has_custom("context_ids") + assert torch.is_tensor(ex.context_ids) + assert tokenizer.ids_to_text(ex.context_ids) == "" + + assert ex.has_custom("answer_ids") + assert torch.is_tensor(ex.answer_ids) + assert tokenizer.ids_to_text(ex.answer_ids) == "some transcription" + + assert ex.has_custom("input_ids") + assert torch.is_tensor(ex.input_ids) + assert tokenizer.ids_to_text(ex.input_ids) == " some transcription" + + # Test that speechlm dataset processes the example correctly + + text_processor = PromptFormatterTextProcessing(tokenizer=tokenizer, prompt_format="t5nmt") + dataset = LhotseAudioQuestionAnswerDataset( + text_processor=text_processor, + default_context="", + tokens_to_generate=0, + pad_to_max_length=False, + max_seq_length=64, + ) + + batch = dataset[batches[0]] + assert tokenizer.ids_to_text(batch["tokens"][0]) == " some transcriptio" + assert tokenizer.ids_to_text(batch["labels"][0]) == "en> some transcription" + assert tokenizer.ids_to_text(batch["contexts"][0]) == "" + assert tokenizer.ids_to_text(batch["answers"][0]) == "some transcription" + + +""" +TEST FOR TEXT DATALOADING WITH EMMETT +""" + + +@pytest.fixture +def nmt_paths(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("nmtdata") + src = tmp_path / "src.txt" + tgt = tmp_path / "tgt.txt" + q = tmp_path / "q.txt" + src.write_text("fake german") + tgt.write_text("real english") + q.write_text("") + return src, tgt, q + + +def test_text_example_with_prompt_emmett_t5(nmt_paths, tokenizer): + src, tgt, q = nmt_paths + config = OmegaConf.create( + { + "input_cfg": [ + { + "type": "txt_pair", + "source_paths": src, + "target_paths": tgt, + "source_language": "de", + "target_language": "en", + "questions_path": q, + "questions_language": "en", + }, + ], + "prompt_format": "t5nmt", + "force_finite": True, + "shuffle": True, + "num_workers": 0, + "batch_size": 1, + "seed": 0, + "shard_seed": 0, + } + ) + + # First test that sampling is correct and tokenizer + prompt formatter is applied there + + dl = get_lhotse_dataloader_from_config( + config=config, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer + ) + batches = [batch for batch in dl] + assert len(batches) == 1 + + b = batches[0] + assert isinstance(b, CutSet) + assert len(b) == 1 + ex = b[0] + assert isinstance(ex, SourceTargetTextExample) + + assert torch.is_tensor(ex.context_ids) + assert tokenizer.ids_to_text(ex.context_ids) == " fake german" + + assert torch.is_tensor(ex.answer_ids) + assert tokenizer.ids_to_text(ex.answer_ids) == "real english" + + assert torch.is_tensor(ex.input_ids) + assert tokenizer.ids_to_text(ex.input_ids) == " fake german real english" + + # Test that speechlm dataset processes the example correctly + + text_processor = PromptFormatterTextProcessing(tokenizer=tokenizer, prompt_format="t5nmt") + dataset = LhotseAudioQuestionAnswerDataset( + text_processor=text_processor, + default_context="", + tokens_to_generate=0, + pad_to_max_length=False, + max_seq_length=64, + ) + + batch = dataset[batches[0]] + + assert tokenizer.ids_to_text(batch["text_input_ids"][0]) == " fake german real english" + assert tokenizer.ids_to_text(batch["text_context_ids"][0]) == " fake german" + assert tokenizer.ids_to_text(batch["text_answer_ids"][0]) == "real english" diff --git a/tests/collections/multimodal/test_speechllm_dataset.py b/tests/collections/multimodal/test_speechllm_dataset.py index de554a219ca4..b4c51c4fc978 100644 --- a/tests/collections/multimodal/test_speechllm_dataset.py +++ b/tests/collections/multimodal/test_speechllm_dataset.py @@ -84,7 +84,6 @@ def test_speechllm_dataset(tokenizer, cuts): ) batch = dataset[cuts] - print(batch) expected_keys = { "sample_ids", @@ -368,8 +367,8 @@ def test_speechllm_dataset_tokens_to_generate_increases_seq_len(llama_tokenizer, max_seq_length=512, ) batch = dataset[cuts] - assert batch["tokens"].shape == (1, 347) # was 351 before padding optimization - assert batch["labels"].shape == (1, 347) # was 351 before padding optimization - assert batch["contexts"].shape == (1, 337) # was 352 before padding optimization - assert batch["answers"].shape == (1, 267) # was 352 before padding optimization - assert batch["position_ids"].shape == (1, 348) # was 352 before padding optimization + assert batch["tokens"].shape == (1, 91) + assert batch["labels"].shape == (1, 91) + assert batch["contexts"].shape == (1, 337) + assert batch["answers"].shape == (1, 11) + assert batch["position_ids"].shape == (1, 92) diff --git a/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb new file mode 100644 index 000000000000..b9ddf350cdca --- /dev/null +++ b/tutorials/multimodal/Multimodal Lhotse Dataloading.ipynb @@ -0,0 +1,1014 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e930b0c5f0cffbce", + "metadata": {}, + "source": [ + "# Multimodal Lhotse Dataloading\n", + "\n", + "This tutorial explains how NeMo uses Lhotse for multimodal dataloading.\n", + "The modalities supported as of the time of writing are audio and text.\n", + "The intended audience of this tutorial are NeMo developers and persons who build/modify NeMo models.\n", + "After finishing this tutorial, you should have an understanding how to use various Lhotse building blocks in NeMo for designing the kind of model you want.\n", + "\n", + "We cover the following topics:\n", + "* What are data types?\n", + "* What data types are availabe in NeMo?\n", + "* How do we read them from files?\n", + "* How to apply prompt formatting to various data types?\n", + "* How to create tensors for training with these examples?\n", + "* How to optimize the training by stratifying data sampling on sequence lengths, and how these lengths are measured for different examples and models. \n", + "* How to train on multiple data types together?" + ] + }, + { + "cell_type": "markdown", + "id": "72bd180c65992eba", + "metadata": {}, + "source": [ + "## Data types\n", + "\n", + "A data type represents examples of your training data: speech recordings, text sentences, text sentence pairs, conversations, etc.\n", + "\n", + "A data type consists of:\n", + "* a class that represents a single sample\n", + " * includes properties allowing sequence length measurement for sampling purposes\n", + "* a parser class that's initialized with a config (e.g. paths to data) and acts as an iterator of examples\n", + "* extension functions that define how to apply prompt formatting to a given data type\n", + "\n", + "NeMo uses Lhotse Cuts as a basic data type for audio, and defines several data types for text. We'll go over them below.\n", + "\n", + "External references:\n", + "* [Lhotse documentation](https://lhotse.readthedocs.io/en/latest/getting-started.html)\n", + "* [Lhotse in NeMo documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#lhotse-dataloading)" + ] + }, + { + "cell_type": "markdown", + "id": "cf32bf3ea5a9cb17", + "metadata": {}, + "source": [ + "### Audio examples (Lhotse cuts)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d2d747f6b32d5942", + "metadata": { + "jupyter": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from lhotse import MonoCut, Recording, SupervisionSegment, AudioSource\n", + "from lhotse.testing.dummies import dummy_cut\n", + "\n", + "\n", + "# A basic audio example: recording with transcription\n", + "cut = MonoCut(\n", + " id=\"utt-0\",\n", + " start=0.0,\n", + " duration=10.0,\n", + " channel=0,\n", + " supervisions=[SupervisionSegment(id=\"utt-0\", recording_id=\"rec-0\", start=0.0, duration=10.0, text=\"Welcome to Lhotse!\")],\n", + " recording=Recording(\n", + " id=\"rec-0\",\n", + " sources=[AudioSource(type=\"file\", channels=[0], source=\"/path/to/recording.wav\")],\n", + " sampling_rate=16000,\n", + " duration=10.0,\n", + " num_samples=160000,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9b121afd920bdab2", + "metadata": {}, + "source": [ + "## Single text examples " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "41b0c148e0d7ac1c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TextExample(text='This is a single sentence, which may be used in language modeling.', language='en', tokens=None, custom=None)\n" + ] + } + ], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import TextExample\n", + "\n", + "# A basic text example: single line of text.\n", + "text = TextExample(\n", + " text=\"This is a single sentence, which may be used in language modeling.\",\n", + " language=\"en\"\n", + ")\n", + "print(text)" + ] + }, + { + "cell_type": "markdown", + "id": "2abb821b69f71a91", + "metadata": {}, + "source": [ + "## Pairs of text examples" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "282560cc3df9174a", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import SourceTargetTextExample\n", + "\n", + "# A pair of text examples, usable e.g. in machine translation.\n", + "text_pair = SourceTargetTextExample(\n", + " source=TextExample(\n", + " text=\"Some machine translation example.\",\n", + " language=\"en\",\n", + " ),\n", + " target=TextExample(\n", + " text=\"Algunos ejemplos de traducción automática.\",\n", + " language=\"es\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "858d6cb6abb1ccd6", + "metadata": {}, + "source": [ + "## Conversations: text, audio, and multimodal" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e5bd8caee40100b1", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.data.lhotse.text_adapters import NeMoMultimodalConversation, TextTurn, AudioTurn\n", + "\n", + "# A text-only conversation, useful for chat LLM training.\n", + "text_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-text-0\",\n", + " turns=[\n", + " TextTurn(value=\"Is this a text-only conversation?\", role=\"user\"),\n", + " TextTurn(value=\"Yes, but we can do more than that.\", role=\"assistant\"),\n", + " TextTurn(value=\"Tell me more.\", role=\"user\"),\n", + " TextTurn(value=\"Of course! Let's move on to the next example.\", role=\"assistant\"),\n", + " ]\n", + ")\n", + "\n", + "# An audio-only conversation, useful for chat speech LLM training.\n", + "# We'll explain [audio] tag and token_equivalent_duration later in this tutorial.\n", + "audio_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-audio-0\",\n", + " turns=[\n", + " AudioTurn(cut=dummy_cut(0, duration=7.18, with_data=True), role=\"user\", audio_locator_tag=\"[audio]\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=21.64, with_data=True), role=\"assistant\", audio_locator_tag=\"[audio]\"),\n", + " ],\n", + " token_equivalent_duration=0.08,\n", + ")\n", + "\n", + "# A multimodal conversation.\n", + "multimodal_conversation = NeMoMultimodalConversation(\n", + " id=\"convo-multimodal-0\",\n", + " turns=[\n", + " TextTurn(value=\"Is this a text-only conversation?\", role=\"user\"),\n", + " TextTurn(value=\"No, feel free to speak to me.\", role=\"assistant\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=5.87, with_data=True), role=\"user\", audio_locator_tag=\"[audio]\"),\n", + " TextTurn(value=\"Should I respond in voice too?\", role=\"assistant\"),\n", + " TextTurn(value=\"Yes\", role=\"user\"),\n", + " TextTurn(value=\"Certainly!\", role=\"assistant\"),\n", + " AudioTurn(cut=dummy_cut(0, duration=14.62, with_data=True), role=\"assistant\", audio_locator_tag=\"[audio]\"),\n", + " ],\n", + " token_equivalent_duration=0.08,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b21e0e5e84904d89", + "metadata": {}, + "source": [ + "As you can see, these data structures serve as a complete description of training examples of different types, \n", + "as they contain both the data (audio) and various metadata." + ] + }, + { + "cell_type": "markdown", + "id": "9198210580be10bf", + "metadata": {}, + "source": [ + "## Parsing data types from files\n", + "\n", + "Related: for an overview of NeMo data configuration format, please see these docs: \n", + "* [Extended multi-dataset configuration format](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#extended-multi-dataset-configuration-format)\n", + "* [Configuring multi-modal dataloading](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/datasets.html#configuring-multi-modal-dataloading)\n", + "\n", + "The goal of data type parser is to read a configuration specifying where the data is located / how to read it,\n", + "create an iterable over the corresponding data type, and wrap it into a Lhotse CutSet.\n", + "\n", + "Adding support for a new data type parser requires two components:\n", + "* An adapter/iterator class dedicated to your data type.\n", + "* A function that instantiates this adapter/iterator, registered with a `@data_type_parser(\"name\")` decorator to make it auto-detectable by NeMo.\n", + "\n", + "We'll take a deeper look at how source-target text example pairs parsing is implemented. We'll implement a custom parser for `SourceTargetTextExample` that reads them from JSON files." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f0e35b53c7ac77b4", + "metadata": {}, + "outputs": [], + "source": [ + "from lhotse.serialization import load_jsonl\n", + "import random\n", + "from typing import Literal, Iterator\n", + "from dataclasses import dataclass\n", + "\n", + "from lhotse import CutSet\n", + "from lhotse.dataset.dataloading import resolve_seed\n", + "from omegaconf import DictConfig\n", + "from nemo.collections.common.data.lhotse.nemo_adapters import expand_sharded_filepaths\n", + "from nemo.collections.common.data.lhotse.cutset import data_type_parser\n", + "\n", + "\n", + "@dataclass\n", + "class LhotseTextPairAdapterFromJsonl:\n", + " manifest_path: str | list[str]\n", + " shuffle_shards: bool = False\n", + " shard_seed: int | Literal[\"trng\", \"randomized\"] = \"trng\"\n", + "\n", + " def __post_init__(self):\n", + " self.manifest_path = expand_sharded_filepaths(self.manifest_path)\n", + "\n", + " def __iter__(self) -> Iterator[SourceTargetTextExample]:\n", + " seed = resolve_seed(self.shard_seed)\n", + " rng = random.Random(seed)\n", + " paths = self.manifest_path\n", + " if self.shuffle_shards:\n", + " rng.shuffle(paths)\n", + " for p in paths:\n", + " for item in load_jsonl(p):\n", + " yield SourceTargetTextExample(\n", + " source=TextExample(item[\"source\"], item.get(\"source_lang\")),\n", + " target=TextExample(item[\"target\"], item.get(\"target_lang\")),\n", + " question=(\n", + " TextExample(item[\"prompt\"], language=item(\"prompt_lang\"))\n", + " if \"prompt\" in item\n", + " else None\n", + " ),\n", + " )\n", + "\n", + "\n", + "@data_type_parser(\"txt_pair_jsonl\")\n", + "def read_txt_pair_paths(config: DictConfig) -> tuple[CutSet, bool]:\n", + " cuts = CutSet(\n", + " LhotseTextPairAdapterFromJsonl(\n", + " manifest_path=config.manifest_path,\n", + " shuffle_shards=config.shuffle,\n", + " shard_seed=config.shard_seed,\n", + " )\n", + " )\n", + " if not config.get(\"force_finite\", False):\n", + " cuts = cuts.repeat()\n", + " return cuts, True" + ] + }, + { + "cell_type": "markdown", + "id": "64367e6596754ee6", + "metadata": {}, + "source": [ + "Note that there is a bit of boilerplate (`expand_sharded_filepaths`, `force_finite`, `shuffle_shards`, `shard_seed`) - we might reduce the amount of necessary boilerplate in the future, but for now it is required.\n", + "\n", + "Let's test that it works. We'll first create two JSONL files (shards) with one entry each, and later use NeMo's path expansion mechanism to provide them as the input configuration.\n", + "\n", + "Then, we'll read it using the high-level API `read_cutset_from_config` that's actually used by NeMo+Lhotse dataloader to show that the auto-registration mechanism works as expected." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7987fce8db39b008", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-10-18 14:12:16 nemo_logging:349] /Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/pydub/utils.py:170: RuntimeWarning: Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\n", + " warn(\"Couldn't find ffmpeg or avconv - defaulting to ffmpeg, but may not work\", RuntimeWarning)\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom=None)\n" + ] + } + ], + "source": [ + "!echo '{\"source\": \"A\", \"target\": \"B\"}' >> _tutorial_nmt_0.jsonl\n", + "!echo '{\"source\": \"C\", \"target\": \"D\"}' >> _tutorial_nmt_1.jsonl\n", + "\n", + "from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config\n", + "\n", + "data, use_iterable_dataset = read_cutset_from_config(\n", + " {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"txt_pair_jsonl\", \n", + " \"manifest_path\": \"_tutorial_nmt__OP_0..1_CL_.jsonl\", \n", + " }\n", + " ]\n", + " }\n", + ")\n", + "\n", + "example = next(iter(data))\n", + "assert isinstance(example, SourceTargetTextExample)\n", + "assert example.source.text == \"A\"\n", + "assert example.target.text == \"B\"\n", + "print(example)" + ] + }, + { + "cell_type": "markdown", + "id": "be48872625d1a2e0", + "metadata": {}, + "source": [ + "## Prompt formatting and conversion of data types to tensors\n", + "\n", + "Since we now understand how data types are read, let's see how to convert them to actual training examples.\n", + "Because this tutorial is focused on multimodal LLM / speech LLM training, we'll be using prompt templates adequate for various LLMs to prepare the training data. In this example, we'll use Llama2 prompt template to format each data type.\n", + "\n", + " We'll need to initialize a prompt formatter and a tokenizer; we'll just train a dummy BPE tokenizer for the purpose of the tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6e1d296be0d363d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-18 14:12:19 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n" + ] + } + ], + "source": [ + "import string\n", + "import shlex\n", + "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n", + "from nemo.collections.common.prompts.formatter import PromptFormatter\n", + "\n", + "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n", + "\n", + "tok_path, vocab_path = create_spt_model(\n", + " data_file=\"_tutorial_train_text.txt\", \n", + " output_dir=\"_tutorial_spt\",\n", + " vocab_size=512, \n", + " sample_size=-1, \n", + " do_lower_case=False, \n", + " bos=True, \n", + " eos=True, \n", + " pad=True, \n", + " user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n", + ")\n", + "\n", + "tokenizer = SentencePieceTokenizer(tok_path)\n", + "prompt = PromptFormatter.resolve(\"llama2\")(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "6988777c9dc1653b", + "metadata": {}, + "source": [ + "Now, we'll convert the data types to a training/inference friendly format. Specifically, we want to have 4 tensors:\n", + "* `context_ids`: token IDs that serve as the input for LLM (e.g. user query, conversation history, etc.)\n", + "* `answer_ids`: token IDs that serve as the answer for LLM (assistant response)\n", + "* `input_ids`: concatenated `context_ids` and `answer_ids`\n", + "* `mask`: loss mask that's only set to `True` for each token belonging to each of assistant's turns. Same length as `input_ids`.\n", + "\n", + "Let's first go through Cut, SourceTargetTextExample, and NeMoMultimodalConversation to see what happens with them." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "5f8c0a54189e443d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cut:\n", + "\t* input_ids [INST] Repeat after me: [/INST] Welcome to Lhotse!\n", + "\t* context_ids [INST] Repeat after me: [/INST]\n", + "\t* answer_ids Welcome to Lhotse!\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True])\n", + "\n", + "SourceTargetTextExample:\n", + "\t* input_ids [INST] Some machine translation example. [/INST] Algunos ejemplos de traducci ⁇ n autom ⁇ tica.\n", + "\t* context_ids [INST] Some machine translation example. [/INST]\n", + "\t* answer_ids Algunos ejemplos de traducci ⁇ n autom ⁇ tica.\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True])\n", + "\n", + "NeMoMultimodalConversation:\n", + "\t* input_ids [INST] Is this a text-only conversation? [/INST] No, feel free to speak to me. [INST] [audio] [/INST] Should I respond in voice too? [INST] Yes [/INST] Certainly! [audio]\n", + "\t* context_ids [INST] Is this a text-only conversation? [/INST] No, feel free to speak to me. [INST] [audio] [/INST] Should I respond in voice too? [INST] Yes [/INST]\n", + "\t* answer_ids Certainly! [audio]\n", + "loss mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " False, False, False, False, False, False, False, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, False,\n", + " False, False, False, False, False, False, False, False, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True])\n", + "\n" + ] + } + ], + "source": [ + "from nemo.collections.common.data.prompt_fn import apply_prompt_format_fn\n", + "\n", + "cut.context = \"Repeat after me:\"\n", + "print(\"Cut:\")\n", + "formatted = apply_prompt_format_fn(cut, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()\n", + "\n", + "print(\"SourceTargetTextExample:\")\n", + "formatted = apply_prompt_format_fn(text_pair, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()\n", + "\n", + "print(\"NeMoMultimodalConversation:\")\n", + "formatted = apply_prompt_format_fn(multimodal_conversation, prompt)\n", + "for name in [\"input_ids\", \"context_ids\", \"answer_ids\"]:\n", + " print(\"\\t*\", name, tokenizer.ids_to_text(formatted[name]))\n", + "print(\"loss mask\", formatted[\"mask\"])\n", + "print()" + ] + }, + { + "cell_type": "markdown", + "id": "e1b50937e5f75d10", + "metadata": {}, + "source": [ + "Note how each example got converted into the same prompt format. \n", + "\n", + "For multimodal conversation we have a special mechanism that replaces audio turns with an `audio_locator_tag`. \n", + "We expect that the tokenizer contains this tag as a special token.\n", + "The user will later replace these special tokens with audio representations (tokenized, or not) in the training step of the model. \n", + "\n", + "If you create a new prompt format, or a new data type, or want to specialize how a given data type is formatted with a given prompt, it is easily customizable by defining a single function with `@registered_prompt_format_fn(DataType, PromptFormatterType)` decorator. For example, if we created a new data type called `TextTriplet`, and added a default prompt format function, and another one specialized for Llama2:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "108b3593a5f16444", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids tensor([ 1, 9, 4, 9, 6, 9, 42, 9, 7, 9, 43, 9, 5, 9, 44, 2])\n", + "context_ids tensor([ 1, 9, 4, 9, 6, 9, 42, 9, 7, 9, 43, 9, 5])\n", + "answer_ids tensor([ 9, 44, 2])\n", + "mask tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, True, True, True])\n" + ] + } + ], + "source": [ + "from nemo.collections.common.prompts import Llama2PromptFormatter\n", + "from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn\n", + "from nemo.collections.common.data.lhotse.text_adapters import Formattable, CustomFieldMixin\n", + "\n", + "\n", + "@dataclass\n", + "class TextTriplet(Formattable, CustomFieldMixin):\n", + " # Note: we will explain Formattable and CustomFieldMixin in the next sections.\n", + " text1: str\n", + " text2: str\n", + " text3: str\n", + "\n", + "\n", + "@registered_prompt_format_fn(TextTriplet)\n", + "def text_triplets_generic(example: TextTriplet, prompt: PromptFormatter):\n", + " return prompt.encode_dialog(turns=[\n", + " {\"role\": \"user\", \"slots\": {\"message\": f\"{example.text1} {example.text2}\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": f\"{example.text3}\"}},\n", + " ])\n", + "\n", + " \n", + "@registered_prompt_format_fn(TextTriplet, Llama2PromptFormatter)\n", + "def text_triplets_llama2(example: TextTriplet, prompt: Llama2PromptFormatter):\n", + " return prompt.encode_dialog(turns=[\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": example.text1 , \"message\": example.text2}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": example.text3}},\n", + " ])\n", + "\n", + "\n", + "formatted = apply_prompt_format_fn(TextTriplet(\"A\", \"B\", \"C\"), prompt)\n", + "for k, v in formatted.items():\n", + " print(k, v)" + ] + }, + { + "cell_type": "markdown", + "id": "9565bef14a863465", + "metadata": {}, + "source": [ + "If we also created a data type parser for `TextTriplet` like we did for `SourceTargetTextExample` in the section before, we have a complete new data type support for dataloading. " + ] + }, + { + "cell_type": "markdown", + "id": "6ac39c8fcbcf5860", + "metadata": {}, + "source": [ + "## Support for sequence length stratification / dynamic bucketing\n", + "\n", + "References: \n", + "* [EMMeTT: Efficient Multimodal Machine Translation Training](https://arxiv.org/abs/2409.13523) \n", + "\n", + "We found that by using dynamic bucketing with [OOMptimizer](https://github.com/NVIDIA/NeMo/blob/main/docs/source/asr/datasets.rst#pushing-gpu-utilization-to-the-limits-with-bucketing-and-oomptimizer) can significantly accelerate multimodal LLM training. \n", + "In order to ensure that all data types can benefit from this acceleration, we introduced the `Formattable` concept.\n", + "It indicates that a given data type supports prompt formatting and provides properties to measure input and output sequence length.\n", + "\n", + "Let's see this in action with the previously formatted data types:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f5ca38ea137f8210", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SourceTargetTextPair:\n", + "\t* input_length 39\n", + "\t* output_length 44\n", + "\t* total_length 83\n", + "\t* len(context_ids) 39\n", + "\t* len(answer_ids) 44\n", + "\t* len(input_ids) 83\n", + "NeMoMultimodalConversation\n", + "\t* input_length 191\n", + "\t* output_length 196\n", + "\t* total_length 387\n", + "\t* len(context_ids) 118\n", + "\t* len(answer_ids) 14\n", + "\t* len(input_ids) 132\n" + ] + } + ], + "source": [ + "print(\"SourceTargetTextPair:\")\n", + "text_pair = text_pair.apply_prompt_format(prompt)\n", + "print(\"\\t*\", \"input_length\", text_pair.input_length)\n", + "print(\"\\t*\", \"output_length\", text_pair.output_length)\n", + "print(\"\\t*\", \"total_length\", text_pair.total_length)\n", + "print(\"\\t*\", \"len(context_ids)\", len(text_pair.context_ids))\n", + "print(\"\\t*\", \"len(answer_ids)\", len(text_pair.answer_ids))\n", + "print(\"\\t*\", \"len(input_ids)\", len(text_pair.input_ids))\n", + "\n", + "print(\"NeMoMultimodalConversation\")\n", + "text_pair = multimodal_conversation.apply_prompt_format(prompt)\n", + "print(\"\\t*\", \"input_length\", multimodal_conversation.input_length)\n", + "print(\"\\t*\", \"output_length\", multimodal_conversation.output_length)\n", + "print(\"\\t*\", \"total_length\", multimodal_conversation.total_length)\n", + "print(\"\\t*\", \"len(context_ids)\", len(multimodal_conversation.context_ids))\n", + "print(\"\\t*\", \"len(answer_ids)\", len(multimodal_conversation.answer_ids))\n", + "print(\"\\t*\", \"len(input_ids)\", len(multimodal_conversation.input_ids))\n" + ] + }, + { + "cell_type": "markdown", + "id": "ecca372c2a0cad6e", + "metadata": {}, + "source": [ + "Note that for `NeMoMultimodalConversation` the length is much greater that the number of text tokens. \n", + "This is where `token_equivalent_duration` comes in: we want to factor in the audio turns into sequence lengths.\n", + "Since we know what is the duration of audio, we only need to know how much duration should be covered by each audio \"token\" or \"frame\".\n", + "A typical setup would be with NeMo FastConformer as an audio encoder, which uses 10ms frames at the input and subsamples them by a factor of 8 in the output. \n", + "The resulting `token_equivalent_duration` is therefore `0.08`, i.e., a single token created from audio is worth 80ms of duration. \n", + "For length computation, we sum the number of text tokens and the equivalent number of audio tokens.\n", + "\n", + "We can see that lhotse's `DynamicBucketingSampler` is able to process this data using NeMo multimodal sampling strategies:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6e295cfbfe8ff69b", + "metadata": {}, + "outputs": [], + "source": [ + "from lhotse.dataset import DynamicBucketingSampler\n", + "from nemo.collections.common.data.lhotse.sampling import MultimodalFixedBucketBatchSizeConstraint2D\n", + "\n", + "cuts = CutSet([multimodal_conversation]).repeat() # repeat makes iterable infinite\n", + "sampler = DynamicBucketingSampler(\n", + " cuts, \n", + " constraint=MultimodalFixedBucketBatchSizeConstraint2D(\n", + " max_seq_len_buckets=[32, 64, 128, 256, 512, 1024, 1536, 2048],\n", + " batch_sizes=[8, 7, 6, 5, 4, 3, 2, 1],\n", + " token_equivalent_duration=0.08, \n", + " measure_total_length=True,\n", + " ),\n", + " buffer_size=10,\n", + ")\n", + "\n", + "batch = next(iter(sampler))\n", + "assert len(batch) == 4 \n", + "# Our conversation example fell into bucket number 4 (min: 256, max: 512) with an assigned batch size of 4" + ] + }, + { + "cell_type": "markdown", + "id": "4ff5baae-0771-4ac9-aa68-c3faee5aa261", + "metadata": {}, + "source": [ + "## Putting it all together to configure joint audio, text, and conversation dataloading\n", + "\n", + "We'll showcase some higher level APIs here. First, we'll create data examples on disk for three distinct types: audio to text, text to text, and multimodal conversations." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "5a0e5433-3e63-4ab2-9290-001159a9b8e0", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "from lhotse.serialization import save_to_jsonl\n", + "from lhotse.testing.dummies import dummy_recording\n", + "\n", + "# Prepare dummy ASR data\n", + "d = Path(\"_tutorial_data\")\n", + "!mkdir -p {d}/asr_shar\n", + "cut = dummy_recording(0, duration=17.11, with_data=True).to_cut()\n", + "cut.supervisions = [SupervisionSegment(id=cut.id, recording_id=cut.id, start=0.0, duration=cut.duration, text=\"Welcome to Lhotse!\")]\n", + "cut.context = \"Repeat after me\"\n", + "CutSet([cut.save_audio(d / \"rec.flac\")]).to_shar(d / \"asr_shar\", fields={\"recording\": \"flac\"})\n", + "\n", + "# Prepare dummy translation data\n", + "(d / \"src.txt\").write_text(\"A\")\n", + "(d / \"tgt.txt\").write_text(\"B\")\n", + "\n", + "# Prepare dummy multimodal conversation\n", + "save_to_jsonl(\n", + " [\n", + " {\n", + " \"id\": \"convo-1\",\n", + " \"conversations\": [\n", + " {\"from\": \"user\", \"value\": \"tell me what you hear\", \"type\": \"text\"},\n", + " {\"from\": \"user\", \"value\": str(d / \"rec.flac\"), \"duration\": cut.duration, \"type\": \"audio\"},\n", + " {\"from\": \"assistant\", \"value\": \"somebody just welcomed me to a himalayan mountain\", \"type\": \"text\"},\n", + " ]\n", + " }\n", + " ],\n", + " d / \"conv.jsonl\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3a4d669b-f816-4522-a491-ba31bfbf689c", + "metadata": {}, + "source": [ + "Now we'll configure a Lhotse dataloader to yield mini-batches with different data types in a round-robin fashion." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c4a7364e-c00f-4f60-9d72-9e7d228121cb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n", + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n", + "[NeMo I 2024-10-18 14:12:19 dataloader:481] Creating a Lhotse DynamicBucketingSampler (max_batch_duration=None max_batch_size=None)\n" + ] + } + ], + "source": [ + "import torch\n", + "from omegaconf import OmegaConf\n", + "from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config\n", + "\n", + "# This configuration is typically present in NeMo training configs under `model.train_ds` key.\n", + "cfg = OmegaConf.create({\n", + " # Note that we have several sampler groups under keys: \"asr\", \"nmt\", and \"chat\".\n", + " # Each group has its own data source and sampling settings, i.e., you can define\n", + " # completely different batch sizes, sequence length filters, etc. for each type of data.\n", + " # To enable this behaviour, set multi_config to True.\n", + " \"multi_config\": True,\n", + " \n", + " # The following fields are shared by all groups.\n", + " # sampler_fusion key determines how to yield batches from different samplers:\n", + " # * \"round_robin\" will just yield one type at a time\n", + " # * \"zip\" will sample a batch for each type and concatenate them, yielding a larger multimodal batch\n", + " # * \"randomized_round_robin\" expects an extra \"sampler_weights\" option which will define sampling probs for each group.:\n", + " \"sampler_fusion\": \"round_robin\",\n", + " \"shuffle\": True,\n", + " \"num_workers\": 0,\n", + " \"seed\": 0,\n", + " \"shard_seed\": \"trng\",\n", + " \n", + " \"asr\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"lhotse_shar\", \n", + " \"shar_path\": d / \"asr_shar\"\n", + " }\n", + " ],\n", + " \"min_duration\": 0.5,\n", + " \"max_duration\": 40,\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [5, 10, 20, 40],\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + "\n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + " },\n", + "\n", + " \"nmt\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"txt_pair\", \n", + " \"source_paths\": d / \"src.txt\", \n", + " \"target_paths\": d / \"tgt.txt\"\n", + " }\n", + " ],\n", + " \"use_multimodal_sampling\": True, # will count tokens instead of seconds\n", + " \"min_tokens\": 1,\n", + " \"max_tokens\": 32,\n", + " \"measure_total_length\": False, # filters by input length instead of total length\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [[16, 16], [16, 32], [32, 16], [32, 32]], # 2D buckets\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + " \n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + " },\n", + "\n", + " \"chat\": {\n", + " \"input_cfg\": [\n", + " {\n", + " \"type\": \"multimodal_conversation\", \n", + " \"manifest_filepath\": d / \"conv.jsonl\", \n", + " \"audio_locator_tag\": \"[audio]\"\n", + " }\n", + " ],\n", + " \"use_multimodal_sampling\": True, # will count tokens instead of seconds\n", + " \"min_tokens\": 1,\n", + " \"max_tokens\": 1024,\n", + " \"measure_total_length\": True,\n", + " \"token_equivalent_duration\": 0.08,\n", + " \"use_bucketing\": True,\n", + " \"bucket_duration_bins\": [128, 256, 512, 1024],\n", + " \"bucket_batch_size\": [4, 3, 2, 1],\n", + " \"prompt_format\": \"llama2\",\n", + "\n", + " # Simplified settings for quick tutorial running (don't use those in real applciations).\n", + " \"concurrent_bucketing\": False,\n", + " \"bucket_buffer_size\": 50,\n", + " \"shuffle_buffer_size\": 50,\n", + " },\n", + "})\n", + "\n", + "\n", + "# A no-op PyTorch Dataset class that will just return the data structures.\n", + "# In a real training setup, you'll want to implement conversion of a list of examples to a tensor mini-batch\n", + "# that is adequate for your model. \n", + "# Note that you can handle multiple types of examples to create appropriate mini-batch schema for each.\n", + "class Identity(torch.utils.data.Dataset):\n", + " def __getitem__(self, examples: CutSet):\n", + " return examples\n", + "\n", + "dloader = get_lhotse_dataloader_from_config(cfg, global_rank=0, world_size=1, dataset=Identity(), tokenizer=tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e8768e28-663b-4d69-bb31-fbd6b80c0389", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 0. Examples:\n", + "\t* MonoCut(id='dummy-recording-0000_repeat10', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 10, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* MonoCut(id='dummy-recording-0000_repeat41', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 41, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 1. Examples:\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 2. Examples:\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat0', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat1', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* NeMoMultimodalConversation(id='convo-1_repeat2', turns=[TextTurn(value='tell me what you hear', role='user'), AudioTurn(cut=MonoCut(id='rec', start=0.0, duration=17.11, channel=0, supervisions=[], features=None, recording=Recording(id='rec', sources=[AudioSource(type='file', channels=[0], source='_tutorial_data/rec.flac')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom=None), role='user', audio_locator_tag='[audio]'), TextTurn(value='somebody just welcomed me to a himalayan mountain', role='assistant')], token_equivalent_duration=0.08, custom={'input_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5, 9, 92, 88, 86, 78, 75, 88,\n", + " 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85, 76, 88, 86, 78, 77, 9, 86,\n", + " 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74, 85, 74, 98, 74, 87, 9, 86,\n", + " 88, 94, 87, 93, 74, 82, 87, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 93, 78, 85, 85, 9, 86, 78, 9, 96, 81, 74, 93, 9, 98,\n", + " 88, 94, 9, 81, 78, 74, 91, 9, 8, 9, 5]), 'answer_ids': tensor([ 9, 92, 88, 86, 78, 75, 88, 77, 98, 9, 83, 94, 92, 93, 9, 96, 78, 85,\n", + " 76, 88, 86, 78, 77, 9, 86, 78, 9, 93, 88, 9, 74, 9, 81, 82, 86, 74,\n", + " 85, 74, 98, 74, 87, 9, 86, 88, 94, 87, 93, 74, 82, 87, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 3. Examples:\n", + "\t* MonoCut(id='dummy-recording-0000_repeat67', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 67, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* MonoCut(id='dummy-recording-0000_repeat16', start=0, duration=17.11, channel=0, supervisions=[SupervisionSegment(id='dummy-recording-0000', recording_id='dummy-recording-0000', start=0.0, duration=17.11, channel=0, text='Welcome to Lhotse!', language=None, speaker=None, gender=None, custom=None, alignment=None)], features=None, recording=Recording(id='rec', sources=[AudioSource(type='memory', channels=[0], source='')], sampling_rate=16000, num_samples=273760, duration=17.11, channel_ids=[0], transforms=None), custom={'context': 'Repeat after me', 'shard_origin': PosixPath('_tutorial_data/asr_shar/cuts.000000.jsonl.gz'), 'shar_epoch': 16, 'input_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5, 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88,\n", + " 93, 92, 78, 10, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 59, 78, 89, 78, 74, 93, 9, 74, 79, 93, 78, 91, 9, 86,\n", + " 78, 9, 5]), 'answer_ids': tensor([ 9, 64, 78, 85, 76, 88, 86, 78, 9, 93, 88, 9, 53, 81, 88, 93, 92, 78,\n", + " 10, 2]), 'mask': tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n", + "Step 4. Examples:\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\t* SourceTargetTextExample(source=TextExample(text='A', language=None, tokens=None, custom=None), target=TextExample(text='B', language=None, tokens=None, custom=None), question=None, custom={'input_ids': tensor([ 1, 9, 4, 9, 42, 9, 5, 9, 43, 2]), 'context_ids': tensor([ 1, 9, 4, 9, 42, 9, 5]), 'answer_ids': tensor([ 9, 43, 2]), 'mask': tensor([False, False, False, False, False, False, False, True, True, True]), 'dataloading_info': {'rank': 0, 'world_size': 1, 'worker_id': None}})\n", + "\n" + ] + } + ], + "source": [ + "for idx, batch in enumerate(dloader):\n", + " if idx == 5:\n", + " break\n", + " print(f\"Step {idx}. Examples:\")\n", + " for item in batch:\n", + " print(\"\\t*\", item)\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "704c44f5-bcce-4b4f-828b-fa1e18de8d71", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/multimodal/Prompt Formatter Tutorial.ipynb b/tutorials/multimodal/Prompt Formatter Tutorial.ipynb new file mode 100644 index 000000000000..85f220115e13 --- /dev/null +++ b/tutorials/multimodal/Prompt Formatter Tutorial.ipynb @@ -0,0 +1,458 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cd408a7a-d4b6-4f33-83d3-c607dbc5f580", + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "source": [ + "# Prompt Formatter Tutorial\n", + "\n", + "This tutorial introduces NeMo's PromptFormatter API available in module `nemo.collections.common.prompts`.\n", + "After finishing this tutorial you will be familiar with the existing prompt formatters, how to use them, and how to build your own.\n", + "\n", + "We cover the following topics:\n", + "\n", + "* Using existing prompt formatters with Llama2 as an example.\n", + "\n", + "* Defining your own prompt formatter.\n", + "\n", + "We also support applying prompt formatters for multimodal data and Lhotse-compatible data types. To learn more, see our other tutorial: [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb)" + ] + }, + { + "cell_type": "markdown", + "id": "3f87f30c-79c0-41e8-b126-283ff5436465", + "metadata": {}, + "source": [ + "### Pre-requsite: building a dummy tokenizer\n", + "\n", + "We're going to need a tokenizer to work with prompt formatters - we'll just build a dummy one for the purpose of this tutorial." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e91ebef5-9a25-4eb1-8211-d0f5990f7c37", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/pzelasko/miniforge3/envs/nemo/lib/python3.10/site-packages/transformers/utils/generic.py:441: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.\n", + " _torch_pytree._register_pytree_node(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-10-23 11:26:41 sentencepiece_tokenizer:333] tokenizer model _tutorial_spt/tokenizer.model already exists\n" + ] + } + ], + "source": [ + "import string\n", + "import shlex\n", + "from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model\n", + "\n", + "!echo {shlex.quote(' '.join(string.printable))} > _tutorial_train_text.txt\n", + "\n", + "tok_path, vocab_path = create_spt_model(\n", + " data_file=\"_tutorial_train_text.txt\", \n", + " output_dir=\"_tutorial_spt\",\n", + " vocab_size=512, \n", + " sample_size=-1, \n", + " do_lower_case=False, \n", + " bos=True, \n", + " eos=True, \n", + " pad=True, \n", + " user_defined_symbols=[\"[INST]\", \"[/INST]\", \"<>\", \"<>\", \"[audio]\"]\n", + ")\n", + "\n", + "tokenizer = SentencePieceTokenizer(tok_path)\n", + "\n", + "def display(encoded_chat, with_mask=False):\n", + " \"\"\"Utility for printing prompt formatted chats.\"\"\"\n", + " for key, val in encoded_chat.items():\n", + " if key.endswith(\"_ids\"):\n", + " print(key, '--', tokenizer.ids_to_text(val), '\\n')\n", + " if key == \"mask\" and with_mask:\n", + " print(key, '--', val)" + ] + }, + { + "cell_type": "markdown", + "id": "4c5c6c88-c882-4305-8757-585fec3eab46", + "metadata": {}, + "source": [ + "## Using an existing PromptFormatter: Llama2\n", + "\n", + "\n", + "**Instanting the prompt formatter.** Let's start with a simple example of Llama2 prompt format use." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c77a993e-453f-474e-8912-fd35c7fc39ba", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.prompts.llama import Llama2PromptFormatter\n", + "from pprint import pprint\n", + "\n", + "prompt = Llama2PromptFormatter(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "92054a0f-5b97-4178-94b8-a27e62acf97b", + "metadata": {}, + "source": [ + "**Chat example.** We'll define a multi-turn conversation between the user and assistant below:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c5eabe5e-4160-41d7-ad85-a4df596de38b", + "metadata": {}, + "outputs": [], + "source": [ + "chat = [\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "eff61b98-c7be-4345-ac97-15573d1a9533", + "metadata": {}, + "source": [ + "**Prompt formatter outputs.** Now, we apply prompt formatter to that conversation to obtain four tensors useful for training:\n", + "* `context_ids` encode the whole dialog history up to the last response of the assistant;\n", + "* `answer_ids` encode the last response of the assistant;\n", + "* `input_ids` encode the full conversation;\n", + "* `mask` is a boolean training loss mask that's set to `True` for every token belonging to assistant's turns.\n", + "\n", + "Since the token IDs are meaningless, we'll apply reverse tokenizer for displaying the prompt formatted example." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a10216b3-2bbe-4a2f-8ca8-557c3b9056be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- [INST] Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "answer_ids -- In order to build your own audio amplifier, start with ... \n", + "\n", + "mask -- tensor([False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True])\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(chat)\n", + "display(encoded, with_mask=True)" + ] + }, + { + "cell_type": "markdown", + "id": "e181618e-6df8-44b2-b986-15660133e486", + "metadata": {}, + "source": [ + "**System prompt.** We also support the system prompt. Since it affects the prompt format in a non-trivial way, it is defined as a separate role `\"system_and_user\"`, which has two slots `\"system\"` and `\"message\"`. We'll omit printing the mask for brevity." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2c3476a4-b301-4f35-9520-90d4b919363d", + "metadata": {}, + "outputs": [], + "source": [ + "chat_with_system = [\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"In order to build your own audio amplifier, start with ...\"}},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5c8c329d-f8b3-48cb-b664-baed0fcd90ab", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "answer_ids -- In order to build your own audio amplifier, start with ... \n", + "\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(chat_with_system)\n", + "display(encoded)" + ] + }, + { + "cell_type": "markdown", + "id": "a453345a-6456-43ed-a663-0554c459fddb", + "metadata": {}, + "source": [ + "**Constructing inference-time prompts.** During inference, we don't know what's the last turn of the assistant - we only want to construct the ``context_ids`` tensor. In those cases, just omit the last assistant's turn. The prompt formatter will return the ``context_ids`` tensor (with ``input_ids`` alias for it too)." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4ede7100-9d28-4cf0-ab75-bfede9936218", + "metadata": {}, + "outputs": [], + "source": [ + "inference_chat = [\n", + " {\"role\": \"system_and_user\", \"slots\": {\"system\": \"You are a sales rep in an electronics store.\", \"message\": \"Do you know something about electronics?\"}},\n", + " {\"role\": \"assistant\", \"slots\": {\"message\": \"Sure, ask away.\"}},\n", + " {\"role\": \"user\", \"slots\": {\"message\": \"How to build my own audio amplifier?\"}},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "61bf8e77-0630-4a84-bd30-ca4c27f8d898", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n", + "context_ids -- [INST] <> You are a sales rep in an electronics store. <> Do you know something about electronics? [/INST] Sure, ask away. [INST] How to build my own audio amplifier? [/INST] \n", + "\n" + ] + } + ], + "source": [ + "encoded = prompt.encode_dialog(inference_chat)\n", + "display(encoded)" + ] + }, + { + "cell_type": "markdown", + "id": "a334e00a-9530-4333-98de-5cb8fb08eb47", + "metadata": {}, + "source": [ + "### How is Llama2 PromptFormatter built\n", + "\n", + "`Llama2PromptFormatter` is a small class with prompt definition that inherits `PromptFormatter`, which implements the logic for applying prompt format and tokenization to multi-turn conversations. \n", + "\n", + "Let's take a look at `Llama2PromptFormatter` definition:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f29fbf2f-3caa-4b27-86ca-5012d9fc6ba5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class Llama2PromptFormatter(PromptFormatter):\n", + " \"\"\"\n", + " This template has been validated to provide identical tokenized results to the official code\n", + " in https://github.com/meta-llama/llama/blob/main/llama/generation.py\n", + " \"\"\"\n", + "\n", + " NAME = \"llama2\"\n", + " OUTPUT_ROLE = \"assistant\"\n", + " TEMPLATE = {\n", + " \"system_and_user\": {\n", + " \"template\": f\"{BOS_SLOT}[INST] <>\\n|system|\\n<>\\n\\n|message| [/INST]\",\n", + " \"slots\": {\n", + " \"system\": Modality.Text,\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " \"user\": {\n", + " \"template\": f\"{BOS_SLOT}[INST] |message| [/INST]\",\n", + " \"slots\": {\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " OUTPUT_ROLE: {\n", + " \"template\": f\"|message| {EOS_SLOT}\",\n", + " \"slots\": {\n", + " \"message\": Modality.Text,\n", + " },\n", + " },\n", + " }\n", + "\n" + ] + } + ], + "source": [ + "import inspect\n", + "print(inspect.getsource(Llama2PromptFormatter))" + ] + }, + { + "cell_type": "markdown", + "id": "b24e9310-b8ed-4e35-9dda-d24aa62cfb6a", + "metadata": {}, + "source": [ + "As you can see, the definition consist of the following key components:\n", + "* Derives `PromptFormatter` parent class.\n", + "* Specifies `NAME`, which is used for dynamic resolution of string to class via `cls = PromptFormatter.resolve(name)`.\n", + "* Specifies `OUTPUT_ROLE`, which is the name for the role with assistant's responses (typically `\"assistant\"`).\n", + "* Specifies `TEMPLATE` which defines the dialog structure and how user-provided values (slots) are applied to prompts. Notably:\n", + " * The slots are wrapped into pipe operators `\"|\"` in the prompt template definition, and substituted with user provided values before tokenization.\n", + " * `\"system_and_user`\" role has two slots, `\"system\"` and `\"message\"`, and a template that wraps them with Llama2 special tokens.\n", + " * We use `BOS_SLOT` and `EOS_SLOT` to insert sentencepiece tokenizer's `bos_id` and `eos_id` in the right places (remember that sentencepiece won't tokenize them from text, they need to be inserted programmatically).\n", + " * The slots have a type, currently supported types are `Modality.Text` and `Modality.TextLiteral(value1, value2, ...)` that allows to restrict the set of slots values." + ] + }, + { + "cell_type": "markdown", + "id": "8cbdca6c-6c0f-42a9-a4a7-b936684c6e12", + "metadata": {}, + "source": [ + "## Defining your own prompt formatter" + ] + }, + { + "cell_type": "markdown", + "id": "25a9b6d2-d004-4f7f-8b24-4fd6d4eae244", + "metadata": {}, + "source": [ + "Generally you can follow the definition of existing prompt formatters to define your own. \n", + "We have several prompt formats implemented for Llama, Gemma, Phi, etc. \n", + "\n", + "We'll define a custom simple prompt format that has no system prompt below as an illustration:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b69f6532-24d8-4419-b1da-42184c3d72de", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.common.prompts.formatter import PromptFormatter, Modality\n", + "\n", + "class MyPrompt(PromptFormatter):\n", + " NAME = \"myprompt\"\n", + " OUTPUT_ROLE = \"assistant\"\n", + " TEMPLATE = {\n", + " \"user\": {\n", + " \"template\": \"User: |message|\\n\",\n", + " \"slots\": {\"message\": Modality.Text},\n", + " },\n", + " \"assistant\": {\n", + " \"template\": \"Assistant: |message|\\n\",\n", + " \"slots\": {\"message\": Modality.Text},\n", + " },\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a97c6589-1303-446c-952f-d2b4007ca7e9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "input_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? Assistant: In order to build your own audio amplifier, start with ... \n", + "\n", + "context_ids -- User: Do you know something about electronics? Assistant: Sure, ask away. User: How to build my own audio amplifier? \n", + "\n", + "answer_ids -- Assistant: In order to build your own audio amplifier, start with ... \n", + "\n" + ] + } + ], + "source": [ + "my_prompt_cls = PromptFormatter.resolve(\"myprompt\") # it is auto-registered\n", + "my_prompt = my_prompt_cls(tokenizer)\n", + "display(my_prompt.encode_dialog(chat))" + ] + }, + { + "cell_type": "markdown", + "id": "30f9c96a-6cf8-4cd3-b0e8-6b461c86100f", + "metadata": {}, + "source": [ + "## Applying prompt formatter to multimodal data\n", + "\n", + "We refer the reader to our other tutorial, [Multimodal Lhotse Dataloading](./Multimodal Lhotse Dataloading.ipynb), where this is discussed in detail." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}