From a81f7874aab3885752f398910ef42e8b71b8fe19 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Sun, 17 Mar 2024 20:26:17 +0000 Subject: [PATCH 1/6] YouTube German config and new processors Signed-off-by: Sasha Meister --- dataset_configs/youtube/de.yaml | 77 +++++++ sdp/processors/__init__.py | 2 +- .../datasets/commoncrawl/__init__.py | 8 +- sdp/processors/datasets/youtube/__init__.py | 17 ++ .../datasets/youtube/aggregate_segments.py | 68 ++++++ .../youtube/create_initial_manifest.py | 90 ++++++++ sdp/processors/datasets/youtube/utils.py | 103 +++++++++ sdp/processors/nemo/asr_inference.py | 32 +++ .../nemo/transcribe_speech_parallel.py | 208 ++++++++++++++++++ 9 files changed, 600 insertions(+), 5 deletions(-) create mode 100644 dataset_configs/youtube/de.yaml create mode 100644 sdp/processors/datasets/youtube/__init__.py create mode 100644 sdp/processors/datasets/youtube/aggregate_segments.py create mode 100644 sdp/processors/datasets/youtube/create_initial_manifest.py create mode 100644 sdp/processors/datasets/youtube/utils.py create mode 100644 sdp/processors/nemo/transcribe_speech_parallel.py diff --git a/dataset_configs/youtube/de.yaml b/dataset_configs/youtube/de.yaml new file mode 100644 index 00000000..451bdc43 --- /dev/null +++ b/dataset_configs/youtube/de.yaml @@ -0,0 +1,77 @@ +processors_to_run: "0:" +base_dir: "/ws/test_subset" +workspace_dir: "/ws/test_subset_out" +lang: de +min_duration: 1 +max_duration: 40 + +processors: + # Create initial manifests based on pairs of .opus audio + .srt transcript (with ground-truth timestamps) + - _target_: sdp.processors.datasets.youtube.CreateInitialManifest + data_dir: ${base_dir} + output_audio_dir: ${workspace_dir}/audio/wav_samples + output_manifest_file: ${workspace_dir}/manifest1.json + chunksize: 10 + in_memory_chunksize: 400 + + # Aggregate ground-truth segments to longer one based on duration threshold + - _target_: sdp.processors.datasets.youtube.AggregateSegments + max_duration: ${max_duration} + output_segments_audio_dir: ${workspace_dir}/audio/wav_segments + output_manifest_file: ${workspace_dir}/manifest2.json + + # Filter out samples which duration is out of range 0-40 sec. + - _target_: sdp.processors.DropHighLowDuration + output_manifest_file: ${workspace_dir}/manifest3.json + low_duration_threshold: ${min_duration} + high_duration_threshold: ${max_duration} + + # Identify language of the text + - _target_: sdp.processors.datasets.commoncrawl.TextLid + output_manifest_file: ${workspace_dir}/manifest4.json + input_text_key: orig_text + output_lang_key: text_lang + device: cuda + pretrained_model: "jb2k/bert-base-multilingual-cased-language-detection" + drop_text_duplicates: True + + - _target_: sdp.processors.datasets.commoncrawl.Lang2Iso + output_manifest_file: ${workspace_dir}/manifest5.json + input_lang_key: text_lang + output_lang_key: text_lang + + ## Filter out samples with text in non-target language + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest6.json + input_value_key: text_lang + target_value: ${lang} + + # Identify language of the audio + - _target_: sdp.processors.datasets.commoncrawl.AudioLid + output_manifest_file: ${workspace_dir}/manifest7.json + input_audio_key: audio_filepath + output_lang_key: audio_lang + device: cuda + pretrained_model: "langid_ambernet" + + ## Filter out samples with audio in non-target language + - _target_: sdp.processors.PreserveByValue + output_manifest_file: ${workspace_dir}/manifest8.json + input_value_key: audio_lang + target_value: ${lang} + + # ASR Inference + - _target_: sdp.processors.ASRInferenceParallel + output_manifest_file: ${workspace_dir}/manifest9.json + pretrained_model: nvidia/stt_${lang}_fastconformer_hybrid_large_pc + batch_size: 64 + devices: 4 + + + + + + + + + diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index f7a896e1..2ab441c5 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -74,5 +74,5 @@ from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, ) -from sdp.processors.nemo.asr_inference import ASRInference +from sdp.processors.nemo.asr_inference import ASRInference, ASRInferenceParallel from sdp.processors.nemo.pc_inference import PCInference diff --git a/sdp/processors/datasets/commoncrawl/__init__.py b/sdp/processors/datasets/commoncrawl/__init__.py index b4fe3020..7ee1c072 100644 --- a/sdp/processors/datasets/commoncrawl/__init__.py +++ b/sdp/processors/datasets/commoncrawl/__init__.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .commoncrawl import UseSonar, BLEUScore, Subprocess, NmtSubprocess, PreserveByValue, \ - Lang2Iso, SplitByVttSentence, SplitByVtt, AudioLid, TextLid, AllVttText, TxtToVtt, \ - ReadParquet, CreateInitialManifestCC, FfmpegConvert, ASR_HF, AlignerSubprocess, \ - SplitByAligner, JoinBy, EvalBandwidth, CreateInitialManifestExt, AudioDuration, \ +from .commoncrawl import UseSonar, BLEUScore, Subprocess, NmtSubprocess, \ + Lang2Iso, SplitByVttSentence, AudioLid, TextLid, AllVttText, TxtToVtt, \ + ReadParquet, CreateInitialManifestCC, ASR_HF, AlignerSubprocess, \ + SplitByAligner, JoinBy, EvalBandwidth, CreateInitialManifestExt, \ TrainDevTestSplitCC, DropAbsPath, GetSpecificFiles, CopyFiles, ManifestToUtf8 diff --git a/sdp/processors/datasets/youtube/__init__.py b/sdp/processors/datasets/youtube/__init__.py new file mode 100644 index 00000000..8ee20226 --- /dev/null +++ b/sdp/processors/datasets/youtube/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .create_initial_manifest import CreateInitialManifest +from .utils import parse_srt +from .aggregate_segments import * \ No newline at end of file diff --git a/sdp/processors/datasets/youtube/aggregate_segments.py b/sdp/processors/datasets/youtube/aggregate_segments.py new file mode 100644 index 00000000..f5aaef07 --- /dev/null +++ b/sdp/processors/datasets/youtube/aggregate_segments.py @@ -0,0 +1,68 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pydub import AudioSegment +import os + +from sdp.processors.base_processor import BaseParallelProcessor +from sdp.processors.datasets.youtube.utils import RawSegment, AggregatedSegment, get_audio_segment + + +class AggregateSegments(BaseParallelProcessor): + def __init__( + self, + max_duration: float = 40.0, + crop_audio_segments: bool = True, + output_segments_audio_dir: str = None, + **kwargs, + ): + super().__init__(**kwargs) + self.max_duration = max_duration + self.crop_audio_segments = crop_audio_segments + self.output_segments_audio_dir = output_segments_audio_dir + + def prepare(self): + if self.crop_audio_segments and self.output_segments_audio_dir: + os.makedirs(os.path.join(self.output_segments_audio_dir), exist_ok=True) + + def process_dataset_entry(self, data_entry: dict): + sample_id = data_entry['sample_id'] + segmnets = data_entry['segments'] + agg_segments = [] + + first_segment = RawSegment(**segmnets[0]) + agg_segment = AggregatedSegment(segment=first_segment, segment_id=1, sample_id=sample_id, + output_audio_dir = self.output_segments_audio_dir) + + for segment in segmnets[1 : ]: + segment = RawSegment(**segment) + + if (not agg_segment.duration_match or + agg_segment.duration >= self.max_duration or + segment.end_time - agg_segment.start_time >= self.max_duration): + agg_segments.append(agg_segment.to_dataentry()) + agg_segment = AggregatedSegment(segment=segment, + segment_id=len(agg_segments) + 1, sample_id=sample_id, + output_audio_dir = self.output_segments_audio_dir) + else: + agg_segment.aggregate(segment) + else: + agg_segments.append(agg_segment.to_dataentry()) + + if self.crop_audio_segments: + audio = AudioSegment.from_wav(data_entry['audio_filepath']) + for agg_segment in agg_segments: + get_audio_segment(audio = audio, + start_time = agg_segment.data['start_time'], + end_time = agg_segment.data['end_time'], + output_audio_filepath = agg_segment.data['audio_filepath']) + + return agg_segments \ No newline at end of file diff --git a/sdp/processors/datasets/youtube/create_initial_manifest.py b/sdp/processors/datasets/youtube/create_initial_manifest.py new file mode 100644 index 00000000..3bca6ee1 --- /dev/null +++ b/sdp/processors/datasets/youtube/create_initial_manifest.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict +from glob import glob + +from sdp.logging import logger +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +from sdp.processors.datasets.youtube.utils import parse_srt, Sample +from sdp.utils.common import ffmpeg_convert + +class CreateInitialManifest(BaseParallelProcessor): + def __init__( + self, + data_dir: str, + output_audio_dir: str, + audio_file_extenstion: str = ".opus", + target_samplerate: int = 16000, + target_nchannels: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.data_dir = data_dir + self.output_audio_dir = output_audio_dir + self.audio_file_extenstion = audio_file_extenstion + self.target_samplerate = target_samplerate + self.target_nchannels = target_nchannels + + def _get_manifest(self): + audio_filepaths = glob(f"{self.data_dir}/*{self.audio_file_extenstion}") + samples = [] + for audio_filepath in audio_filepaths: + sample = Sample(orig_audio_filepath = audio_filepath) + sample.sample_id = os.path.basename(audio_filepath).replace(self.audio_file_extenstion, "") # Get sample_id + + # Get .srt file, which relaterd to source audio + srt_filepaths = glob(f"{self.data_dir}/{sample.sample_id}.*.srt") + + if len(srt_filepaths) < 1: + logger.warning(f"Sample \"{sample.sample_id}\" has no related .srt files. Skipping") + continue + + srt_filepath = srt_filepaths[0] + if len(srt_filepaths) > 1: + logger.warning(f"Sample \"{sample.sample_id}\" has multiple related .srt files: {', '.join(srt_filepaths)}. \ + Only first file will be used for parsing - {srt_filepaths[0]}, other related .srt files will be skipped.") + + sample.srt_filepath = srt_filepath + samples.append(sample.to_dataentry()) + + return samples + + def prepare(self): + os.makedirs(os.path.join(self.output_audio_dir), exist_ok=True) + + def read_manifest(self): + data_entries = self._get_manifest() + return data_entries + + def process_dataset_entry(self, data_entry: DataEntry): + # Convert source_audio_filepath to .wav + data_entry.data['audio_filepath'] = os.path.join(self.output_audio_dir, os.path.basename(data_entry.data['orig_audio_filepath']).replace(self.audio_file_extenstion, ".wav")) + + ffmpeg_convert(input_file=data_entry.data['orig_audio_filepath'], + output_wav=data_entry.data['audio_filepath'], + sample_rate=self.target_samplerate, + num_channels=self.target_nchannels) + + if not os.path.exists(data_entry.data['audio_filepath']): + return [] + + # Parse segments from .srt + segments = parse_srt(data_entry.data['srt_filepath'], verify_duration = True, wav_filepath=data_entry.data['audio_filepath']) + + if len(segments) > 0: + data_entry.data['segments'] = [segment.__dict__ for segment in segments] + + return [data_entry] \ No newline at end of file diff --git a/sdp/processors/datasets/youtube/utils.py b/sdp/processors/datasets/youtube/utils.py new file mode 100644 index 00000000..9f5c9c5e --- /dev/null +++ b/sdp/processors/datasets/youtube/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pysrt +from pydub import AudioSegment +from dataclasses import dataclass +import re +import os +from sdp.processors.base_processor import DataEntry + + +@dataclass +class RawSegment: + segment_id: int = None + start_time: float = None + end_time: float = None + duration: str = None + duration_match: bool = None + orig_text: str = None + + def to_dataentry(self): + return DataEntry(data = self.__dict__) + + +class AggregatedSegment(RawSegment): + def __init__(self, segment: dict, segment_id: int, sample_id: str, output_audio_dir: str): + super().__init__(**segment.__dict__) + self.segment_id = f"{sample_id}_{str(segment_id).zfill(4)}" + self.audio_filepath = os.path.join(output_audio_dir, f'{self.segment_id}.wav') if output_audio_dir is not None else None + + def aggregate(self, segment): + self.end_time = segment.end_time + self.duration = self.end_time - self.start_time + self.orig_text = re.sub("\s+", " ", f"{self.orig_text} {segment.orig_text}".strip()) + +@dataclass +class Sample: + sample_id: str = None + srt_filepath: str = None + orig_audio_filepath: str = None + audio_filepath: str = None + segments: list[RawSegment | AggregatedSegment] = None + + def to_dataentry(self): + data = self.__dict__ + data['segments'] = [segment.data.__dict__ for segment in data['segments']] if data['segments'] is not None else [] + return DataEntry(data = data) + + +def get_audio_segment(audio, start_time: float, end_time: float, output_audio_filepath: str = None): + start_time = start_time * 1000 + end_time = end_time * 1000 + audio_segment = audio[start_time : end_time] + + if output_audio_filepath: + audio_segment.export(output_audio_filepath, format="wav") + return audio_segment + + +def get_audio_segment_duration(audio, start_time, end_time): + audio_segment = get_audio_segment(audio, start_time, end_time) + return audio_segment.duration_seconds + + +def parse_srt(srt_filepath, verify_duration: bool = True, wav_filepath: str = None): + subs = pysrt.open(srt_filepath) + srt_segments = [] + + if verify_duration and wav_filepath: + audio = AudioSegment.from_wav(wav_filepath) + else: + audio = None + + epsilon = 1e-2 + + for sub in subs: + segment = RawSegment(segment_id = sub.index, + start_time = sub.start.ordinal / 1000, + end_time = sub.end.ordinal / 1000, + orig_text = sub.text_without_tags) + + duration_by_timestemps = segment.end_time - segment.start_time + + if audio: + segment.duration = get_audio_segment_duration(audio, segment.start_time, segment.end_time) + segment.duration_match = abs(segment.duration - duration_by_timestemps) < epsilon + else: + segment.duration = duration_by_timestemps + + srt_segments.append(segment) + + return srt_segments \ No newline at end of file diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/nemo/asr_inference.py index 5af6e254..5c2c1bcb 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/nemo/asr_inference.py @@ -14,6 +14,7 @@ import os import subprocess +import shutil from pathlib import Path from sdp.processors.base_processor import BaseProcessor @@ -74,3 +75,34 @@ def process(self): shell=True, check=True, ) + + +class ASRInferenceParallel(BaseProcessor): + def __init__( + self, + pretrained_model: str, + batch_size: int = 32, + devices: int = 2, + **kwargs, + ): + super().__init__(**kwargs) + self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech_parallel.py" + self.pretrained_model = pretrained_model + self.batch_size = batch_size + self.devices = devices + self.output_manifest_dir = self.output_manifest_file.replace(".json", "") + + def process(self): + subprocess.run( + f"python {self.script_path} " + f"model={self.pretrained_model} " + f"predict_ds.manifest_filepath={self.input_manifest_file} " + f"output_path={self.output_manifest_dir} " + f"predict_ds.batch_size={self.batch_size} " + f"trainer.devices={self.devices} ", + shell=True, + check=True, + ) + + os.rename(os.path.join(self.output_manifest_dir, "predictions_all.json"), self.output_manifest_file) + shutil.rmtree(self.output_manifest_dir) \ No newline at end of file diff --git a/sdp/processors/nemo/transcribe_speech_parallel.py b/sdp/processors/nemo/transcribe_speech_parallel.py new file mode 100644 index 00000000..c0af8f97 --- /dev/null +++ b/sdp/processors/nemo/transcribe_speech_parallel.py @@ -0,0 +1,208 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# ASR transcribe/inference with multi-GPU/multi-node support for large datasets +# It supports both tarred and non-tarred datasets +# Arguments +# model: path to a nemo/PTL checkpoint file or name of a pretrained model +# predict_ds: config of the dataset/dataloader +# output_path: path to store the predictions +# return_predictions: whether to return the predictions as output other than writing into the files +# use_cer: whether to calculate the error in terms of CER or use the default WER +# +# Results of each GPU/worker is written into a file named 'predictions_{rank}.json, and aggregated results of all workers are written into 'predictions_all.json' + +Example for non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_conformer_ctc_large \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + +Example for Hybrid-CTC/RNNT models with non-tarred datasets: + +python transcribe_speech_parallel.py \ + model=stt_en_fastconformer_hybrid_large \ + decoder_type=ctc \ + predict_ds.manifest_filepath=/dataset/manifest_file.json \ + predict_ds.batch_size=16 \ + output_path=/tmp/ + +Example for tarred datasets: + +python transcribe_speech_parallel.py \ + predict_ds.is_tarred=true \ + predict_ds.manifest_filepath=/tarred_dataset/tarred_audio_manifest.json \ + predict_ds.tarred_audio_filepaths=/tarred_dataset/audio__OP_0..127_CL_.tar \ + ... + +By default the trainer uses all the GPUs available and default precision is FP32. +By setting the trainer config you may control these configs. For example to do the predictions with AMP on just two GPUs: + +python transcribe_speech_parallel.py \ + trainer.precision=16 \ + trainer.devices=2 \ + ... + +You may control the dataloader's config by setting the predict_ds: + +python transcribe_speech_parallel.py \ + predict_ds.num_workers=8 \ + predict_ds.min_duration=2.0 \ + predict_ds.sample_rate=16000 \ + model=stt_en_conformer_ctc_small \ + ... + +""" + + +import itertools +import json +import os +from dataclasses import dataclass, is_dataclass +from typing import Optional + +import pytorch_lightning as ptl +import torch +from omegaconf import MISSING, OmegaConf + +from nemo.collections.asr.data.audio_to_text_dataset import ASRPredictionWriter +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.models import ASRModel, EncDecHybridRNNTCTCModel +from nemo.collections.asr.models.configs.asr_models_config import ASRDatasetConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.core.config import TrainerConfig, hydra_runner +from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + + +@dataclass +class ParallelTranscriptionConfig: + model: Optional[str] = None # name + predict_ds: ASRDatasetConfig = ASRDatasetConfig(return_sample_id=True, num_workers=4) + output_path: str = MISSING + + # when return_predictions is enabled, the prediction call would keep all the predictions in memory and return them when prediction is done + return_predictions: bool = False + use_cer: bool = False + + # decoding strategy for RNNT models + rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() + + # decoder type: ctc or rnnt, can be used to switch between CTC and RNNT decoder for Hybrid RNNT/CTC models + decoder_type: Optional[str] = None + # att_context_size can be set for cache-aware streaming models with multiple look-aheads + att_context_size: Optional[list] = None + + trainer: TrainerConfig = TrainerConfig(devices=-1, accelerator="gpu", strategy="ddp") + + +def match_train_config(predict_ds, train_ds): + # It copies the important configurations from the train dataset of the model + # into the predict_ds to be used for prediction. It is needed to match the training configurations. + if train_ds is None: + return + + predict_ds.sample_rate = train_ds.get("sample_rate", 16000) + cfg_name_list = [ + "int_values", + "use_start_end_token", + "blank_index", + "unk_index", + "normalize", + "parser", + "eos_id", + "bos_id", + "pad_id", + ] + + if is_dataclass(predict_ds): + predict_ds = OmegaConf.structured(predict_ds) + for cfg_name in cfg_name_list: + if hasattr(train_ds, cfg_name): + setattr(predict_ds, cfg_name, getattr(train_ds, cfg_name)) + + return predict_ds + + +@hydra_runner(config_name="TranscriptionConfig", schema=ParallelTranscriptionConfig) +def main(cfg: ParallelTranscriptionConfig): + if cfg.model.endswith(".nemo"): + logging.info("Attempting to initialize from .nemo file") + model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") + elif cfg.model.endswith(".ckpt"): + logging.info("Attempting to initialize from .ckpt file") + model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") + else: + logging.info( + "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" + ) + model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") + + if isinstance(model, EncDecHybridRNNTCTCModel) and cfg.decoder_type is not None: + model.change_decoding_strategy(decoder_type=cfg.decoder_type) + + trainer = ptl.Trainer(**cfg.trainer) + + cfg.predict_ds.return_sample_id = True + cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + data_loader = model._setup_dataloader_from_config(cfg.predict_ds) + + os.makedirs(cfg.output_path, exist_ok=True) + # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. + global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0)) + output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") + predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file) + trainer.callbacks.extend([predictor_writer]) + + predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions) + if predictions is not None: + predictions = list(itertools.chain.from_iterable(predictions)) + samples_num = predictor_writer.close_output_file() + + logging.info( + f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." + ) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + samples_num = 0 + pred_text_list = [] + text_list = [] + if is_global_rank_zero(): + output_file = os.path.join(cfg.output_path, f"predictions_all.json") + logging.info(f"Prediction files are being aggregated in {output_file}.") + with open(output_file, 'w') as outf: + for rank in range(trainer.world_size): + input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") + with open(input_file, 'r') as inpf: + lines = inpf.readlines() + for line in lines: + item = json.loads(line) + pred_text_list.append(item["pred_text"]) + text_list.append(item["text"]) + outf.write(json.dumps(item) + "\n") + samples_num += 1 + wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) + logging.info( + f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." + ) + logging.info("{} for all predictions is {:.4f}.".format("CER" if cfg.use_cer else "WER", wer_cer)) + + +if __name__ == '__main__': + main() From 5b310b6db23f3f34d2c474b61988c78a5a492252 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 18 Mar 2024 10:46:40 +0000 Subject: [PATCH 2/6] Added Merge Manifests processor Signed-off-by: Sasha Meister --- dataset_configs/youtube/de.yaml | 15 +++++--- sdp/processors/datasets/youtube/__init__.py | 3 +- .../datasets/youtube/merge_manifests.py | 35 +++++++++++++++++++ 3 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 sdp/processors/datasets/youtube/merge_manifests.py diff --git a/dataset_configs/youtube/de.yaml b/dataset_configs/youtube/de.yaml index 451bdc43..57a084cd 100644 --- a/dataset_configs/youtube/de.yaml +++ b/dataset_configs/youtube/de.yaml @@ -1,4 +1,4 @@ -processors_to_run: "0:" +processors_to_run: "9:" base_dir: "/ws/test_subset" workspace_dir: "/ws/test_subset_out" lang: de @@ -65,10 +65,17 @@ processors: output_manifest_file: ${workspace_dir}/manifest9.json pretrained_model: nvidia/stt_${lang}_fastconformer_hybrid_large_pc batch_size: 64 - devices: 4 - + devices: 2 + + ## Merge manifests + - _target_: sdp.processors.datasets.youtube.MergeManifests + input_manifest_file: ${workspace_dir}/manifest8.json + input_manifest_file2: ${workspace_dir}/manifest9.json + output_manifest_file: ${workspace_dir}/manifest10.json + key_field: audio_filepath + fields_to_merge: + - {"pred_text" : "pred_text_pc"} - diff --git a/sdp/processors/datasets/youtube/__init__.py b/sdp/processors/datasets/youtube/__init__.py index 8ee20226..119ac1ca 100644 --- a/sdp/processors/datasets/youtube/__init__.py +++ b/sdp/processors/datasets/youtube/__init__.py @@ -14,4 +14,5 @@ from .create_initial_manifest import CreateInitialManifest from .utils import parse_srt -from .aggregate_segments import * \ No newline at end of file +from .aggregate_segments import * +from .merge_manifests import MergeManifests \ No newline at end of file diff --git a/sdp/processors/datasets/youtube/merge_manifests.py b/sdp/processors/datasets/youtube/merge_manifests.py new file mode 100644 index 00000000..0860c429 --- /dev/null +++ b/sdp/processors/datasets/youtube/merge_manifests.py @@ -0,0 +1,35 @@ +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +import json + +class MergeManifests(BaseParallelProcessor): + def __init__( + self, input_manifest_file2: str, fields_to_merge: dict, key_field: str = "audio_filepath", + **kwargs + ): + super().__init__(**kwargs) + self.input_manifest_file2 = input_manifest_file2 + self.manifest2_dict = {} + self.fields_to_merge = fields_to_merge + self.key_field = key_field + + def prepare(self): + with open(self.input_manifest_file2, 'r') as manifest: + line = manifest.readline() + while line: + whole_sample = json.loads(line) + key_value = whole_sample[self.key_field] + sample = {} + for field_names_dict in self.fields_to_merge: + curr_field_name = list(field_names_dict.keys())[0] + sample[curr_field_name] = whole_sample[curr_field_name] + + self.manifest2_dict[key_value] = sample + line = manifest.readline() + + def process_dataset_entry(self, data_entry: dict): + key_value = data_entry[self.key_field] + for field_names_dict in self.fields_to_merge: + curr_field_name = list(field_names_dict.keys())[0] + new_field_name = field_names_dict[curr_field_name] + data_entry[new_field_name] = self.manifest2_dict[key_value][curr_field_name] + return [DataEntry(data=data_entry)] \ No newline at end of file From 32247203d52a5b84518cb68e4dcaf843d154f7c2 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 18 Mar 2024 11:13:49 +0000 Subject: [PATCH 3/6] Clean de.yaml pipeline config Signed-off-by: Sasha Meister --- dataset_configs/youtube/de.yaml | 179 +++++++++++++++++++++++++++++++- 1 file changed, 174 insertions(+), 5 deletions(-) diff --git a/dataset_configs/youtube/de.yaml b/dataset_configs/youtube/de.yaml index 57a084cd..f3a09037 100644 --- a/dataset_configs/youtube/de.yaml +++ b/dataset_configs/youtube/de.yaml @@ -1,9 +1,14 @@ -processors_to_run: "9:" -base_dir: "/ws/test_subset" -workspace_dir: "/ws/test_subset_out" +processors_to_run: "25:" +base_dir: "/ws/test_subset/" +workspace_dir: "/ws/test_subset_out/" + +# filters lang: de -min_duration: 1 -max_duration: 40 +min_duration: 1.0 +max_duration: 40.0 +max_wer: 75.0 +max_cer: 30.0 + processors: # Create initial manifests based on pairs of .opus audio + .srt transcript (with ground-truth timestamps) @@ -76,9 +81,173 @@ processors: fields_to_merge: - {"pred_text" : "pred_text_pc"} + # Filter out samples with empty pred_text_pc + - _target_: sdp.processors.DropIfRegexMatch + output_manifest_file: ${workspace_dir}/manifest11.json + text_key: pred_text_pc + regex_patterns: + - "^\\s*$" + + # Preprocess orig text for audio-based TN + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest12.json + duplicate_fields: {"orig_text" : "pre_normalized"} + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest13.json + text_key: pre_normalized + regex_params_list: + - {"pattern": '\\[hn]', "repl" : " "} + - {"pattern": "\\s+", "repl" : " "} + - {"pattern": "\\[", "repl" : " "} + - {"pattern": "\\]", "repl" : " "} + - {"pattern": "!", "repl" : "."} + - {"pattern": "\\)", "repl" : " "} + - {"pattern": "\\(", "repl" : " "} + - {"pattern": "“", "repl" : " "} + - {"pattern": "„", "repl" : " "} + - {"pattern": "–", "repl" : " "} + - {"pattern": ";", "repl" : ","} + - {"pattern": "'", "repl" : " "} + - {"pattern": "…", "repl" : "."} + - {"pattern": "«", "repl" : " "} + - {"pattern": "»", "repl" : " "} + - {"pattern": "’", "repl" : " "} + - {"pattern": "‘", "repl" : " "} + - {"pattern": "”", "repl" : " "} + - {"pattern": "—", "repl" : " "} + - {"pattern": "´", "repl" : " "} + - {"pattern": "″", "repl" : " "} + - {"pattern": "`", "repl" : " "} + - {"pattern": "\\|", "repl" : " "} + - {"pattern": "−", "repl" : " "} + - {"pattern": "‟", "repl" : " "} + - {"pattern": "‒", "repl" : " "} + - {"pattern": " ", "repl" : " "} + - {"pattern": "", "repl" : " "} + - {"pattern": "‐", "repl" : " "} + - {"pattern": "ʻ", "repl" : " "} + - {"pattern": "′", "repl" : " "} + - {"pattern": "\\\\", "repl" : " "} + - {"pattern": "^\\s?\\.\\.\\.", "repl" : ""} + - {"pattern": "\\s?\\.\\.\\.$", "repl" : "."} + + ## Remove extra space + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest14.json + text_key: pre_normalized + regex_params_list: + - {"pattern": "\\s+", "repl" : " "} + - {"pattern": "^\\s+", "repl" : ""} + - {"pattern": "\\s+$", "repl" : ""} + + ## Filter out samples out of Regex + - _target_: sdp.processors.DropIfNoneOfRegexMatch + output_manifest_file: ${workspace_dir}/manifest15.json + text_key: pre_normalized + regex_patterns: + - "^[ !#$%&'*+,\\-.0-9:=?ABCDEFGHIJKLMNOPQRSTUVWXYZ^_abcdefghijklmnopqrstuvwxyz{}~£¥°²³µÄÖÜßäöüμω₩€/]+$" + # Run audio based TN + - _target_: sdp.processors.datasets.commoncrawl.Subprocess + output_manifest_file: ${workspace_dir}/manifest16.json + input_manifest_arg: "--manifest" + output_manifest_arg: "--output_filename" + arg_separator: "=" + cmd: "python /ws/NeMo-text-processing/nemo_text_processing/text_normalization/normalize_with_audio.py \ + --language=${lang} --n_jobs=-1 --batch_size=600 --manifest_text_field=pre_normalized --manifest_asr_pred_field=pred_text_pc \ + --cache_dir=${workspace_dir}/cache \ + --whitelist=/ws/NeMo-text-processing/nemo_text_processing/text_normalization/${lang}/data/whitelist.tsv" + # Post-normalization processing + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest17.json + duplicate_fields: {"normalized" : "post_normalized"} + ## Extra chars removing from normalized text + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest18.json + text_key: post_normalized + regex_params_list: + - {"pattern": "['\\-:{}\\/]", "repl" : " "} + - {"pattern": "!", "repl" : "."} + - {"pattern": "\\s+", "repl" : " "} + - {"pattern": "^\\s+", "repl" : ""} + - {"pattern": "\\s+$", "repl" : ""} + + ## Remove samples with chars out of list (letters, comma, period, question mark, space) + - _target_: sdp.processors.DropIfNoneOfRegexMatch + output_manifest_file: ${workspace_dir}/manifest19.json + text_key: post_normalized + regex_patterns: + - "^[a-zA-ZäÄöÖüÜß,\\.?\\s]+$" + + # Create text field with lowercased clean "post_normalized" + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest20.json + duplicate_fields: {"post_normalized" : "text"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest21.json + text_key: "text" + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest22.json + text_key: "text" + regex_params_list: + - {"pattern": "[\\.\\?\\,]", "repl" : " "} + - {"pattern": "\\s+", "repl" : " "} + - {"pattern": "^\\s+", "repl" : ""} + - {"pattern": "\\s+$", "repl" : ""} + + # Create pred_text field with lowercased clean "pred_text_pc" + - _target_: sdp.processors.DuplicateFields + output_manifest_file: ${workspace_dir}/manifest23.json + duplicate_fields: {"pred_text_pc" : "pred_text"} + + - _target_: sdp.processors.SubMakeLowercase + output_manifest_file: ${workspace_dir}/manifest24.json + text_key: "pred_text" + + - _target_: sdp.processors.SubRegex + output_manifest_file: ${workspace_dir}/manifest25.json + text_key: "pred_text" + regex_params_list: + - {"pattern": "[\\.\\?\\,]", "repl" : " "} + - {"pattern": "\\s+", "repl" : " "} + - {"pattern": "^\\s+", "repl" : ""} + - {"pattern": "\\s+$", "repl" : ""} + + # Filtration + - _target_: sdp.processors.DropHighCER + output_manifest_file: ${workspace_dir}/manifest26.json + cer_threshold: ${max_cer} + text_key: "text" + pred_text_key: "pred_text" + + - _target_: sdp.processors.DropHighWER + output_manifest_file: ${workspace_dir}/manifest27.json + wer_threshold: ${max_wer} + text_key: "text" + pred_text_key: "pred_text" + + # Finalization + - _target_: sdp.processors.KeepOnlySpecifiedFields + output_manifest_file: ${workspace_dir}/manifest28.json + fields_to_keep: ["audio_filepath", "duration", "post_normalized"] + + - _target_: sdp.processors.RenameFields + output_manifest_file: ${workspace_dir}/manifest29.json + rename_fields: {"post_normalized":"text"} + + - _target_: sdp.processors.datasets.commoncrawl.CopyFiles + file_field: audio_filepath + path_to_copy: ${workspace_dir}/clean_data/audio/ + path_levels: 1 + - _target_: sdp.processors.datasets.commoncrawl.DropAbsPath + output_manifest_file: ${workspace_dir}/clean_data/${lang}_manifest.json + path_key: audio_filepath + abs_path_to_drop: ${workspace_dir} From 595cf660b57556f81ba47ec56e0ab1cc9c1151f5 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 18 Mar 2024 16:57:46 +0000 Subject: [PATCH 4/6] Fix Lang2Iso Signed-off-by: Sasha Meister --- dataset_configs/youtube/de.yaml | 2 +- sdp/processors/datasets/commoncrawl/commoncrawl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dataset_configs/youtube/de.yaml b/dataset_configs/youtube/de.yaml index f3a09037..45330933 100644 --- a/dataset_configs/youtube/de.yaml +++ b/dataset_configs/youtube/de.yaml @@ -1,4 +1,4 @@ -processors_to_run: "25:" +processors_to_run: "1:5" base_dir: "/ws/test_subset/" workspace_dir: "/ws/test_subset_out/" diff --git a/sdp/processors/datasets/commoncrawl/commoncrawl.py b/sdp/processors/datasets/commoncrawl/commoncrawl.py index 045949fa..f879a81d 100644 --- a/sdp/processors/datasets/commoncrawl/commoncrawl.py +++ b/sdp/processors/datasets/commoncrawl/commoncrawl.py @@ -1107,7 +1107,7 @@ def __init__( } def process_dataset_entry(self, data_entry): - data_entry[self.output_lang_key] = self.iso_m[data_entry[self.input_lang_key]] + data_entry[self.output_lang_key] = self.iso_m.get(data_entry[self.input_lang_key], None) return [DataEntry(data=data_entry)] From 9d0fa8b127d456bec485f2a3486109c98333b237 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 18 Mar 2024 17:38:32 +0000 Subject: [PATCH 5/6] fix typo --- dataset_configs/youtube/de.yaml | 2 +- sdp/processors/datasets/youtube/aggregate_segments.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dataset_configs/youtube/de.yaml b/dataset_configs/youtube/de.yaml index 45330933..4f0ca9d9 100644 --- a/dataset_configs/youtube/de.yaml +++ b/dataset_configs/youtube/de.yaml @@ -1,4 +1,4 @@ -processors_to_run: "1:5" +processors_to_run: "0:" base_dir: "/ws/test_subset/" workspace_dir: "/ws/test_subset_out/" diff --git a/sdp/processors/datasets/youtube/aggregate_segments.py b/sdp/processors/datasets/youtube/aggregate_segments.py index f5aaef07..b60a3c09 100644 --- a/sdp/processors/datasets/youtube/aggregate_segments.py +++ b/sdp/processors/datasets/youtube/aggregate_segments.py @@ -35,14 +35,14 @@ def prepare(self): def process_dataset_entry(self, data_entry: dict): sample_id = data_entry['sample_id'] - segmnets = data_entry['segments'] + segments = data_entry['segments'] agg_segments = [] - first_segment = RawSegment(**segmnets[0]) + first_segment = RawSegment(**segments[0]) agg_segment = AggregatedSegment(segment=first_segment, segment_id=1, sample_id=sample_id, output_audio_dir = self.output_segments_audio_dir) - for segment in segmnets[1 : ]: + for segment in segments[1 : ]: segment = RawSegment(**segment) if (not agg_segment.duration_match or From 6ba9856881431e20f7c2952b01925f29a1ba2646 Mon Sep 17 00:00:00 2001 From: Sasha Meister Date: Mon, 18 Mar 2024 17:40:29 +0000 Subject: [PATCH 6/6] fix empty list error - IndexError: list index out of range --- sdp/processors/datasets/youtube/aggregate_segments.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdp/processors/datasets/youtube/aggregate_segments.py b/sdp/processors/datasets/youtube/aggregate_segments.py index b60a3c09..d97524c4 100644 --- a/sdp/processors/datasets/youtube/aggregate_segments.py +++ b/sdp/processors/datasets/youtube/aggregate_segments.py @@ -38,6 +38,9 @@ def process_dataset_entry(self, data_entry: dict): segments = data_entry['segments'] agg_segments = [] + if len(segments) == 0: + return agg_segments + first_segment = RawSegment(**segments[0]) agg_segment = AggregatedSegment(segment=first_segment, segment_id=1, sample_id=sample_id, output_audio_dir = self.output_segments_audio_dir)