From 6051bad2f98e2ed80f55c311118dd8a1b8c27691 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 21 Jul 2025 07:06:24 -0700 Subject: [PATCH 01/21] Added ConvertToTarredAudioDataset processor Signed-off-by: Sasha Meister --- .../convert_to_tarred_audio_dataset.py | 160 +++ .../utils/convert_to_tarred_audio_dataset.py | 1021 +++++++++++++++++ .../utils/create_dali_tarred_dataset_index.py | 95 ++ 3 files changed, 1276 insertions(+) create mode 100644 sdp/processors/manage_files/convert_to_tarred_audio_dataset.py create mode 100644 sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py create mode 100644 sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py diff --git a/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py new file mode 100644 index 00000000..93f4a970 --- /dev/null +++ b/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py @@ -0,0 +1,160 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +from dataclasses import dataclass +from typing import Optional +from copy import deepcopy +from tqdm import tqdm +import shutil + +from sdp.processors.base_processor import BaseProcessor +from sdp.processors.manage_files.utils.convert_to_tarred_audio_dataset import create_tar_datasets + + +@dataclass +class ConvertToTarredAudioDatasetConfig: + """ + Configuration class for ConvertToTarredAudioDataset. + + Attributes: + max_duration (float): Maximum allowed duration for audio samples. + min_duration (Optional[float]): Minimum allowed duration for audio samples. + concat_manifest_paths (Optional[str]): Path to a manifest file containing multiple manifest paths to concatenate. + target_dir (Optional[str]): Output directory to save tarred dataset. + metadata_path (Optional[str]): Path to write metadata about the tarred dataset. + num_shards (int): Number of shards to create. If -1, it will be determined automatically. + shuffle (bool): Whether to shuffle the input manifest before processing. + keep_files_together (bool): If True, all segments from the same source file are kept in the same shard. + sort_in_shards (bool): If True, samples inside each shard will be sorted by duration. + buckets_num (int): Number of duration-based buckets to split data into. + dynamic_buckets_num (int): Number of dynamic buckets for load balancing. + shuffle_seed (Optional[int]): Random seed used for shuffling. + write_metadata (bool): Whether to write metadata JSON files during processing. + no_shard_manifests (bool): If True, disables writing per-shard manifest files. + force_codec (Optional[str]): Audio codec to use when re-encoding audio files. + workers (int): Number of worker processes for parallel audio re-encoding. + slice_with_offset (bool): If True, audio slices will use offset and duration fields. + only_manifests (bool): If True, only manifests will be generated without audio re-encoding. + """ + max_duration: float + min_duration: Optional[float] = None + concat_manifest_paths: Optional[str] = None + target_dir: Optional[str] = None + metadata_path: Optional[str] = None + num_shards: int = -1 + shuffle: bool = False + keep_files_together: bool = False + sort_in_shards: bool = False + buckets_num: int = 1 + dynamic_buckets_num: int = 30 + shuffle_seed: Optional[int] = None + write_metadata: bool = False + no_shard_manifests: bool = False + force_codec: Optional[str] = None + workers: int = 1 + slice_with_offset: bool = False + only_manifests: bool = False + + +class ConvertToTarredAudioDataset(BaseProcessor): + """ + A processor for converting audio manifests into tarred audio datasets. + + This processor optionally splits data into duration-based buckets, and calls the + `create_tar_datasets` utility to convert and shard audio data into tar files, + with accompanying manifest files. + + Args: + output_manifest_file (str): Path to the final output manifest. + input_manifest_file (str): Path to the input manifest to be tarred. + **cfg_kwargs: Additional keyword arguments passed to the configuration dataclass. + + Returns: + Writes a tarred and sharded audio dataset to disk. + + - The dataset consists of multiple `.tar` archives with audio files. + - A final manifest (JSON lines format) is written to ``output_manifest_file``, + referencing each sample, its path inside the tar, and other metadata. + - If ``buckets_num > 1``, each sample will include an additional ``bucket_id`` field. + + .. note:: + If `buckets_num > 1`, the input manifest is split into multiple duration buckets, + and each bucket is processed independently. A `bucket_id` is added to each sample. + + You may need to install the extra dependencies of Lhotse and NeMo for this processor to work correctly: + ``pip install lhotse "nemo-toolkit[common]"`` + + """ + + def __init__( + self, + output_manifest_file: str, + input_manifest_file: str = None, + **cfg_kwargs, + ): + super().__init__( + input_manifest_file=input_manifest_file, + output_manifest_file=output_manifest_file + ) + self.cfg = ConvertToTarredAudioDatasetConfig(**cfg_kwargs) + + def process(self): + # If bucketing is enabled, divide the data based on duration ranges. + if self.cfg.buckets_num > 1: + with open(self.output_manifest_file, 'w', encoding='utf8') as fout: + bucket_length = (self.cfg.max_duration - self.cfg.min_duration) / float(self.cfg.buckets_num) + + for i_bucket in range(self.cfg.buckets_num): + # Create a config for the current bucket + bucket_config = deepcopy(self.cfg) + bucket_config.min_duration = self.cfg.min_duration + i_bucket * bucket_length + bucket_config.max_duration = bucket_config.min_duration + bucket_length + if i_bucket == self.cfg.buckets_num - 1: + # Ensure final bucket includes edge cases + bucket_config.max_duration += 1e-5 + + bucket_config.target_dir = os.path.join(self.cfg.target_dir, f"bucket{i_bucket+1}") + + print(f"Creating bucket {i_bucket+1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ...") + print(f"Results are being saved at: {bucket_config.target_dir}.") + + # Create tarred dataset for the current bucket + create_tar_datasets( + manifest_path=self.input_manifest_file, + **vars(bucket_config) + ) + + # Read and modify the output manifest from this bucket + bucket_manifest_path = os.path.join(bucket_config.target_dir, 'tarred_audio_manifest.json') + with open(bucket_manifest_path, 'r', encoding='utf8') as bin_f: + for line in tqdm(bin_f, desc="Writing output manifest.."): + entry = json.loads(line) + entry['bucket_id'] = i_bucket + line = json.dumps(entry) + fout.writelines(f'{line}\n') + + print(f"Bucket {i_bucket+1} is created.") + + else: + # No bucketing — create single tarred dataset + create_tar_datasets( + manifest_path=self.input_manifest_file, + **vars(self.cfg) + ) + + # Copy the generated manifest to the target location + tarred_audio_manifest = os.path.join(self.cfg.target_dir, 'tarred_audio_manifest.json') + shutil.copy(tarred_audio_manifest, self.output_manifest_file) \ No newline at end of file diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py new file mode 100644 index 00000000..c06462a9 --- /dev/null +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -0,0 +1,1021 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +# This script converts an existing audio dataset with a manifest to +# a tarred and sharded audio dataset that can be read by the +# TarredAudioToTextDataLayer. + +# Please make sure your audio_filepath DOES NOT CONTAIN '-sub'! +# Because we will use it to handle files which have duplicate filenames but with different offsets +# (see function create_shard for details) + + +# Bucketing can help to improve the training speed. You may use --buckets_num to specify the number of buckets. +# It creates multiple tarred datasets, one per bucket, based on the audio durations. +# The range of [min_duration, max_duration) is split into equal sized buckets. +# Recommend to use --sort_in_shards to speedup the training by reducing the paddings in the batches +# More info on how to use bucketing feature: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/datasets.html + +# If valid NVIDIA DALI version is installed, will also generate the corresponding DALI index files that need to be +# supplied to the config in order to utilize webdataset for efficient large dataset handling. +# NOTE: DALI + Webdataset is NOT compatible with Bucketing support ! + +# Usage: +1) Creating a new tarfile dataset + +python convert_to_tarred_audio_dataset.py \ + --manifest_path= \ + --target_dir= \ + --num_shards= \ + --max_duration= \ + --min_duration= \ + --shuffle --shuffle_seed=1 \ + --sort_in_shards \ + --force_codec=flac \ + --workers=-1 + + +2) Concatenating more tarfiles to a pre-existing tarred dataset + +python convert_to_tarred_audio_dataset.py \ + --manifest_path= \ + --metadata_path= \ + --target_dir= \ + --max_duration= \ + --min_duration= \ + --shuffle --shuffle_seed=1 \ + --sort_in_shards \ + --workers=-1 \ + --concat_manifest_paths + + +3) Writing an empty metadata file + +python convert_to_tarred_audio_dataset.py \ + --target_dir= \ + # any other optional argument + --num_shards=8 \ + --max_duration=16.7 \ + --min_duration=0.01 \ + --shuffle \ + --workers=-1 \ + --sort_in_shards \ + --shuffle_seed=1 \ + --write_metadata + +""" +import argparse +import copy +import json +import os +import random +import tarfile +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from io import BytesIO +from typing import Any, List, Optional + +import numpy as np +import soundfile as sf +from joblib import Parallel, delayed +from omegaconf import DictConfig, OmegaConf, open_dict +from tabulate import tabulate +from tqdm import tqdm + +try: + import create_dali_tarred_dataset_index as dali_index + + DALI_INDEX_SCRIPT_AVAILABLE = True +except (ImportError, ModuleNotFoundError, FileNotFoundError): + DALI_INDEX_SCRIPT_AVAILABLE = False + + +@dataclass +class ASRTarredDatasetConfig: + num_shards: int = -1 + shuffle: bool = False + max_duration: Optional[float] = None + min_duration: Optional[float] = None + shuffle_seed: Optional[int] = None + sort_in_shards: bool = True + slice_with_offset: bool = True + shard_manifests: bool = True + keep_files_together: bool = False + force_codec: Optional[str] = None + use_lhotse: bool = False + use_bucketing: bool = False + num_buckets: Optional[int] = None + bucket_duration_bins: Optional[list[float]] = None + + +@dataclass +class ASRTarredDatasetMetadata: + created_datetime: Optional[str] = None + version: int = 0 + num_samples_per_shard: Optional[int] = None + is_concatenated_manifest: bool = False + + dataset_config: Optional[ASRTarredDatasetConfig] = field(default_factory=lambda: ASRTarredDatasetConfig()) + history: Optional[List[Any]] = field(default_factory=lambda: []) + + def __post_init__(self): + self.created_datetime = self.get_current_datetime() + + def get_current_datetime(self): + return datetime.now().strftime("%m-%d-%Y %H-%M-%S") + + @classmethod + def from_config(cls, config: DictConfig): + obj = cls() + obj.__dict__.update(**config) + return obj + + @classmethod + def from_file(cls, filepath: str): + config = OmegaConf.load(filepath) + return ASRTarredDatasetMetadata.from_config(config=config) + + +class ASRTarredDatasetBuilder: + """ + Helper class that constructs a tarred dataset from scratch, or concatenates tarred datasets + together and constructs manifests for them. + """ + + def __init__(self): + self.config = None + + def configure(self, config: ASRTarredDatasetConfig): + """ + Sets the config generated from command line overrides. + + Args: + config: ASRTarredDatasetConfig dataclass object. + """ + self.config = config # type: ASRTarredDatasetConfig + + if self.config.num_shards < 0: + raise ValueError("`num_shards` must be > 0. Please fill in the metadata information correctly.") + + def create_new_dataset( + self, + manifest_path: str, + target_dir: str = "./tarred/", + num_workers: int = 0, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + only_manifests: bool = False, + dry_run: bool = False, + ): + """ + Creates a new tarred dataset from a given manifest file. + + Args: + manifest_path (str): Path to the original ASR manifest file. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred/". + num_workers (int, optional): Number of parallel worker processes for writing tar files. Defaults to 0 (sequential processing). + buckets_num (int, optional): Number of buckets for static bucketing. Defaults to 1 (no bucketing). + dynamic_buckets_num (int, optional): Number of buckets to estimate for dynamic bucketing. Defaults to 30. + only_manifests (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. + + Raises: + ValueError: If the configuration has not been set. + FileNotFoundError: If the manifest file does not exist. + + Output: + - Creates tar files and a tarred dataset compatible manifest file in the specified `target_dir`. + - Preserves a record of the metadata used to construct the tarred dataset in `metadata.yaml`. + - Optionally creates shard manifests if `config.shard_manifests` is enabled. + + Notes: + - The function reads the manifest, applies filtering and shuffling if specified, and creates shards of tar files. + - It generates shard manifests and the main tarred dataset manifest. + - Metadata is updated and saved based on the tarred dataset configuration. + """ + if self.config is None: + raise ValueError("Config has not been set. Please call `configure(config: ASRTarredDatasetConfig)`") + + if manifest_path is None: + raise FileNotFoundError("Manifest filepath cannot be None !") + + config = self.config # type: ASRTarredDatasetConfig + + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Read the existing manifest + entries, total_duration, filtered_entries, filtered_duration = self._read_manifest(manifest_path, config) + + header = [ + "Min.\nduration", + "Max.\nduration", + "Entries amount\nafter filtration", + "Total duration\nafter filtration", + "Shards\namount", + "Entries\nper shard", + "Remainded\nentries", + ] + + entires_amount = f'{len(entries)} / {len(entries) + len(filtered_entries)}' + entries_duration = f'{total_duration:.2f} / {total_duration + filtered_duration:.2f} s' + entries_per_shard = len(entries) // config.num_shards + remainder = len(entries) % config.num_shards + + data = [ + [ + f"{config.min_duration} s", + f"{config.max_duration} s", + f"{entires_amount}", + f"{entries_duration}", + f"{config.num_shards}", + f"{entries_per_shard}", + f"{remainder}", + ] + ] + + print('\n' + tabulate(data, headers=header, tablefmt="grid", colalign=["center"] * len(header))) + if dry_run: + return + + if len(entries) == 0: + print("No tarred dataset was created as there were 0 valid samples after filtering!") + return + if config.shuffle: + random.seed(config.shuffle_seed) + print(f"Shuffling (seed: {config.shuffle_seed})...") + if config.keep_files_together: + filename_entries = defaultdict(list) + for ent in entries: + filename_entries[ent["audio_filepath"]].append(ent) + filenames = list(filename_entries.keys()) + random.shuffle(filenames) + shuffled_entries = [] + for filename in filenames: + shuffled_entries += filename_entries[filename] + entries = shuffled_entries + else: + random.shuffle(entries) + + start_indices = [] + end_indices = [] + # Build indices + for i in range(config.num_shards): + start_idx = (len(entries) // config.num_shards) * i + end_idx = start_idx + (len(entries) // config.num_shards) + print(f"Shard {i} has entries {start_idx} ~ {end_idx}") + files = set() + for ent_id in range(start_idx, end_idx): + files.add(entries[ent_id]["audio_filepath"]) + print(f"Shard {i} contains {len(files)} files") + if i == config.num_shards - 1: + # We discard in order to have the same number of entries per shard. + print(f"Have {len(entries) - end_idx} entries left over that will be discarded.") + + start_indices.append(start_idx) + end_indices.append(end_idx) + + manifest_folder, _ = os.path.split(manifest_path) + + with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel: + # Call parallel tarfile construction + new_entries_list = parallel( + delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder, only_manifests) + for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) + ) + + if config.shard_manifests: + sharded_manifests_dir = target_dir + '/sharded_manifests' + if not os.path.exists(sharded_manifests_dir): + os.makedirs(sharded_manifests_dir) + + for manifest in new_entries_list: + shard_id = manifest[0]['shard_id'] + new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json') + with open(new_manifest_shard_path, 'w', encoding='utf-8') as m2: + for entry in manifest: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + # Flatten the list of list of entries to a list of entries + new_entries = [sample for manifest in new_entries_list for sample in manifest] + del new_entries_list + + print("Total number of entries in manifest :", len(new_entries)) + + # Write manifest + new_manifest_path = os.path.join(target_dir, 'tarred_audio_manifest.json') + with open(new_manifest_path, 'w', encoding='utf-8') as m2: + for entry in new_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + # Write metadata (default metadata for new datasets) + new_metadata_path = os.path.join(target_dir, 'metadata.yaml') + metadata = ASRTarredDatasetMetadata() + + # Update metadata + metadata.dataset_config = config + metadata.num_samples_per_shard = len(new_entries) // config.num_shards + + if buckets_num <= 1: + # Estimate and update dynamic bucketing args + bucketing_kwargs = self.estimate_dynamic_bucketing_duration_bins( + new_manifest_path, num_buckets=dynamic_buckets_num + ) + for k, v in bucketing_kwargs.items(): + setattr(metadata.dataset_config, k, v) + + # Write metadata + metadata_yaml = OmegaConf.structured(metadata) + OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + + def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_buckets: int = 30) -> dict: + from lhotse import CutSet + from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets + + from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator + + cuts = CutSet(LazyNeMoIterator(manifest_path, metadata_only=True)) + bins = estimate_duration_buckets(cuts, num_buckets=num_buckets) + print( + f"Note: we estimated the optimal bucketing duration bins for {num_buckets} buckets. " + "You can enable dynamic bucketing by setting the following options in your training script:\n" + " use_lhotse=true\n" + " use_bucketing=true\n" + f" num_buckets={num_buckets}\n" + f" bucket_duration_bins=[{','.join(map(str, bins))}]\n" + " batch_duration=\n" + "If you'd like to use a different number of buckets, re-estimate this option manually using " + "scripts/speech_recognition/estimate_duration_bins.py" + ) + return dict( + use_lhotse=True, + use_bucketing=True, + num_buckets=num_buckets, + bucket_duration_bins=list(map(float, bins)), # np.float -> float for YAML serialization + ) + + def create_concatenated_dataset( + self, + base_manifest_path: str, + manifest_paths: List[str], + metadata: ASRTarredDatasetMetadata, + target_dir: str = "./tarred_concatenated/", + num_workers: int = 1, + only_manifests: bool = False, + dry_run: bool = False, + ): + """ + Creates a concatenated tarred dataset from the base manifest and additional manifest files. + + Args: + base_manifest_path (str): Path to the base manifest file that contains information for the original + tarred dataset (with flattened paths). + manifest_paths (List[str]): List of paths to additional manifest files that will be concatenated with + the base tarred dataset. + metadata (ASRTarredDatasetMetadata): Metadata instance containing configuration and overrides. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred_concatenated/". + num_workers (int, optional): Number of parallel worker processes for creating tar files. Defaults to 1. + only_manifests (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. + + Raises: + FileNotFoundError: If the base manifest file or any of the additional manifest files does not exist. + + Output: + - Creates tar files and a concatenated tarred dataset compatible manifest file in the specified `target_dir`. + - Updates metadata to reflect the concatenated dataset, including the version and historical data. + + Notes: + - The function reads the base manifest and additional manifests, filters and shuffles entries as needed, + and creates new shards of tar files. + - It generates a new concatenated dataset manifest and updates metadata with versioning and historical context. + - If `metadata` is provided, the function updates its version and includes historical data in the new metadata. + """ + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + if base_manifest_path is None: + raise FileNotFoundError("Base manifest filepath cannot be None !") + + if manifest_paths is None or len(manifest_paths) == 0: + raise FileNotFoundError("List of additional manifest filepaths cannot be None !") + + config = ASRTarredDatasetConfig(**(metadata.dataset_config)) + + # Read the existing manifest (no filtering here) + base_entries, _, _, _ = self._read_manifest(base_manifest_path, config) + print(f"Read base manifest containing {len(base_entries)} samples.") + + # Precompute number of samples per shard + if metadata.num_samples_per_shard is None: + num_samples_per_shard = len(base_entries) // config.num_shards + else: + num_samples_per_shard = metadata.num_samples_per_shard + + print("Number of samples per shard :", num_samples_per_shard) + + # Compute min and max duration and update config (if no metadata passed) + print(f"Selected max duration : {config.max_duration}") + print(f"Selected min duration : {config.min_duration}") + + entries = [] + for new_manifest_idx in range(len(manifest_paths)): + new_entries, total_duration, filtered_new_entries, filtered_duration = self._read_manifest( + manifest_paths[new_manifest_idx], config + ) + + if len(filtered_new_entries) > 0: + print( + f"Filtered {len(filtered_new_entries)} files which amounts to {filtered_duration:0.2f}" + f" seconds of audio from manifest {manifest_paths[new_manifest_idx]}." + ) + print( + f"After filtering, manifest has {len(entries)} files which amounts to {total_duration} seconds of audio." + ) + + entries.extend(new_entries) + + if len(entries) == 0: + print("No tarred dataset was created as there were 0 valid samples after filtering!") + return + + if config.shuffle: + random.seed(config.shuffle_seed) + print(f"Shuffling (seed: {config.shuffle_seed})...") + random.shuffle(entries) + + # Drop last section of samples that cannot be added onto a chunk + drop_count = len(entries) % num_samples_per_shard + total_new_entries = len(entries) + entries = entries[:-drop_count] + + print( + f"Dropping {drop_count} samples from total new samples {total_new_entries} since they cannot " + f"be added into a uniformly sized chunk." + ) + + # Create shards and updated manifest entries + num_added_shards = len(entries) // num_samples_per_shard + + print(f"Number of samples in base dataset : {len(base_entries)}") + print(f"Number of samples in additional datasets : {len(entries)}") + print(f"Number of added shards : {num_added_shards}") + print(f"Remainder: {len(entries) % num_samples_per_shard}") + + if dry_run: + return + + start_indices = [] + end_indices = [] + shard_indices = [] + for i in range(num_added_shards): + start_idx = (len(entries) // num_added_shards) * i + end_idx = start_idx + (len(entries) // num_added_shards) + shard_idx = i + config.num_shards + print(f"Shard {shard_idx} has entries {start_idx + len(base_entries)} ~ {end_idx + len(base_entries)}") + + start_indices.append(start_idx) + end_indices.append(end_idx) + shard_indices.append(shard_idx) + + manifest_folder, _ = os.path.split(base_manifest_path) + + with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel: + # Call parallel tarfile construction + new_entries_list = parallel( + delayed(self._create_shard)( + entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder, only_manifests + ) + for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices)) + ) + + if config.shard_manifests: + sharded_manifests_dir = target_dir + '/sharded_manifests' + if not os.path.exists(sharded_manifests_dir): + os.makedirs(sharded_manifests_dir) + + for manifest in new_entries_list: + shard_id = manifest[0]['shard_id'] + new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json') + with open(new_manifest_shard_path, 'w', encoding='utf-8') as m2: + for entry in manifest: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + # Flatten the list of list of entries to a list of entries + new_entries = [sample for manifest in new_entries_list for sample in manifest] + del new_entries_list + + # Write manifest + if metadata is None: + new_version = 1 # start with `1`, where `0` indicates the base manifest + dataset + else: + new_version = metadata.version + 1 + + print("Total number of entries in manifest :", len(base_entries) + len(new_entries)) + + new_manifest_path = os.path.join(target_dir, f'tarred_audio_manifest_version_{new_version}.json') + with open(new_manifest_path, 'w', encoding='utf-8') as m2: + # First write all the entries of base manifest + for entry in base_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + # Finally write the new entries + for entry in new_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write('\n') + + # Preserve historical metadata + base_metadata = metadata + + # Write metadata (updated metadata for concatenated datasets) + new_metadata_path = os.path.join(target_dir, f'metadata_version_{new_version}.yaml') + metadata = ASRTarredDatasetMetadata() + + # Update config + config.num_shards = config.num_shards + num_added_shards + + # Update metadata + metadata.version = new_version + metadata.dataset_config = config + metadata.num_samples_per_shard = num_samples_per_shard + metadata.is_concatenated_manifest = True + metadata.created_datetime = metadata.get_current_datetime() + + # Attach history + current_metadata = OmegaConf.structured(base_metadata.history) + metadata.history = current_metadata + + # Write metadata + metadata_yaml = OmegaConf.structured(metadata) + OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + + def _read_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): + """Read and filters data from the manifest""" + # Read the existing manifest + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + with open(manifest_path, 'r', encoding='utf-8') as m: + for line in m: + entry = json.loads(line) + audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" + if config.slice_with_offset and "offset" not in entry: + raise KeyError( + f"Manifest entry does not contain 'offset' field, but '--slice_with_offset' is enabled: {entry}" + ) + if audio_key not in entry: + raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") + audio_filepath = entry[audio_key] + if not os.path.isfile(audio_filepath) and not os.path.isabs(audio_filepath): + audio_filepath_abs = os.path.join(os.path.dirname(manifest_path), audio_filepath) + if not os.path.isfile(audio_filepath_abs): + raise FileNotFoundError(f"Could not find {audio_filepath} or {audio_filepath_abs}!") + entry[audio_key] = audio_filepath_abs + if (config.max_duration is None or entry['duration'] < config.max_duration) and ( + config.min_duration is None or entry['duration'] >= config.min_duration + ): + entries.append(entry) + total_duration += entry["duration"] + else: + filtered_entries.append(entry) + filtered_duration += entry['duration'] + + return entries, total_duration, filtered_entries, filtered_duration + + def _write_to_tar( + self, tar, audio_filepath: str, squashed_filename: str, duration: float = None, offset: float = 0 + ) -> None: + codec = self.config.force_codec + to_transcode = not (codec is None or audio_filepath.endswith(f".{codec}")) + to_crop = not (duration is None and offset == 0) + + if not to_crop and not to_transcode: + # Add existing file without transcoding, trimming, or re-encoding. + tar.add(audio_filepath, arcname=squashed_filename) + return + + # Standard processing: read, trim, and transcode the audio file + with sf.SoundFile(audio_filepath) as f: + sampling_rate = f.samplerate + + # Trim audio based on offset and duration. + start_sample = int(offset * sampling_rate) + num_frames = int(duration * sampling_rate) if duration else -1 + audio, sampling_rate = sf.read(audio_filepath, start=start_sample, frames=num_frames) + + # Determine codec parameters. + if codec is not None: + if codec == "opus": + kwargs = {"format": "ogg", "subtype": "opus"} + else: + kwargs = {"format": codec} + else: + codec = sf.info(audio_filepath).format.lower() + kwargs = {"format": codec} + + # Transcode and write audio to tar. + encoded_audio = BytesIO() + sf.write(encoded_audio, audio, sampling_rate, closefd=False, **kwargs) + + # Generate filename with the appropriate extension. + encoded_squashed_filename = f"{squashed_filename.split('.')[0]}.{codec}" + + # Add the in-memory audio file to the tar archive. + ti = tarfile.TarInfo(encoded_squashed_filename) + encoded_audio.seek(0) + ti.size = len(encoded_audio.getvalue()) + tar.addfile(ti, encoded_audio) + + def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = None, only_manifests: bool = False): + """Creates a tarball containing the audio files from `entries`.""" + if self.config.sort_in_shards: + entries.sort(key=lambda x: x["duration"], reverse=False) + + new_entries = [] + + tar_filepath = os.path.join(target_dir, f'audio_{shard_id}.tar') + if not only_manifests: + tar = tarfile.open(tar_filepath, mode='w', dereference=True) + + count = dict() + for entry in tqdm(entries, desc="Creating shard.."): + # We squash the filename since we do not preserve directory structure of audio files in the tarball. + if os.path.exists(entry["audio_filepath"]) or only_manifests: + audio_filepath = entry["audio_filepath"] + else: + if not manifest_folder: + raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") + + audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"]) + if not os.path.exists(audio_filepath): + raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") + + base, ext = os.path.splitext(audio_filepath) + base = base.replace('/', '_') + # Need the following replacement as long as WebDataset splits on first period + base = base.replace('.', '_') + squashed_filename = f'{base}{ext}' + + if self.config.slice_with_offset: + if squashed_filename not in count: + count[squashed_filename] = 1 + + entry_offset = str(entry['offset']).split('.') + if len(entry_offset) == 1: + # Example: offset = 12 -> becomes 12_0 + entry_offset.append('0') + elif len(entry_offset) == 2: + # Example: offset = 12.34 -> becomes 12_34 + pass + else: + raise ValueError( + f"The offset for the entry with audio_filepath '{entry['audio_filepath']}' is incorrectly provided ({entry['offset']}). " + "Expected a float-like value (e.g., 12 or 12.34)." + ) + entry_offset = "_".join(entry_offset) + + entry_duration = str(entry['duration']).split('.') + if len(entry_duration) == 1: + entry_duration.append('0') + elif len(entry_duration) > 2: + raise ValueError( + f"The duration for the entry with audio_filepath '{entry['audio_filepath']}' is incorrectly provided ({entry['duration']})." + ) + entry_duration = "_".join(entry_duration) + + to_write = base + "_" + entry_offset + "_" + entry_duration + ext + if not only_manifests: + self._write_to_tar( + tar, audio_filepath, to_write, duration=entry['duration'], offset=entry['offset'] + ) + count[squashed_filename] += 1 + + entry['source_audio_offset'] = entry['offset'] + del entry['offset'] + else: + if squashed_filename not in count: + if not only_manifests: + self._write_to_tar(tar, audio_filepath, squashed_filename) + to_write = squashed_filename + count[squashed_filename] = 1 + else: + to_write = base + "-sub" + str(count[squashed_filename]) + ext + count[squashed_filename] += 1 + + if only_manifests: + entry['abs_audio_filepath'] = audio_filepath + + # Carry over every key in the entry, override audio_filepath and shard_id + new_entry = { + **entry, + 'audio_filepath': to_write, + 'shard_id': shard_id, # Keep shard ID for recordkeeping + } + new_entries.append(new_entry) + + if not only_manifests: + tar.close() + return new_entries + + @classmethod + def setup_history(cls, base_metadata: ASRTarredDatasetMetadata, history: List[Any]): + if 'history' in base_metadata.keys(): + for history_val in base_metadata.history: + cls.setup_history(history_val, history) + + if base_metadata is not None: + metadata_copy = copy.deepcopy(base_metadata) + with open_dict(metadata_copy): + metadata_copy.pop('history', None) + history.append(metadata_copy) + + +def main(args): + if args.buckets_num > 1: + bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num) + for i_bucket in range(args.buckets_num): + bucket_config = copy.deepcopy(args) + bucket_config.min_duration = args.min_duration + i_bucket * bucket_length + bucket_config.max_duration = bucket_config.min_duration + bucket_length + if i_bucket == args.buckets_num - 1: + # add a small number to cover the samples with exactly duration of max_duration in the last bucket. + bucket_config.max_duration += 1e-5 + bucket_config.target_dir = os.path.join(args.target_dir, f"bucket{i_bucket+1}") + print( + f"Creating bucket {i_bucket+1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ..." + ) + print(f"Results are being saved at: {bucket_config.target_dir}.") + create_tar_datasets(**vars(bucket_config)) + if not args.dry_run: + print(f"Bucket {i_bucket+1} is created.") + else: + create_tar_datasets(**vars(args)) + + +def create_tar_datasets( + manifest_path: str = None, + concat_manifest_paths: str = None, + target_dir: str = None, + metadata_path: str = None, + num_shards: int = -1, + max_duration: float = None, + min_duration: float = None, + shuffle: bool = False, + keep_files_together: bool = False, + sort_in_shards: bool = False, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + shuffle_seed: int = None, + write_metadata: bool = False, + no_shard_manifests: bool = False, + force_codec: str = None, + workers: int = 1, + slice_with_offset: bool = False, + only_manifests: bool = False, + dry_run: bool = False, +): + builder = ASRTarredDatasetBuilder() + + shard_manifests = False if no_shard_manifests else True + + if write_metadata: + metadata = ASRTarredDatasetMetadata() + dataset_cfg = ASRTarredDatasetConfig( + num_shards=num_shards, + shuffle=shuffle, + max_duration=max_duration, + min_duration=min_duration, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, + shard_manifests=shard_manifests, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, + ) + metadata.dataset_config = dataset_cfg + + output_path = os.path.join(target_dir, 'default_metadata.yaml') + OmegaConf.save(metadata, output_path, resolve=True) + print(f"Default metadata written to {output_path}") + exit(0) + + if concat_manifest_paths is None or len(concat_manifest_paths) == 0: + # Create a tarred dataset from scratch + config = ASRTarredDatasetConfig( + num_shards=num_shards, + shuffle=shuffle, + max_duration=max_duration, + min_duration=min_duration, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, + shard_manifests=shard_manifests, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, + ) + builder.configure(config) + builder.create_new_dataset( + manifest_path=manifest_path, + target_dir=target_dir, + num_workers=workers, + buckets_num=buckets_num, + dynamic_buckets_num=dynamic_buckets_num, + only_manifests=only_manifests, + dry_run=dry_run, + ) + + else: + if buckets_num > 1: + raise ValueError("Concatenation feature does not support buckets_num > 1.") + print("Concatenating multiple tarred datasets ...") + + # Implicitly update config from base details + if metadata_path is not None: + metadata = ASRTarredDatasetMetadata.from_file(metadata_path) + else: + raise ValueError("`metadata` yaml file path must be provided!") + + # Preserve history + history = [] + builder.setup_history(OmegaConf.structured(metadata), history) + metadata.history = history + + # Add command line overrides (everything other than num_shards) + metadata.dataset_config.max_duration = max_duration + metadata.dataset_config.min_duration = min_duration + metadata.dataset_config.shuffle = shuffle + metadata.dataset_config.shuffle_seed = shuffle_seed + metadata.dataset_config.sort_in_shards = sort_in_shards + metadata.dataset_config.shard_manifests = shard_manifests + + builder.configure(metadata.dataset_config) + + # Concatenate a tarred dataset onto a previous one + builder.create_concatenated_dataset( + base_manifest_path=manifest_path, + manifest_paths=concat_manifest_paths, + metadata=metadata, + target_dir=target_dir, + num_workers=workers, + slice_with_offset=slice_with_offset, + only_manifests=only_manifests, + dry_run=dry_run, + ) + + if not dry_run and (DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE): + print("Constructing DALI Tarfile Index - ", target_dir) + index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=workers) + dali_index.main(index_config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert an existing ASR dataset to tarballs compatible with TarredAudioToTextDataLayer." + ) + parser.add_argument( + "--manifest_path", default=None, type=str, required=False, help="Path to the existing dataset's manifest." + ) + + parser.add_argument( + '--concat_manifest_paths', + nargs='+', + default=None, + type=str, + required=False, + help="Path to the additional dataset's manifests that will be concatenated with base dataset.", + ) + + # Optional arguments + parser.add_argument( + "--target_dir", + default='./tarred', + type=str, + help="Target directory for resulting tarballs and manifest. Defaults to `./tarred`. Creates the path if necessary.", + ) + + parser.add_argument( + "--metadata_path", + required=False, + default=None, + type=str, + help="Path to metadata file for the dataset.", + ) + + parser.add_argument( + "--num_shards", + default=-1, + type=int, + help="Number of shards (tarballs) to create. Used for partitioning data among workers.", + ) + parser.add_argument( + '--max_duration', + default=None, + required=True, + type=float, + help='Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.', + ) + parser.add_argument( + '--min_duration', + default=None, + type=float, + help='Minimum duration of audio clip in the dataset. By default, it is None and will not filter files.', + ) + parser.add_argument( + "--shuffle", + action='store_true', + help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.", + ) + + parser.add_argument( + "--keep_files_together", + action='store_true', + help="Whether or not to keep entries from the same file (but different offsets) together when sorting before tarring/sharding.", + ) + parser.add_argument( + "--slice_with_offset", + action='store_true', + help=( + "If set, the audio will be sliced based on `offset` and `duration` fields from the manifest. " + "This is useful for creating datasets from audio segments instead of full files. " + "When unset, the entire audio file is used without slicing, regardless of the offset/duration values in the manifest." + ), + ) + parser.add_argument( + "--sort_in_shards", + action='store_true', + help="Whether or not to sort samples inside the shards based on their duration.", + ) + + parser.add_argument( + "--buckets_num", + type=int, + default=1, + help="Number of buckets to create based on duration.", + ) + + parser.add_argument( + "--dynamic_buckets_num", + type=int, + default=30, + help="Intended for dynamic (on-the-fly) bucketing; this option will not bucket your dataset during tar conversion. " + "Estimates optimal bucket duration bins for a given number of buckets.", + ) + + parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.") + parser.add_argument( + '--write_metadata', + action='store_true', + help=( + "Flag to write a blank metadata with the current call config. " + "Note that the metadata will not contain the number of shards, " + "and it must be filled out by the user." + ), + ) + parser.add_argument( + "--no_shard_manifests", + action='store_true', + help="Do not write sharded manifests along with the aggregated manifest.", + ) + parser.add_argument( + "--force_codec", + type=str, + default=None, + help=( + "If specified, transcode the audio to the given format. " + "Supports libnsndfile formats (example values: 'opus', 'flac')." + ), + ) + parser.add_argument( + "--only_manifests", + action='store_true', + help=( + "If set, only creates manifests for each shard without creating the actual tar files. " + "This allows you to verify the output structure and content before committing to the full tarball creation process. " + "Each manifest entry will also include the field `abs_audio_filepath`, which stores the absolute path to the original audio file." + ), + ) + parser.add_argument( + "--dry_run", + action='store_true', + help=( + "Run in simulation mode: calculate and display the number of shards and estimated data per shard without reading audio files or writing any output." + ), + ) + parser.add_argument('--workers', type=int, default=1, help='Number of worker processes') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py new file mode 100644 index 00000000..1ae64dc5 --- /dev/null +++ b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import logging +import os +from dataclasses import dataclass + +import hydra +from hydra.core.config_store import ConfigStore +from joblib import Parallel, delayed +from omegaconf import MISSING + +try: + from wds2idx import IndexCreator + + INDEX_CREATOR_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + INDEX_CREATOR_AVAILABLE = False + +""" +python create_dali_tarred_dataset_index.py \ + tar_dir= \ + workers=-1 + +""" + +logging.basicConfig(level=logging.INFO) + + +@dataclass +class DALITarredIndexConfig: + tar_dir: str = MISSING # Path to the existing dataset's manifest + workers: int = -1 # number of worker processes + + +def process_index_path(tar_paths, index_dir): + """ + Appends the folder `{index_dir}` to the filepath of all tarfiles. + Example: + /X/Y/Z/audio_0.tar -> /X/Y/Z/{index_dir}/audio_0.index + """ + index_paths = [] + for path in tar_paths: + basepath, filename = os.path.split(path) + path = filename.replace('.tar', '.index') + path = os.path.join(basepath, path) + base, name = os.path.split(path) + index_path = os.path.join(index_dir, name) + index_paths.append(index_path) + + return index_paths + + +def build_index(tarpath, indexfile): + with IndexCreator(tarpath, indexfile) as index: + index.create_index() + + +@hydra.main(config_path=None, config_name='index_config', version_base="1.1") +def main(cfg: DALITarredIndexConfig): + if not INDEX_CREATOR_AVAILABLE: + logging.error("`wds2idx` is not installed. Please install NVIDIA DALI >= 1.11") + exit(1) + + tar_files = list(glob.glob(os.path.join(cfg.tar_dir, "*.tar"))) + + index_dir = os.path.join(cfg.tar_dir, "dali_index") + if not os.path.exists(index_dir): + os.makedirs(index_dir, exist_ok=True) + + index_paths = process_index_path(tar_files, index_dir) + + with Parallel(n_jobs=cfg.workers, verbose=len(tar_files)) as parallel: + _ = parallel(delayed(build_index)(tarpath, indexfile) for tarpath, indexfile in zip(tar_files, index_paths)) + + logging.info("Finished constructing index files !") + + +ConfigStore.instance().store(name='index_config', node=DALITarredIndexConfig) + + +if __name__ == '__main__': + main() \ No newline at end of file From bcc6b7727c4ce3279d98bbe8717a12a2bbcb582e Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 21 Jul 2025 09:21:25 -0700 Subject: [PATCH 02/21] Docs and import setup added Signed-off-by: Sasha Meister --- docs/src/sdp/api.rst | 3 +++ sdp/processors/__init__.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index 4eecb933..d3275b85 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -299,6 +299,9 @@ Files management .. autodata:: sdp.processors.RemoveFiles :annotation: +.. autodata:: sdp.processors.ConvertToTarredAudioDataset + :annotation: + Data filtering '''''''''''''' diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 79f2205e..7a74f5de 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -144,6 +144,12 @@ from sdp.processors.manage_files.remove import ( RemoveFiles, ) +from sdp.processors.manage_files.remove import ( + RemoveFiles, +) +from sdp.processors.manage_files.convert_to_tarred_audio_dataset import ( + ConvertToTarredAudioDataset, +) from sdp.processors.nemo.asr_inference import ASRInference from sdp.processors.nemo.estimate_bandwidth import EstimateBandwidth from sdp.processors.nemo.pc_inference import PCInference From 62d8e9ee85ee5f33f36c5578819f5a5ae658d4c5 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Tue, 22 Jul 2025 07:25:00 -0700 Subject: [PATCH 03/21] Tests added Signed-off-by: Sasha Meister --- requirements/main.txt | 2 + tests/test_convert_to_tarred_audio_dataset.py | 217 ++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 tests/test_convert_to_tarred_audio_dataset.py diff --git a/requirements/main.txt b/requirements/main.txt index 99c030b4..7380636e 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -25,3 +25,5 @@ datasets>=2.14.0,<3.0.0 # for some processers, additionally https://github.com/NVIDIA/NeMo is required # for some processers, additionally nemo_text_processing is required # for mcv: apt-get update && apt-get upgrade -y && apt-get install -y sox libsox-fmt-all + +# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]" \ No newline at end of file diff --git a/tests/test_convert_to_tarred_audio_dataset.py b/tests/test_convert_to_tarred_audio_dataset.py new file mode 100644 index 00000000..2b864547 --- /dev/null +++ b/tests/test_convert_to_tarred_audio_dataset.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test module for the ConvertToTarredAudioDataset processor. + +This test validates the correctness of the audio sharding logic under different configurations. +It generates a synthetic dataset of random WAV files with varying durations, +then checks if the processor correctly shards, buckets, and outputs the expected manifest entries. + +Test optimization is achieved by using a pytest fixture to generate WAV files only once per test session. +""" + +import os +import tempfile +import shutil +import wave +import numpy as np +import json +from typing import List, Dict +import pytest + +from sdp.processors import ConvertToTarredAudioDataset + +EPSILON = 1e-7 +NUM_WORKERS = max(1, os.cpu_count() // 2) + +def generate_random_wav(audio_filepath: str, duration: float = 1.0, sample_rate: int = 16000): + """Generate a mono 16-bit WAV file with random data of specified duration.""" + num_samples = int(duration * sample_rate) + audio_data = np.random.randint(-32768, 32767, num_samples, dtype=np.int16) + + with wave.open(audio_filepath, 'w') as wf: + wf.setnchannels(1) + wf.setsampwidth(2) # 16-bit PCM + wf.setframerate(sample_rate) + wf.writeframes(audio_data.tobytes()) + + +def generate_wav_subset(output_dir: str, min_duration: float = 0.1, max_duration: float = 70.0, num_samples: int = 70, sr: int = 16000) -> List[Dict]: + """Generate a list of WAV files with increasing durations and return their metadata.""" + step = (max_duration - min_duration) / num_samples + durations = [min_duration + i * step for i in range(num_samples)] + + samples = [] + for i, duration in enumerate(durations): + sample_id = f'audio_{i}' + audio_filepath = os.path.join(output_dir, f'{sample_id}.wav') + generate_random_wav(audio_filepath, duration, sr) + samples.append({ + 'sample_id': sample_id, + 'audio_filepath': audio_filepath, + 'duration': duration + }) + return samples + + +def write_manifest(samples: List[Dict], output_manifest_filepath: str): + """Write a list of samples to a JSON lines manifest file.""" + with open(output_manifest_filepath, 'w', encoding='utf8') as manifest: + for sample in samples: + manifest.write(json.dumps(sample) + '\n') + + +def read_manifest(manifest_filepath: str) -> List[Dict]: + """Read a JSON lines manifest file and return the list of samples.""" + with open(manifest_filepath, 'r', encoding='utf8') as manifest: + return [json.loads(line) for line in manifest] + + +def strip_fields(samples: List[Dict], exclude_keys: List[str] = ['audio_filepath', 'abs_audio_filepath']) -> List[Dict]: + """Remove specified keys from sample dictionaries.""" + return [{k: v for k, v in d.items() if k not in exclude_keys} for d in samples] + +def get_expected_result( + samples: List[Dict], + num_shards=8, + buckets_num=1, + min_duration=1.0, + max_duration=40.0, + **kwargs +) -> List[Dict]: + """ + Generate expected manifest given the parameters, matching real processor behavior: + - sorted by duration + - filtered by min/max + - split into buckets + - distributed by filling shard 0 completely, then shard 1, etc. + - discard leftover samples that don't fit evenly + """ + result = [] + EPSILON = 1e-7 + + def process_bucket(bucket_samples: List[Dict], bucket_idx: int) -> None: + bucket_samples = sorted(bucket_samples, key=lambda s: s['duration']) + total = len(bucket_samples) + per_shard = total // num_shards + usable = per_shard * num_shards + trimmed = bucket_samples[:usable] + + shard_size = usable // num_shards + for shard_id in range(num_shards): + start = shard_id * shard_size + end = start + shard_size + for s in trimmed[start:end]: + s = s.copy() + if buckets_num > 1: + # Only add bucket_id when bucketing is enabled (buckets_num > 1) + s['bucket_id'] = bucket_idx + s['shard_id'] = shard_id + result.append(s) + + # Fast path when no bucketing is requested + if buckets_num == 1: + filtered_samples = [ + s for s in samples + if min_duration <= s['duration'] < max_duration # strict upper bound (<) to match processor logic + ] + process_bucket(filtered_samples, 0) + return result + + step = (max_duration + EPSILON - min_duration) / buckets_num + + for i in range(buckets_num): + bucket_min = min_duration + i * step + bucket_max = bucket_min + step + + # Strict upper bound (<) for all but the last bucket. + upper = bucket_max + (EPSILON if i == buckets_num - 1 else 0.0) + bucket_samples = [ + s for s in samples + if bucket_min <= s['duration'] < upper + ] + process_bucket(bucket_samples, i) + + return result + +# 🔧 Pytest fixture that generates audio samples once per test session +@pytest.fixture(scope="session") +def prepared_samples(): + """ + Generate and cache a set of audio samples for all test runs. + Files are created in a temporary directory and deleted after the session. + """ + safe_dir = tempfile.mkdtemp() + samples = generate_wav_subset(safe_dir, min_duration=0.1, max_duration=70.0, num_samples=70, sr=16000) + + yield samples + + # Cleanup after session ends + shutil.rmtree(safe_dir) + + +# Configuration parameters to test different behaviors of the processor +test_configs = [ + dict(num_shards=8, min_duration=1.0, max_duration=40.0, workers=NUM_WORKERS), + dict(num_shards=4, buckets_num=2, min_duration=1.0, max_duration=40.0, workers=NUM_WORKERS), + dict(num_shards=8, buckets_num=1, min_duration=1.0, max_duration=40.0, only_manifests=True, workers=NUM_WORKERS), +] + +@pytest.mark.parametrize("cfg", test_configs) +def test_convert_to_tarred_audio_dataset(prepared_samples, cfg): + """ + Test ConvertToTarredAudioDataset with different sharding and bucketing configurations. + Checks both the manifest contents and the existence of expected output files. + """ + with tempfile.TemporaryDirectory() as output_dir: + input_manifest = os.path.join(output_dir, 'input.json') + output_manifest = os.path.join(output_dir, 'output.json') + cfg['target_dir'] = os.path.join(output_dir, 'tarred_dataset') + cfg['sort_in_shards'] = True + + # Write manifest + write_manifest(prepared_samples, input_manifest) + + # Run processor + processor = ConvertToTarredAudioDataset( + input_manifest_file=input_manifest, + output_manifest_file=output_manifest, + **cfg + ) + processor.process() + + # Compare output manifest with expected values + output_samples = sorted(strip_fields(read_manifest(output_manifest)), key=lambda x: x['duration']) + expected_samples = sorted(strip_fields(get_expected_result(prepared_samples, **cfg)), key=lambda x: x['duration']) + assert output_samples == expected_samples + + # Check existence of tar and manifest files + base_dir = cfg['target_dir'] + bucket_dirs = ( + [os.path.join(base_dir, f"bucket{i+1}") for i in range(cfg.get('buckets_num', 1))] + if cfg.get('buckets_num', 1) > 1 + else [base_dir] + ) + + for b_dir in bucket_dirs: + for shard in range(cfg['num_shards']): + # Tar files should exist unless we run in `only_manifests` mode. + if not cfg.get('only_manifests', False): + assert os.path.exists(os.path.join(b_dir, f'audio_{shard}.tar')) + + # Per-shard manifests are written in `sharded_manifests/manifest_{shard}.json`, + # unless this feature is explicitly disabled via `no_shard_manifests`. + if not cfg.get('no_shard_manifests', False): + assert os.path.exists(os.path.join(b_dir, 'sharded_manifests', f'manifest_{shard}.json')) \ No newline at end of file From 9218835722cddcbdd8246de7ffcac383da140685 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Tue, 22 Jul 2025 07:48:18 -0700 Subject: [PATCH 04/21] Removed duplicated import Signed-off-by: Sasha Meister --- sdp/processors/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index ad2258b6..47e9aa15 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -146,9 +146,6 @@ from sdp.processors.manage_files.remove import ( RemoveFiles, ) -from sdp.processors.manage_files.remove import ( - RemoveFiles, -) from sdp.processors.manage_files.convert_to_tarred_audio_dataset import ( ConvertToTarredAudioDataset, ) From 9c705dc97c2697711535d43e3c7224a18392ab05 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Tue, 22 Jul 2025 08:02:54 -0700 Subject: [PATCH 05/21] Add tabulate to docs requirements Signed-off-by: Sasha Meister --- requirements/docs.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/docs.txt b/requirements/docs.txt index fd4d45ce..2f44117e 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -4,3 +4,4 @@ Sphinx sphinx-book-theme sphinx-copybutton sphinxext-opengraph +tabulate \ No newline at end of file From b05351846a717e2afa129db43f27b56f544ec242 Mon Sep 17 00:00:00 2001 From: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Date: Tue, 22 Jul 2025 18:50:40 +0200 Subject: [PATCH 06/21] Set env var PATH for test --- tests/test_convert_to_tarred_audio_dataset.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_convert_to_tarred_audio_dataset.py b/tests/test_convert_to_tarred_audio_dataset.py index 2b864547..93345995 100644 --- a/tests/test_convert_to_tarred_audio_dataset.py +++ b/tests/test_convert_to_tarred_audio_dataset.py @@ -33,6 +33,11 @@ from sdp.processors import ConvertToTarredAudioDataset +import os + +if 'PATH' not in os.environ: + os.environ['PATH'] = '/usr/bin:/bin' + EPSILON = 1e-7 NUM_WORKERS = max(1, os.cpu_count() // 2) @@ -214,4 +219,4 @@ def test_convert_to_tarred_audio_dataset(prepared_samples, cfg): # Per-shard manifests are written in `sharded_manifests/manifest_{shard}.json`, # unless this feature is explicitly disabled via `no_shard_manifests`. if not cfg.get('no_shard_manifests', False): - assert os.path.exists(os.path.join(b_dir, 'sharded_manifests', f'manifest_{shard}.json')) \ No newline at end of file + assert os.path.exists(os.path.join(b_dir, 'sharded_manifests', f'manifest_{shard}.json')) From d443d063c4927e66ec4a72b781d26f1d2aca9f1b Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 03:53:23 -0700 Subject: [PATCH 07/21] set PATH inside ci Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8761604a..d9c444d0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 + PATH: /usr/local/bin:/usr/bin:/bin run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] From a039493946e3e73b2b2bfe1564b922882b05db94 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 03:55:49 -0700 Subject: [PATCH 08/21] tmp change Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d9c444d0..9a101de3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -92,7 +92,7 @@ jobs: sudo cp incommon-rsa-ca2.pem /usr/local/share/ca-certificates/incommon-rsa-server-ca-2.crt # [cert for CORAL] sudo update-ca-certificates # [cert for CORAL] set -o pipefail # this will make sure next line returns non-0 exit code if tests fail - python -m pytest tests/ --junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt + python -m pytest tests/convert_to_tarred_audio_dataset.py #--junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt # TODO: add some way to see if e2e tests were skipped From 9a0df3f2d8d4cef76622c3e0bbd9629b30e8bd00 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 04:01:32 -0700 Subject: [PATCH 09/21] tmp change Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9a101de3..409e5881 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -92,7 +92,7 @@ jobs: sudo cp incommon-rsa-ca2.pem /usr/local/share/ca-certificates/incommon-rsa-server-ca-2.crt # [cert for CORAL] sudo update-ca-certificates # [cert for CORAL] set -o pipefail # this will make sure next line returns non-0 exit code if tests fail - python -m pytest tests/convert_to_tarred_audio_dataset.py #--junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt + python -m pytest tests/test_convert_to_tarred_audio_dataset.py #--junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt # TODO: add some way to see if e2e tests were skipped From 1d61adf4acf41731bdcaeb10c5c39aa589e59aff Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 05:08:05 -0700 Subject: [PATCH 10/21] tmp change Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 409e5881..111ab582 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,7 +72,7 @@ jobs: sudo apt-get install -y libsndfile1 ffmpeg sox libsox-fmt-mp3 pip install pytorch_lightning pip install Cython wheel # need to pre-install to avoid error in nemo installation - pip install nemo-toolkit[asr,nlp]==1.23.0 + pip install nemo-toolkit[asr,nlp]==2.3.2 pip install nemo_text_processing pip install -r requirements/huggingface.txt pip install certifi #this needed to avoid problems with certificates [COORAL] From fe2088b5f1584caaa4498055cfdd3b97e02dbcac Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 05:14:11 -0700 Subject: [PATCH 11/21] tmp change Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 111ab582..520dd908 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,7 +85,7 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 - PATH: /usr/local/bin:/usr/bin:/bin +# PATH: /usr/local/bin:/usr/bin:/bin run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] From ad968aa1e496b8754630e8e5153f04e37c349b1c Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 05:28:15 -0700 Subject: [PATCH 12/21] Switch to NeMo 2.3.2 Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 3 +- requirements/main.txt | 4 +- .../utils/frame_vad_infer_postprocess.yaml | 2 +- .../asr/nemo/utils/speech_to_text_with_vad.py | 7 +- .../asr/nemo/utils/transcribe_speech.py | 277 ++++++++++++------ 5 files changed, 189 insertions(+), 104 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 520dd908..a0132d67 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,14 +85,13 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 -# PATH: /usr/local/bin:/usr/bin:/bin run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] sudo cp incommon-rsa-ca2.pem /usr/local/share/ca-certificates/incommon-rsa-server-ca-2.crt # [cert for CORAL] sudo update-ca-certificates # [cert for CORAL] set -o pipefail # this will make sure next line returns non-0 exit code if tests fail - python -m pytest tests/test_convert_to_tarred_audio_dataset.py #--junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt + python -m pytest tests/ --junitxml=pytest.xml --ignore=tests/test_tts_sdp_end_to_end.py --cov-report=term-missing:skip-covered --cov=sdp --durations=30 -rs | tee pytest-coverage.txt # TODO: add some way to see if e2e tests were skipped diff --git a/requirements/main.txt b/requirements/main.txt index cdbffac9..5d0a4c37 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -22,7 +22,7 @@ jiwer>=3.1.0,<4.0.0 pyarrow>=8.0.0,<14.0.0 datasets>=2.14.0,<3.0.0 # toloka-kit # Temporarily disabled due to Toloka's technical pause; keep as reference for past and future API support -# for some processers, additionally https://github.com/NVIDIA/NeMo is required +# for some processers, additionally https://github.com/NVIDIA/NeMo 2.3.2 is required # for some processers, additionally nemo_text_processing is required # for mcv: apt-get update && apt-get upgrade -y && apt-get install -y sox libsox-fmt-all @@ -30,4 +30,4 @@ datasets>=2.14.0,<3.0.0 # pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` # for vLLMInference processor is required: pip install "optree>=0.13.0" vllm -# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]" +# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.3.2" diff --git a/sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml b/sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml index 1d00eca6..81eec7b4 100644 --- a/sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml +++ b/sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml @@ -36,4 +36,4 @@ out_manifest_filepath: null # if not specify it will automatically be "manifest_ # json manifest line example -# {"audio_filepath": "/path/to/audio_file.wav", "offset": 0, "duration": 1.23, "label": "infer", "text": "-"} +# {"audio_filepath": "/path/to/audio_file.wav", "offset": 0, "duration": 1.23, "label": "infer", "text": "-"} \ No newline at end of file diff --git a/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py index 05c28cae..649155c5 100644 --- a/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py +++ b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py @@ -57,9 +57,8 @@ import contextlib import json import os - import time -from dataclasses import dataclass, is_dataclass, field +from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import Callable, Optional @@ -73,7 +72,7 @@ from nemo.collections.asr.data import feature_to_text_dataset from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel, EncDecClassificationModel -from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules import CTCDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest from nemo.collections.asr.parts.utils.vad_utils import ( @@ -646,4 +645,4 @@ def run_asr_inference(manifest_filepath, cfg, record_fn) -> str: if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py b/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py index bb04047b..7ca29238 100644 --- a/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py +++ b/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v1.23.0/examples/asr/transcribe_speech.py. -# It is currently only compatible with NeMo v1.23.0. To use a different version of NeMo, please modify the file. - -import contextlib +import json import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl +import numpy as np import torch from omegaconf import OmegaConf, open_dict -from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecMultiTaskModel +from nemo.collections.asr.models import EncDecCTCModel, EncDecHybridRNNTCTCModel, EncDecRNNTModel +from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt from nemo.collections.asr.modules.conformer_encoder import ConformerChangeConfig from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig @@ -34,12 +33,13 @@ from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, prepare_audio_data, + restore_transcription_order, setup_model, - transcribe_partial_audio, write_transcription, ) from nemo.core.config import hydra_runner from nemo.utils import logging +from nemo.utils.timers import SimpleTimer """ Transcribe audio file on a single CPU/GPU. Useful for transcription of moderate amounts of audio data. @@ -48,21 +48,17 @@ model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files - dataset_manifest: path to dataset JSON manifest file (in NeMo format) - - compute_timestamps: Bool to request greedy time stamp information (if the model supports it) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) - - (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) + ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) + rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word, segment]) output_filename: Output filename where the transcriptions will be written batch_size: batch size during inference + presort_manifest: sorts the provided manifest by audio length for faster inference (default: True) cuda: Optional int to enable or disable execution of model on certain CUDA device. allow_mps: Bool to allow using MPS (Apple Silicon M-series GPU) device if available @@ -79,6 +75,8 @@ langid: Str used for convert_num_to_words during groundtruth cleaning use_cer: Bool to use Character Error Rate (CER) or Word Error Rate (WER) + calculate_rtfx: Bool to calculate the RTFx throughput to transcribe the input dataset. + # Usage ASR model can be specified by either "model_path" or "pretrained_name". Data for transcription can be defined with either "audio_dir" or "dataset_manifest". @@ -95,7 +93,7 @@ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ - compute_timestamps=False \ + timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ @@ -106,23 +104,30 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder - conformer: ConformerChangeConfig = ConformerChangeConfig() + conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: + """ + Transcription Configuration for audio to text transcription. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model audio_dir: Optional[str] = None # Path to a directory which contains audio files dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest - channel_selector: Optional[ - Union[int, str] - ] = None # Used to select a single channel from multichannel audio, or use average across channels + channel_selector: Optional[Union[int, str]] = ( + None # Used to select a single channel from multichannel audio, or use average across channels + ) audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation + presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction # General configs output_filename: Optional[str] = None @@ -132,10 +137,11 @@ class TranscriptionConfig: pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() - # Set to True to output greedy timestamp information (only supported models) - compute_timestamps: bool = False - # set to True if need to return full alignment information - preserve_alignment: bool = False + # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses + timestamps: Optional[bool] = None + + # Set to True to return hypotheses instead of text from the transcribe function + return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -147,19 +153,33 @@ class TranscriptionConfig: allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) amp: bool = False amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + compute_dtype: Optional[str] = ( + None # "float32", "bfloat16" or "float16"; if None (default): bfloat16 if available else float32 + ) + matmul_precision: str = "high" # Literal["highest", "high", "medium"] audio_type: str = "wav" # Recompute model transcription, even if the output folder exists with scores. overwrite_transcripts: bool = True # Decoding strategy for CTC models - ctc_decoding: CTCDecodingConfig = CTCDecodingConfig() + ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) # Decoding strategy for RNNT models - rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig(fused_batch_size=-1) + # enable CUDA graphs for transcription + rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) # Decoding strategy for AED models - multitask_decoding: MultiTaskDecodingConfig = MultiTaskDecodingConfig() + multitask_decoding: MultiTaskDecodingConfig = field(default_factory=MultiTaskDecodingConfig) + # Prompt slots for prompted models, e.g. Canary-1B. Examples of acceptable prompt inputs: + # Implicit single-turn assuming default role='user' (works with Canary-1B) + # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes + # Explicit single-turn prompt: + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es + # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # Explicit multi-turn prompt: + # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' + prompt: dict = field(default_factory=dict) # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models decoder_type: Optional[str] = None @@ -167,7 +187,7 @@ class TranscriptionConfig: att_context_size: Optional[list] = None # Use this for model-specific changes before transcription - model_change: ModelChangeConfig = ModelChangeConfig() + model_change: ModelChangeConfig = field(default_factory=ModelChangeConfig) # Config for word / character error rate calculation calculate_wer: bool = True @@ -179,20 +199,22 @@ class TranscriptionConfig: # if True, will also skip writing anything to the output file return_transcriptions: bool = False - # Set to False to return text instead of hypotheses from the transcribe function, so as to save memory - return_hypotheses: bool = True - # key for groundtruth text in manifest gt_text_attr_name: str = "text" + gt_lang_attr_name: str = "lang" + + extract_nbest: bool = False # Extract n-best hypotheses from the model - # Use model's transcribe() function instead of transcribe_partial_audio() by default - # Only use transcribe_partial_audio() when the audio is too long to fit in memory - # Your manifest input should have `offset` field to use transcribe_partial_audio() - allow_partial_transcribe: bool = False + calculate_rtfx: bool = False + warmup_steps: int = 0 # by default - no warmup + run_steps: int = 1 # by default - single run @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: + """ + Transcribes the input audio and can be used to infer with Encoder-Decoder models. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: @@ -217,6 +239,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis logging.info(f"Will apply on-the-fly augmentation on samples during transcription: {augmentor} ") # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) if cfg.cuda is None: if torch.cuda.is_available(): device = [0] # use 0th CUDA device @@ -247,11 +270,29 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.set_trainer(trainer) asr_model = asr_model.eval() + if (cfg.compute_dtype is not None and cfg.compute_dtype != "float32") and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") + + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + + compute_dtype: torch.dtype + if cfg.compute_dtype is None: + can_use_bfloat16 = (not cfg.amp) and map_location.type == "cuda" and torch.cuda.is_bf16_supported() + if can_use_bfloat16: + compute_dtype = torch.bfloat16 + else: + compute_dtype = torch.float32 + else: + assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} + compute_dtype = getattr(torch, cfg.compute_dtype) + + asr_model.to(compute_dtype) + # we will adjust this flag if the model does not support it - compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs - # has to be True if timestamps are required - preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment + + if cfg.timestamps: + cfg.return_hypotheses = True # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -260,7 +301,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis elif isinstance(asr_model, EncDecHybridRNNTCTCModel): if cfg.decoder_type and cfg.decoder_type not in ['ctc', 'rnnt']: raise ValueError('Hybrid model only support ctc or rnnt decoding!') - else: # rnnt model, there could be other models needs to be addressed. + elif isinstance(asr_model, EncDecRNNTModel): if cfg.decoder_type and cfg.decoder_type != 'rnnt': raise ValueError('RNNT model only support rnnt decoding!') @@ -271,7 +312,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs - cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment + if cfg.extract_nbest: + cfg.multitask_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.multitask_decoding) elif cfg.decoder_type is not None: # TODO: Support compute_langs in CTC eventually @@ -279,9 +322,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis raise ValueError("CTC models do not support `compute_langs` at the moment") decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding - decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it - if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = preserve_alignment + if cfg.extract_nbest: + decoding_cfg.beam.return_best_hypothesis = False + cfg.return_hypotheses = True if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -291,17 +334,19 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # Check if ctc or rnnt model elif hasattr(asr_model, 'joint'): # RNNT model + if cfg.extract_nbest: + cfg.rnnt_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 - cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs - if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") - cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps + if cfg.extract_nbest: + cfg.ctc_decoding.beam.return_best_hypothesis = False + cfg.return_hypotheses = True asr_model.change_decoding_strategy(cfg.ctc_decoding) @@ -311,31 +356,16 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis isinstance(asr_model, EncDecHybridRNNTCTCModel) and cfg.decoder_type == "ctc" ): cfg.decoding = cfg.ctc_decoding + elif isinstance(asr_model.decoding, MultiTaskDecoding): + cfg.decoding = cfg.multitask_decoding else: cfg.decoding = cfg.rnnt_decoding - if isinstance(asr_model, EncDecMultiTaskModel): - # Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function - partial_audio = False - filepaths = cfg.dataset_manifest - assert cfg.dataset_manifest is not None - else: - # prepare audio filepaths and decide wether it's partial audio - filepaths, partial_audio = prepare_audio_data(cfg) + filepaths, sorted_manifest_path = prepare_audio_data(cfg) - if not cfg.allow_partial_transcribe: - # by defatul, use model's transcribe() function, unless partial audio is required - partial_audio = False + remove_path_after_done = sorted_manifest_path if sorted_manifest_path is not None else None - # setup AMP (optional) - if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): - logging.info("AMP enabled!\n") - autocast = torch.cuda.amp.autocast - else: - - @contextlib.contextmanager - def autocast(dtype=None): - yield + filepaths = sorted_manifest_path if sorted_manifest_path is not None else filepaths # Compute output filename cfg = compute_output_filename(cfg, model_name) @@ -350,37 +380,82 @@ def autocast(dtype=None): # transcribe audio - amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 + if cfg.calculate_rtfx: + total_duration = 0.0 + + with open(cfg.dataset_manifest, "rt") as fh: + for line in fh: + item = json.loads(line) + if "duration" not in item: + raise ValueError( + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." + ) + total_duration += item["duration"] + + if cfg.warmup_steps == 0: + logging.warning( + "RTFx measurement enabled, but warmup_steps=0. " + "At least one warmup step is recommended to measure RTFx" + ) - with autocast(dtype=amp_dtype): + timer = SimpleTimer() + model_measurements = [] + with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=cfg.amp): with torch.no_grad(): - if partial_audio: - transcriptions = transcribe_partial_audio( - asr_model=asr_model, - path2manifest=cfg.dataset_manifest, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, - decoder_type=cfg.decoder_type, - ) - else: + override_cfg = asr_model.get_transcribe_config() + override_cfg.batch_size = cfg.batch_size + override_cfg.num_workers = cfg.num_workers + override_cfg.return_hypotheses = cfg.return_hypotheses + override_cfg.channel_selector = cfg.channel_selector + override_cfg.augmentor = augmentor + override_cfg.text_field = cfg.gt_text_attr_name + override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.timestamps = cfg.timestamps + if hasattr(override_cfg, "prompt"): + override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) + + device = next(asr_model.parameters()).device + for run_step in range(cfg.warmup_steps + cfg.run_steps): + if run_step < cfg.warmup_steps: + logging.info(f"Running warmup step {run_step}") + # reset timer + timer.reset() + timer.start(device=device) + # call transcribe transcriptions = asr_model.transcribe( - paths2audio_files=filepaths, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - return_hypotheses=cfg.return_hypotheses, - channel_selector=cfg.channel_selector, - augmentor=augmentor, + audio=filepaths, + override_config=override_cfg, + timestamps=cfg.timestamps, ) + # stop timer, log time + timer.stop(device=device) + logging.info(f"Model time for iteration {run_step}: {timer.total_sec():.3f}") + if run_step >= cfg.warmup_steps: + model_measurements.append(timer.total_sec()) + + model_measurements_np = np.asarray(model_measurements) + logging.info( + f"Model time avg: {model_measurements_np.mean():.3f}" + + (f" (std: {model_measurements_np.std():.3f})" if cfg.run_steps > 1 else "") + ) - logging.info(f"Finished transcribing {len(filepaths)} files !") + if cfg.dataset_manifest is not None: + logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}") + if cfg.presort_manifest: + transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions) + else: + logging.info(f"Finished transcribing {len(filepaths)} files !") logging.info(f"Writing transcriptions into file: {cfg.output_filename}") - # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis + # if transcriptions form a tuple of (best_hypotheses, all_hypotheses) if type(transcriptions) == tuple and len(transcriptions) == 2: - transcriptions = transcriptions[0] + if cfg.extract_nbest: + # extract all hypotheses if exists + transcriptions = transcriptions[1] + else: + # extract just best hypothesis + transcriptions = transcriptions[0] if cfg.return_transcriptions: return transcriptions @@ -392,10 +467,15 @@ def autocast(dtype=None): model_name, filepaths=filepaths, compute_langs=compute_langs, - compute_timestamps=compute_timestamps, + timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") + # clean-up + if cfg.presort_manifest is not None: + if remove_path_after_done is not None: + os.unlink(remove_path_after_done) + if cfg.calculate_wer: output_manifest_w_wer, total_res, _ = cal_write_wer( pred_manifest=output_filename, @@ -410,6 +490,13 @@ def autocast(dtype=None): logging.info(f"Writing prediction and error rate of each sample to {output_manifest_w_wer}!") logging.info(f"{total_res}") + if cfg.calculate_rtfx: + rtfx_measurements = total_duration / model_measurements_np + logging.info( + f"Model RTFx on the dataset: {rtfx_measurements.mean():.3f}" + + (f" (std: {rtfx_measurements.std():.3f})" if cfg.run_steps > 1 else "") + ) + return cfg From 834844642f0b27377af7a75bd66c2adc0d044a62 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 06:08:23 -0700 Subject: [PATCH 13/21] Fixed PATH inside ci, and missed import level Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 1 + .../inference/asr/nemo/utils/speech_to_text_with_vad.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a0132d67..9aab518e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 + PATH: /usr/local/bin:/usr/bin:/bin run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] diff --git a/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py index 649155c5..2c734754 100644 --- a/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py +++ b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py @@ -72,7 +72,7 @@ from nemo.collections.asr.data import feature_to_text_dataset from nemo.collections.asr.metrics.wer import word_error_rate from nemo.collections.asr.models import ASRModel, EncDecClassificationModel -from nemo.collections.asr.parts.submodules import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest from nemo.collections.asr.parts.utils.vad_utils import ( From 76f69ba2b5c882033d7f175dad1caa8c67640c3f Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 23 Jul 2025 06:53:07 -0700 Subject: [PATCH 14/21] Changed NeMo to 2.3.1 Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 3 +-- requirements/main.txt | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9aab518e..14f64412 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,7 +72,7 @@ jobs: sudo apt-get install -y libsndfile1 ffmpeg sox libsox-fmt-mp3 pip install pytorch_lightning pip install Cython wheel # need to pre-install to avoid error in nemo installation - pip install nemo-toolkit[asr,nlp]==2.3.2 + pip install nemo-toolkit[asr,nlp]==2.3.1 pip install nemo_text_processing pip install -r requirements/huggingface.txt pip install certifi #this needed to avoid problems with certificates [COORAL] @@ -85,7 +85,6 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 - PATH: /usr/local/bin:/usr/bin:/bin run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] diff --git a/requirements/main.txt b/requirements/main.txt index 5d0a4c37..ff81b0ab 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -22,7 +22,7 @@ jiwer>=3.1.0,<4.0.0 pyarrow>=8.0.0,<14.0.0 datasets>=2.14.0,<3.0.0 # toloka-kit # Temporarily disabled due to Toloka's technical pause; keep as reference for past and future API support -# for some processers, additionally https://github.com/NVIDIA/NeMo 2.3.2 is required +# for some processers, additionally https://github.com/NVIDIA/NeMo 2.3.1 is required # for some processers, additionally nemo_text_processing is required # for mcv: apt-get update && apt-get upgrade -y && apt-get install -y sox libsox-fmt-all @@ -30,4 +30,4 @@ datasets>=2.14.0,<3.0.0 # pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` # for vLLMInference processor is required: pip install "optree>=0.13.0" vllm -# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.3.2" +# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.3.1" From bdea1f8859c1f4bb0081c7da89f1e84a0bd4f948 Mon Sep 17 00:00:00 2001 From: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:46:46 +0400 Subject: [PATCH 15/21] Added threading backend --- .../utils/convert_to_tarred_audio_dataset.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py index c06462a9..d7e11688 100644 --- a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -94,6 +94,10 @@ from tabulate import tabulate from tqdm import tqdm +PARALLEL_BACKEND = "loky" +if os.environ.get("USE_THREADING_BACKEND") == "1": + PARALLEL_BACKEND = "threading" + try: import create_dali_tarred_dataset_index as dali_index @@ -288,13 +292,12 @@ def create_new_dataset( manifest_folder, _ = os.path.split(manifest_path) - with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel: - # Call parallel tarfile construction - new_entries_list = parallel( + with parallel_backend(backend, n_jobs=num_workers): + new_entries_list = Parallel(verbose=config.num_shards)( delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder, only_manifests) for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) ) - + if config.shard_manifests: sharded_manifests_dir = target_dir + '/sharded_manifests' if not os.path.exists(sharded_manifests_dir): @@ -492,9 +495,8 @@ def create_concatenated_dataset( manifest_folder, _ = os.path.split(base_manifest_path) - with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel: - # Call parallel tarfile construction - new_entries_list = parallel( + with parallel_backend(backend, n_jobs=num_workers): + new_entries_list = Parallel(verbose=config.num_shards)( delayed(self._create_shard)( entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder, only_manifests ) @@ -1018,4 +1020,4 @@ def create_tar_datasets( ) parser.add_argument('--workers', type=int, default=1, help='Number of worker processes') args = parser.parse_args() - main(args) \ No newline at end of file + main(args) From e165b502c79f08429352ae2de82b0473f568b3e6 Mon Sep 17 00:00:00 2001 From: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:49:30 +0400 Subject: [PATCH 16/21] Fix parallel related issues --- .../manage_files/utils/convert_to_tarred_audio_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py index d7e11688..52ff5882 100644 --- a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -89,7 +89,7 @@ import numpy as np import soundfile as sf -from joblib import Parallel, delayed +from joblib import Parallel, delayed, parallel_backend from omegaconf import DictConfig, OmegaConf, open_dict from tabulate import tabulate from tqdm import tqdm @@ -292,7 +292,7 @@ def create_new_dataset( manifest_folder, _ = os.path.split(manifest_path) - with parallel_backend(backend, n_jobs=num_workers): + with parallel_backend(PARALLEL_BACKEND, n_jobs=num_workers): new_entries_list = Parallel(verbose=config.num_shards)( delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder, only_manifests) for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) @@ -495,7 +495,7 @@ def create_concatenated_dataset( manifest_folder, _ = os.path.split(base_manifest_path) - with parallel_backend(backend, n_jobs=num_workers): + with parallel_backend(PARALLEL_BACKEND, n_jobs=num_workers): new_entries_list = Parallel(verbose=config.num_shards)( delayed(self._create_shard)( entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder, only_manifests From 710132209bc52ea387718c62a92614c6c669141f Mon Sep 17 00:00:00 2001 From: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Date: Fri, 25 Jul 2025 12:50:32 +0400 Subject: [PATCH 17/21] USE_THREADING_BACKEND = 1 --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 14f64412..88a81538 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -85,6 +85,7 @@ jobs: AWS_SECRET_KEY: ${{ secrets.AWS_SECRET_KEY }} AWS_ACCESS_KEY: ${{ secrets.AWS_ACCESS_KEY }} CLEAN_UP_TMP_PATH: 1 + USE_THREADING_BACKEND: 1 run: | wget https://uit.stanford.edu/sites/default/files/2023/10/11/incommon-rsa-ca2.pem #downloading cert manually [for CORAL] From d8fa030f98cfb7794906074c34a05c230e1ae3bf Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Fri, 25 Jul 2025 03:45:40 -0700 Subject: [PATCH 18/21] NeMo version to 2.2.1 Signed-off-by: Sasha Meister --- .github/workflows/tests.yml | 2 +- requirements/main.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 88a81538..b3e9dade 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -72,7 +72,7 @@ jobs: sudo apt-get install -y libsndfile1 ffmpeg sox libsox-fmt-mp3 pip install pytorch_lightning pip install Cython wheel # need to pre-install to avoid error in nemo installation - pip install nemo-toolkit[asr,nlp]==2.3.1 + pip install nemo-toolkit[asr,nlp]==2.2.1 pip install nemo_text_processing pip install -r requirements/huggingface.txt pip install certifi #this needed to avoid problems with certificates [COORAL] diff --git a/requirements/main.txt b/requirements/main.txt index ff81b0ab..b4f11e73 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -22,7 +22,7 @@ jiwer>=3.1.0,<4.0.0 pyarrow>=8.0.0,<14.0.0 datasets>=2.14.0,<3.0.0 # toloka-kit # Temporarily disabled due to Toloka's technical pause; keep as reference for past and future API support -# for some processers, additionally https://github.com/NVIDIA/NeMo 2.3.1 is required +# for some processers, additionally https://github.com/NVIDIA/NeMo 2.2.1 is required # for some processers, additionally nemo_text_processing is required # for mcv: apt-get update && apt-get upgrade -y && apt-get install -y sox libsox-fmt-all @@ -30,4 +30,4 @@ datasets>=2.14.0,<3.0.0 # pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` # for vLLMInference processor is required: pip install "optree>=0.13.0" vllm -# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.3.1" +# for ConvertToTarredAudioDatasetConfig processor can be additionally required: pip install lhotse "nemo-toolkit[common]==2.2.1" From c79005252be94350c8f58bfee8df09fc1ba8f310 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 30 Jul 2025 04:22:26 -0700 Subject: [PATCH 19/21] =?UTF-8?q?Changes=20addressing=20the=20reviewer?= =?UTF-8?q?=E2=80=99s=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sasha Meister --- .../inference/asr/nemo/utils/transcribe_speech.py | 3 +++ .../manage_files/convert_to_tarred_audio_dataset.py | 12 +++++++----- .../utils/convert_to_tarred_audio_dataset.py | 4 ++++ .../utils/create_dali_tarred_dataset_index.py | 4 +++- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py b/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py index 7ca29238..41714bd9 100644 --- a/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py +++ b/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v2.2.1/examples/asr/transcribe_speech.py +# It is currently only compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. + import json import os from dataclasses import dataclass, field, is_dataclass diff --git a/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py index 93f4a970..833c3167 100644 --- a/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/convert_to_tarred_audio_dataset.py @@ -23,6 +23,7 @@ from sdp.processors.base_processor import BaseProcessor from sdp.processors.manage_files.utils.convert_to_tarred_audio_dataset import create_tar_datasets +from sdp.logging import logger @dataclass class ConvertToTarredAudioDatasetConfig: @@ -128,8 +129,8 @@ def process(self): bucket_config.target_dir = os.path.join(self.cfg.target_dir, f"bucket{i_bucket+1}") - print(f"Creating bucket {i_bucket+1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ...") - print(f"Results are being saved at: {bucket_config.target_dir}.") + logger.info(f"Creating bucket {i_bucket+1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ...") + logger.info(f"Results are being saved at: {bucket_config.target_dir}.") # Create tarred dataset for the current bucket create_tar_datasets( @@ -143,10 +144,11 @@ def process(self): for line in tqdm(bin_f, desc="Writing output manifest.."): entry = json.loads(line) entry['bucket_id'] = i_bucket - line = json.dumps(entry) - fout.writelines(f'{line}\n') + #line = json.dumps(entry) + json.dump(entry, fout, ensure_ascii=False) + fout.write('\n') - print(f"Bucket {i_bucket+1} is created.") + logger.info(f"Bucket {i_bucket+1} is created.") else: # No bucketing — create single tarred dataset diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py index 52ff5882..f7713e58 100644 --- a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -75,6 +75,10 @@ --write_metadata """ + +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +# It is currently compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. + import argparse import copy import json diff --git a/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py index 1ae64dc5..a2c794f3 100644 --- a/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py +++ b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +# It is currently compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. + import glob import logging import os @@ -38,7 +41,6 @@ logging.basicConfig(level=logging.INFO) - @dataclass class DALITarredIndexConfig: tar_dir: str = MISSING # Path to the existing dataset's manifest From 13d19c03c85f85a3dcab288e886d4cf67052376c Mon Sep 17 00:00:00 2001 From: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Date: Wed, 30 Jul 2025 16:18:06 +0400 Subject: [PATCH 20/21] Apply suggestions from code review Co-authored-by: lilithgrigoryan <38436437+lilithgrigoryan@users.noreply.github.com> --- .../manage_files/utils/convert_to_tarred_audio_dataset.py | 3 ++- .../manage_files/utils/create_dali_tarred_dataset_index.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py index f7713e58..ae1a12eb 100644 --- a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -76,7 +76,8 @@ """ -# This file is copied over from https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +# This file is copied over from +https://github.com/NVIDIA/NeMo/blob/v2.2.1/scripts/speech_recognition/convert_to_tarred_audio_dataset.py # It is currently compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. import argparse diff --git a/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py index a2c794f3..e6086c23 100644 --- a/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py +++ b/sdp/processors/manage_files/utils/create_dali_tarred_dataset_index.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# This file is copied over from https://github.com/NVIDIA/NeMo/blob/main/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v2.2.1/scripts/speech_recognition/convert_to_tarred_audio_dataset.py # It is currently compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. import glob From 574107fe7e740e6771e887a45037f0dc000fb697 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Wed, 30 Jul 2025 06:13:24 -0700 Subject: [PATCH 21/21] Docs fix Signed-off-by: Sasha Meister --- .../manage_files/utils/convert_to_tarred_audio_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py index ae1a12eb..33431b84 100644 --- a/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py +++ b/sdp/processors/manage_files/utils/convert_to_tarred_audio_dataset.py @@ -76,8 +76,7 @@ """ -# This file is copied over from -https://github.com/NVIDIA/NeMo/blob/v2.2.1/scripts/speech_recognition/convert_to_tarred_audio_dataset.py +# This file is copied over from https://github.com/NVIDIA/NeMo/blob/v2.2.1/scripts/speech_recognition/convert_to_tarred_audio_dataset.py # It is currently compatible with NeMo v2.2.1 To use a different version of NeMo, please modify the file. import argparse