diff --git a/dataset_configs/portuguese/unlabeled/config.yaml b/dataset_configs/portuguese/unlabeled/config.yaml index b10e64ac..a2aa0740 100644 --- a/dataset_configs/portuguese/unlabeled/config.yaml +++ b/dataset_configs/portuguese/unlabeled/config.yaml @@ -75,21 +75,21 @@ processors: output_manifest_file: ${manifest_dir}/vad input_manifest_arg: "manifest_filepath" output_manifest_arg: "output_dir" - cmd: 'python sdp/processors/nemo/speech_to_text_with_vad.py audio_type=wav vad_model=vad_multilingual_frame_marblenet vad_config=sdp/processors/nemo/frame_vad_infer_postprocess.yaml' + cmd: 'python sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py audio_type=wav vad_model=vad_multilingual_frame_marblenet vad_config=sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml' - _target_: sdp.processors.RenameFields input_manifest_file: ${manifest_dir}/vad/temp_manifest_vad_rttm-onset0.3-offset0.3-pad_onset0.2-pad_offset0.2-min_duration_on0.2-min_duration_off0.2-filter_speech_firstTrue.json output_manifest_file: ${manifest_dir}/manifest7.json rename_fields: {"audio_filepath":"source_filepath"} - - _target_: sdp.processors.nemo.rttm.GetRttmSegments + - _target_: sdp.processors.GetRttmSegments output_manifest_file: ${manifest_dir}/manifest8.json rttm_key: rttm_file output_file_key: audio_segments duration_key: duration duration_threshold: 20.0 - - _target_: sdp.processors.nemo.rttm.SplitAudioFile + - _target_: sdp.processors.SplitAudioFile output_manifest_file: ${manifest_dir}/manifest9.json splited_audio_dir: ${workspace_dir}/splited_wavs/ segments_key: audio_segments diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index 7f8eef4b..4c8f5aff 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -184,9 +184,6 @@ used in the downstream processing for additional enhancement or filtering. .. autodata:: sdp.processors.ASRTransformers :annotation: -.. autodata:: sdp.processors.EstimateBandwidth - :annotation: - .. autodata:: sdp.processors.tts.pyannote.PyAnnoteDiarizationAndOverlapDetection :annotation: @@ -202,6 +199,15 @@ used in the downstream processing for additional enhancement or filtering. .. autodata:: sdp.processors.tts.metrics.BandwidthEstimationProcessor :annotation: +.. autodata:: sdp.processors.FasterWhisperInference + :annotation: + +.. autodata:: sdp.processors.vLLMInference + :annotation: + +.. autodata:: sdp.processors.AudioLid + :annotation: + Text-only processors #################### @@ -246,6 +252,9 @@ Data modifications .. autodata:: sdp.processors.ListToEntries :annotation: +.. autodata:: sdp.processors.EstimateBandwidth + :annotation: + Data filtering '''''''''''''' @@ -364,6 +373,18 @@ Data filtering .. autodata:: sdp.processors.RejectIfBanned :annotation: +.. autodata:: sdp.processors.DetectWhisperHallucinationFeatures + :annotation: + +.. autodata:: sdp.processors.CleanQwenGeneration + :annotation: + +.. autodata:: sdp.processors.GetRttmSegments + :annotation: + +.. autodata:: sdp.processors.SplitAudioFile + :annotation: + Miscellaneous ############# diff --git a/requirements/main.txt b/requirements/main.txt index 99c030b4..9b8858d4 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -25,3 +25,7 @@ datasets>=2.14.0,<3.0.0 # for some processers, additionally https://github.com/NVIDIA/NeMo is required # for some processers, additionally nemo_text_processing is required # for mcv: apt-get update && apt-get upgrade -y && apt-get install -y sox libsox-fmt-all +# for FasterWhisperInference processor is required: + # pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper + # export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` +# for vLLMInference processor is required: pip install "optree>=0.13.0" vllm \ No newline at end of file diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index ab0d05ed..b25ce04d 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -81,7 +81,6 @@ from sdp.processors.huggingface.create_initial_manifest import ( CreateInitialManifestHuggingFace, ) -from sdp.processors.huggingface.speech_recognition import ASRTransformers from sdp.processors.modify_manifest.common import ( AddConstantFields, ApplyInnerJoin, @@ -119,6 +118,7 @@ SubRegex, ListToEntries, LambdaExpression, + EstimateBandwidth, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, @@ -141,6 +141,16 @@ from sdp.processors.modify_manifest.make_letters_uppercase_after_period import ( MakeLettersUppercaseAfterPeriod, ) +from sdp.processors.inference.asr.nemo.asr_inference import ASRInference +from sdp.processors.inference.asr.nemo.lid_inference import AudioLid +from sdp.processors.inference.asr.faster_whisper.faster_whisper_inference import FasterWhisperInference +from sdp.processors.inference.asr.transformers.speech_recognition import ASRTransformers +from sdp.processors.inference.asr.utils.whisper_hallucinations import DetectWhisperHallucinationFeatures +from sdp.processors.inference.asr.utils.rttm import GetRttmSegments, SplitAudioFile +from sdp.processors.inference.nlp.nemo.pc_inference import PCInference +from sdp.processors.inference.llm.vllm.vllm import vLLMInference +from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration + from sdp.processors.manage_files.convert_audio import ( FfmpegConvert, SoxConvert, @@ -151,10 +161,7 @@ from sdp.processors.manage_files.remove import ( RemoveFiles, ) -from sdp.processors.nemo.asr_inference import ASRInference -from sdp.processors.nemo.estimate_bandwidth import EstimateBandwidth -from sdp.processors.nemo.lid_inference import AudioLid -from sdp.processors.nemo.pc_inference import PCInference + from sdp.processors.toloka.accept_if import AcceptIfWERLess from sdp.processors.toloka.create_pool import CreateTolokaPool from sdp.processors.toloka.create_project import CreateTolokaProject diff --git a/sdp/processors/inference/asr/faster_whisper/faster_whisper_inference.py b/sdp/processors/inference/asr/faster_whisper/faster_whisper_inference.py new file mode 100644 index 00000000..b8e419d9 --- /dev/null +++ b/sdp/processors/inference/asr/faster_whisper/faster_whisper_inference.py @@ -0,0 +1,495 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import json +from copy import deepcopy +from tqdm import tqdm +import librosa +from dataclasses import dataclass, field, asdict, is_dataclass +from typing import List, Optional, Any, Dict +from omegaconf import OmegaConf, MISSING + +from sdp.logging import logger +from multiprocessing import Pool +import traceback + +from sdp.processors.base_processor import BaseProcessor + +""" +This module implements `FasterWhisperInference`, a multiprocessing-compatible audio transcription +processor using the FasterWhisper library. + +It reads an input manifest, runs inference on available devices (GPU/CPU), and outputs predictions, +including optional timestamp and word-level information. + +Classes: + - InferenceConfig: Configuration for whisper decoding and inference behavior. + - ModelConfig: Configuration for the Whisper model loading. + - DatasetConfig: Configuration for dataset input/output handling. + - WhisperInferenceConfig: Combined config container. + - FasterWhisperInference: Main processor class for transcribing input audio files in parallel. +""" + +def serialize(obj): + """ + Recursively serializes a dataclass, list, or dict to a JSON-compatible structure. + + Args: + obj (Any): Object to serialize (dataclass, list, or dict). + + Returns: + JSON-serializable version of the object. + """ + if is_dataclass(obj): + return asdict(obj) + elif isinstance(obj, list): + return [serialize(item) for item in obj] + elif isinstance(obj, dict): + return {k: serialize(v) for k, v in obj.items()} + return obj + +@dataclass +class InferenceConfig: + """ + Configuration for FasterWhisper inference. + To know more about the parameters, refer to the documentation: + https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/transcribe.py#L303 + """ + language: Optional[str] = None + task: str = "transcribe" + log_progress: bool = False + beam_size: int = 5 + best_of: int = 5 + patience: float = 1 + length_penalty: float = 1 + repetition_penalty: float = 1 + no_repeat_ngram_size: int = 0 + temperature: List[float] = field(default_factory=lambda: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + compression_ratio_threshold: Optional[float] = 2.4 + log_prob_threshold: Optional[float] = -1.0 + no_speech_threshold: Optional[float] = 0.6 + condition_on_previous_text: bool = True + prompt_reset_on_temperature: float = 0.5 + initial_prompt: Optional[Any] = None + prefix: Optional[str] = None + suppress_blank: bool = True + suppress_tokens: Optional[List[int]] = field(default_factory=lambda: [-1]) + without_timestamps: bool = True + max_initial_timestamp: float = 1.0 + word_timestamps: bool = False + prepend_punctuations: str = "\"'“¿([{-" + append_punctuations: str = "\"'.。,,!!??::”)]}、" + multilingual: bool = False + vad_filter: bool = True + + try: + from faster_whisper.vad import VadOptions + vad_parameters: Optional[VadOptions] = None + except ModuleNotFoundError: + pass + + max_new_tokens: Optional[int] = None + chunk_length: Optional[int] = None + clip_timestamps: Optional[Any] = "0" + hallucination_silence_threshold: Optional[float] = None + hotwords: Optional[str] = None + language_detection_threshold: Optional[float] = 0.5 + language_detection_segments: int = 1 + +@dataclass +class ModelConfig: + model_size_or_path: str = MISSING + device: str = "auto" + device_index: Optional[List[int]] = field(default_factory=lambda: [0]) + compute_type: str = "default" + cpu_threads: int = 0 + num_workers: int = 1 + download_root: Optional[str] = None + local_files_only: bool = False + files: Optional[dict] = None + + +@dataclass +class DatasetConfig: + manifest_filepath: str = MISSING + output_dir: str = MISSING + skip_corrupted: bool = False + save_timestamps_separately: bool = True + offset: bool = False + + +@dataclass +class WhisperInferenceConfig: + model: ModelConfig = field(default_factory=lambda: ModelConfig()) + dataset: DatasetConfig = field(default_factory=lambda: DatasetConfig()) + inference: InferenceConfig = field(default_factory=lambda: InferenceConfig()) + + +class FasterWhisperInference(BaseProcessor): + """ + Processor that performs parallel audio transcription using the FasterWhisper model. + + This class reads a manifest of audio files, transcribes them using multiprocessing + (each device or CPU thread handles a portion), and writes results in a NeMo-compatible manifest. + + Args: + input_manifest_file (str): Path to the input manifest. + output_manifest_file (Optional[str]): Path to the output manifest (default: `/predictions_all.json`). + model_size_or_path (str): Whisper model path or model name (e.g., 'base', 'medium'). + device (str): Device type to use ('auto', 'cuda', or 'cpu'). + num_devices (int): Number of workers/devices to use (-1 = all available). + model_download_root (Optional[str]): Directory where model checkpoints will be downloaded. + output_dir (Optional[str]): Directory to store output predictions and timestamps. + skip_corrupted_audios (bool): Whether to skip audio files that raise exceptions. + save_timestamps_separately (bool): If True, saves segment/word timestamps as separate files. + slice_by_offset (bool): If True, slices audio using offset/duration before inference. + inference (Optional[Dict]): Additional inference parameters for Whisper. + language_detection_only (bool): If True, only perform language detection. + in_memory_chunksize (int): Number of samples to load per worker at once. + audio_filepath_field (str): Name of the field in manifest pointing to audio path. + + Returns: + A final merged manifest file where each line corresponds to the transcription result of an input audio sample. + The manifest is assembled from multiple per-worker (rank) manifest files, each produced by a separate device or process. + + Each entry contains the following fields: + + - ``language`` (str, optional): Detected language (if language detection is enabled). + - ``language_probability`` (float, optional): Confidence score of detected language. + - ``pred_text`` (str): Final transcribed text obtained by concatenating all segment texts. + + One of the following timestamp representations will also be included, depending on the value of `save_timestamps_separately`: + + - If ``save_timestamps_separately=False``: + - ``segments`` (List[Dict]): List of segment dictionaries with start/end timestamps and transcribed text. + + - If ``save_timestamps_separately=True``: + - ``segments`` (str): Path to a JSON file containing segment-level timestamps. + - ``words`` (str, optional): Path to a JSON file containing word-level timestamps (if `word_timestamps=True`). + + The final combined manifest is written to ``output_manifest_file``, which defaults to ``/predictions_all.json``. + + .. note:: + Make sure to install the following packages before using this processor: + + pip install pytorch-lightning nvidia-cublas-cu12 nvidia-cudnn-cu12==9.* faster_whisper + + Additionally, ensure that the dynamic libraries for cuBLAS and cuDNN are discoverable at runtime: + + export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'` + + This is required for CUDA backend components to function correctly when using FasterWhisper with GPU acceleration. + + For detailed configuration options and advanced usage of FasterWhisper, refer to the official repository: + https://github.com/SYSTRAN/faster-whisper + + Example: + .. code-block:: yaml + + - _target_: sdp.processors.FasterWhisperInference + input_manifest_file: /your/input/manifest.json + output_manifest_file: /your/output/manifest.json + model_size_or_path: base + """ + def __init__(self, + input_manifest_file: str, + output_manifest_file: Optional[str] = None, + model_size_or_path: str = "base", + device: str = "auto", + num_devices: int = -1, + compute_type: str = "default", + model_download_root: Optional[str] = None, + output_dir: Optional[str] = None, + skip_corrupted_audios: bool = False, + save_timestamps_separately: bool = True, + slice_by_offset: bool = False, + inference: Optional[Dict] = {}, + language_detection_only: bool = False, + in_memory_chunksize: int = 100000, + audio_filepath_field: str = 'audio_filepath', + ): + + super().__init__(input_manifest_file = input_manifest_file, + output_manifest_file = output_manifest_file, + ) + + #DatasetConfig setup + if not self.output_manifest_file and not output_dir: + raise ValueError("Either `output_manifest_file` or `output_dir` must be provided.") + + if not output_dir: + output_dir = os.path.splitext(self.output_manifest_file)[0] + + if not self.output_manifest_file: + self.output_manifest_file = os.path.join(output_dir, 'predictions_all.json') + + dataset_cfg = DatasetConfig(manifest_filepath = self.input_manifest_file, + output_dir = output_dir, + skip_corrupted = skip_corrupted_audios, + save_timestamps_separately = save_timestamps_separately, + offset = slice_by_offset) + + #InferenceConfig setup + inference_cfg = OmegaConf.structured(InferenceConfig(**inference)) + + #ModelConfig setup + device, device_ids = self.setup_devices(device, num_devices) + self.device_ids = device_ids + model_cfg = ModelConfig(model_size_or_path = model_size_or_path, + device = device, compute_type = compute_type, + download_root = model_download_root) + + #GeneralConfig setup + self.config = WhisperInferenceConfig(model=model_cfg, + dataset=dataset_cfg, + inference=inference_cfg, + ) + + #Additional args + self.audio_filepath_field = audio_filepath_field + self.language_detection_only = language_detection_only + self.in_memory_chunksize = in_memory_chunksize + + @staticmethod + def setup_devices(device: str = "auto", num_devices: int = -1): + """ + Determines device type and number of workers to use for inference. + + Returns: + Tuple[str, List[int]]: Selected device type and list of device indices. + """ + try: + import torch + TORCH_AVAILABLE = True + except ImportError: + TORCH_AVAILABLE = False + + if device in ["cuda", "auto"] and TORCH_AVAILABLE: + cuda_available_workers = torch.cuda.device_count() + if cuda_available_workers == 0: + if device == "cuda": + raise RuntimeError("GPU was requested, but no CUDA devices are available.") + else: + logger.warning("No GPU found in auto mode — switching to CPU.") + device = "cpu" + else: + logger.info("CUDA devices found. GPU will be used as workers.") + device = "cuda" + elif device == "cpu": + logger.info("CPU will be used as workers.") + else: + raise ValueError(f"Invalid device type: {device}") + + if device == "cuda": + max_available_workers = cuda_available_workers + else: + max_available_workers = os.cpu_count() + + if num_devices < -1 or num_devices == 0: + raise ValueError(f"Invalid number of workers: {num_devices}.") + elif num_devices == -1: + workers = max_available_workers + logger.info(f"Using {workers} {device.upper()} worker(s).") + elif num_devices > max_available_workers: + workers = max_available_workers + logger.warning(f"Requested {num_devices} {device.upper()} workers, but only {max_available_workers} {device.upper()} available — using {workers}.") + else: + workers = num_devices + logger.info(f"Using {workers} {device.upper()} worker(s).") + + device_ids = list(range(workers)) + return device, device_ids + + def prepare(self): + """ + Creates output directories required for storing prediction and timestamp files. + """ + os.makedirs(self.config.dataset.output_dir, exist_ok = True) + if self.config.dataset.save_timestamps_separately: + os.makedirs(os.path.join(self.config.dataset.output_dir, "segments"), exist_ok = True) + if self.config.inference.word_timestamps: + os.makedirs(os.path.join(self.config.dataset.output_dir, "words"), exist_ok = True) + + def _chunk_manifest(self): + """Splits the manifest into smaller chunks defined by ``in_memory_chunksize``.""" + manifest_chunk = [] + for idx, data_entry in enumerate(self.read_manifest(), 1): + manifest_chunk.append(data_entry) + if idx % self.in_memory_chunksize == 0: + yield manifest_chunk + manifest_chunk = [] + if manifest_chunk: + yield manifest_chunk + + def read_manifest(self): + """Reading the input manifest file.""" + if not self.input_manifest_file: + raise NotImplementedError("Override this method if no input manifest file is used") + with open(self.input_manifest_file, "rt", encoding="utf8") as fin: + for line in fin: + yield json.loads(line) + + def _get_entries_for_device(self, device_id: int): + """ + Yields manifest entries assigned to a given device. + + Uses round-robin assignment of sorted entries by duration. + """ + for chunk in self._chunk_manifest(): + chunk.sort(key=lambda entry: entry["duration"]) + batch = chunk[device_id::len(self.device_ids)] + for entry in batch: + yield entry + + def _get_audio_segment(self, audio_filepath: str, offset: float, duration: float): + """Loads a segment of audio based on offset and duration.""" + audio, sr = librosa.load(audio_filepath, sr=None) + start_sample = int(offset * sr) + end_sample = int((offset + duration) * sr) + audio_segment = audio[start_sample : end_sample] + return audio_segment + + def _write_timestamps(self, filename: str, segments: List[Dict]): + """Saves timestamp information (segments and optionally word-level) to separate files.""" + + output_segments_filepath = os.path.join(self.config.dataset.output_dir, 'segments', f'{filename}.json') + sample_words = [] + with open(output_segments_filepath, 'w', encoding = 'utf8') as output_manifest: + for segment in segments: + words = segment.pop('words') + if self.config.inference.word_timestamps: + for word in words: + word['segment_id'] = segment['id'] + sample_words.append(word) + + line = json.dumps(segment) + output_manifest.writelines(f'{line}\n') + + def _write_words(words: List[Dict]): + output_manifest_filepath = os.path.join(self.config.dataset.output_dir, 'words', f'{filename}.json') + with open(output_manifest_filepath, 'w', encoding = 'utf8') as output_manifest: + for word in words: + line = json.dumps(word) + output_manifest.writelines(f'{line}\n') + return output_manifest_filepath + + output_words_filepath = None + if self.config.inference.word_timestamps: + output_words_filepath = _write_words(output_words_filepath, sample_words) + + return dict(segments = output_segments_filepath, words = output_words_filepath) + + def _transcribe(self, device_id: int): + """"" + Transcribes all samples assigned to a given device. + + Loads the Whisper model, reads samples, performs language detection or full transcription, + and writes predictions to a device-specific output file. + """ + + from faster_whisper import WhisperModel + from faster_whisper.audio import decode_audio + + model_cfg = deepcopy(self.config.model) + model_cfg.device_index = [device_id] if model_cfg.device == "cuda" else [0] + model = WhisperModel(**asdict(model_cfg)) + + inference_cfg = OmegaConf.to_container(self.config.inference, resolve=True) + + output_manifest_file = os.path.join(self.config.dataset.output_dir, f'prediction_{device_id}.json') + + with open(output_manifest_file, 'w', encoding='utf8') as fout: + for entry in tqdm(self._get_entries_for_device(device_id), desc = f"Transcribing ({self.config.model.device.upper()} {device_id})"): + audio_filepath = entry[self.audio_filepath_field] + + if self.language_detection_only: + try: + audio = decode_audio(audio_filepath) + features = model.feature_extractor(audio) + language, language_probability, all_language_probs = model.detect_language(features = features, + vad_filter = self.config.inference.vad_filter, + vad_parameters = self.config.inference.vad_parameters, + language_detection_segments = self.config.inference.language_detection_segments, + language_detection_threshold = self.config.inference.language_detection_threshold) + except Exception: + if self.config.dataset.skip_corrupted: + logger.warning(f"Sample can't be processed: {audio_filepath}. Skipping.") + continue + else: + traceback.print_exc() + exit(1) + + result = dict(language = language, language_probability = language_probability) + else: + try: + if self.config.dataset.offset: + audio = self._get_audio_segment(audio_filepath, entry['offset'], entry['duration']) + else: + audio = audio_filepath + + segments, info = model.transcribe(audio = audio, **inference_cfg) + + except Exception: + if self.config.dataset.skip_corrupted: + logger.warning(f"Sample can't be transcribed: {audio_filepath}. Skipping.") + continue + else: + traceback.print_exc() + exit(1) + + result = serialize(info) + result.pop('all_language_probs', None) + result.pop('transcription_options', None) + result.pop('vad_options', None) + + _segments = [] + for segment in segments: + _segments.append(serialize(segment)) + segments = _segments + + if self.config.dataset.save_timestamps_separately: + audio_filename = os.path.splitext(os.path.basename(audio_filepath))[0] + timestamps_filepaths = self._write_timestamps(audio_filename, segments) + result.update(timestamps_filepaths) + else: + result['segments'] = segments + + pred_text = ' '.join(str(segment['text']) for segment in segments).strip() + result['pred_text'] = pred_text + + entry.update(result) + fout.write(json.dumps(entry, ensure_ascii=False) + "\n") + fout.flush() + + return output_manifest_file + + def process(self): + """ + Main entry point for the processor. + + Prepares directories, distributes transcription tasks across devices, and aggregates results + into the final output manifest file. + """ + self.prepare() + + with Pool(processes=len(self.device_ids)) as pool: + output_rank_manifests = pool.map(self._transcribe, self.device_ids) + + with open(self.output_manifest_file, 'w', encoding='utf8') as output_manifest: + for rank_manifest_filepath in tqdm(output_rank_manifests, desc = "Writing output manifest.."): + with open(rank_manifest_filepath, 'r', encoding='utf8') as rank_manifest: + for line in rank_manifest: + output_manifest.writelines(line) \ No newline at end of file diff --git a/sdp/processors/nemo/asr_inference.py b/sdp/processors/inference/asr/nemo/asr_inference.py similarity index 97% rename from sdp/processors/nemo/asr_inference.py rename to sdp/processors/inference/asr/nemo/asr_inference.py index 4359f320..a826fc54 100644 --- a/sdp/processors/nemo/asr_inference.py +++ b/sdp/processors/inference/asr/nemo/asr_inference.py @@ -49,7 +49,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.script_path = Path(__file__).parents[1] / "nemo" / "transcribe_speech.py" + self.script_path = Path(__file__).parent / "utils" / "transcribe_speech.py" self.pretrained_model = pretrained_model self.batch_size = batch_size diff --git a/sdp/processors/nemo/lid_inference.py b/sdp/processors/inference/asr/nemo/lid_inference.py similarity index 100% rename from sdp/processors/nemo/lid_inference.py rename to sdp/processors/inference/asr/nemo/lid_inference.py diff --git a/sdp/processors/nemo/frame_vad_infer_postprocess.yaml b/sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml similarity index 100% rename from sdp/processors/nemo/frame_vad_infer_postprocess.yaml rename to sdp/processors/inference/asr/nemo/utils/frame_vad_infer_postprocess.yaml diff --git a/sdp/processors/nemo/speech_to_text_with_vad.py b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py similarity index 99% rename from sdp/processors/nemo/speech_to_text_with_vad.py rename to sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py index 6fdd183d..05c28cae 100644 --- a/sdp/processors/nemo/speech_to_text_with_vad.py +++ b/sdp/processors/inference/asr/nemo/utils/speech_to_text_with_vad.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/sdp/processors/nemo/transcribe_speech.py b/sdp/processors/inference/asr/nemo/utils/transcribe_speech.py similarity index 100% rename from sdp/processors/nemo/transcribe_speech.py rename to sdp/processors/inference/asr/nemo/utils/transcribe_speech.py diff --git a/sdp/processors/huggingface/speech_recognition.py b/sdp/processors/inference/asr/transformers/speech_recognition.py similarity index 100% rename from sdp/processors/huggingface/speech_recognition.py rename to sdp/processors/inference/asr/transformers/speech_recognition.py diff --git a/sdp/processors/nemo/rttm.py b/sdp/processors/inference/asr/utils/rttm.py similarity index 98% rename from sdp/processors/nemo/rttm.py rename to sdp/processors/inference/asr/utils/rttm.py index 394014ca..892f1803 100644 --- a/sdp/processors/nemo/rttm.py +++ b/sdp/processors/inference/asr/utils/rttm.py @@ -1,12 +1,10 @@ import os -from typing import Dict, List, Union +from typing import Dict import soundfile as sf -from tqdm import tqdm from sdp.logging import logger from sdp.processors.base_processor import BaseParallelProcessor, DataEntry -from sdp.utils.common import load_manifest class GetRttmSegments(BaseParallelProcessor): diff --git a/sdp/processors/inference/asr/utils/whisper_hallucinations.py b/sdp/processors/inference/asr/utils/whisper_hallucinations.py new file mode 100644 index 00000000..c5ddf35c --- /dev/null +++ b/sdp/processors/inference/asr/utils/whisper_hallucinations.py @@ -0,0 +1,130 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry + + +class DetectWhisperHallucinationFeatures(BaseParallelProcessor): + """ + Computes hallucination-related features for ASR model outputs (e.g., Whisper transcripts). + + This processor analyzes the transcript text and flags common hallucination patterns by computing + boolean features such as: + - Repeated or low-diversity n-grams (`hall_repeated_ngrams`) + + **Example:** + + .. code-block:: text + + yes yes yes yes yes yes yes yes yes yes yes yes + + - Unusually long or disproportionately long words (`hall_long_word`) + + **Example:** + + .. code-block:: text + + short mid reallyreallyreallyreallyreallyreallyreallylong + + - Matches with known hallucinated phrases (`hall_frequent_single_word`) + + **Example:** + + .. code-block:: text + + lorem ipsum dolor sit amet + + It appends these features to each entry in the manifest for downstream filtering or analysis. + + Args: + common_hall_file (str): Path to a file with known hallucinated phrases, one per line. + unique_words_threshold (float): Maximum allowed share of unique words before marking as repeated. Default is 0.4. + long_word_threshold (int): Minimum character length for a word to be considered "long". Default is 25. + long_word_rel_threshold (float): Relative length ratio between the longest and second-longest word. Default is 3. + char_rate_threshold (float): [Unused in current logic, retained for compatibility]. Default is 4. + text_field (str): Key in the data entry that contains the transcript. Default is 'text'. + **kwargs: Additional keyword arguments passed to `BaseParallelProcessor`. + + Returns: + A manifest with additional boolean fields for hallucination detection. + """ + + def __init__( + self, + common_hall_file, + unique_words_threshold=0.4, + long_word_threshold=25, + long_word_rel_threshold=3, + char_rate_threshold=4, + text_field='text', + **kwargs, + ): + super().__init__(**kwargs) + self.unique_words_threshold = unique_words_threshold + self.long_word_threshold = long_word_threshold + self.long_word_rel_threshold = long_word_rel_threshold + self.char_rate_threshold = char_rate_threshold # Currently unused + self.text_field = text_field + + # Load common hallucination phrases into memory + with open(common_hall_file, 'r') as f: + self.common_hall_phrases = [line.strip() for line in f] + + def repeated_ngrams(self, words): + """ + Flags entries with low lexical diversity (i.e., repeated n-grams). + + Returns True if the fraction of unique words is below the threshold. + """ + unique_words_share = len(set(words)) / len(words) + return unique_words_share <= self.unique_words_threshold + + def long_word(self, words): + """ + Detects unusually long words or sharp differences in word lengths. + + Returns True if the longest word is above the absolute threshold or much longer + than the second-longest word. + """ + word_lengths = sorted([len(word) for word in words]) + + if word_lengths[-1] >= self.long_word_threshold: + return True + + if len(words) > 1: + diff = (word_lengths[-1] - word_lengths[-2]) / word_lengths[-2] + return diff >= self.long_word_rel_threshold + + return False + + def frequent_single_word(self, text): + """ + Checks if the cleaned transcript matches any known hallucinated phrase. + """ + cleaned_text = text.strip().replace('.', '').replace('?', '').replace('!', '') + return cleaned_text in self.common_hall_phrases + + def process_dataset_entry(self, data_entry): + """ + Processes a single manifest entry and appends hallucination features. + """ + text = data_entry[self.text_field] + words = text.split() + + # Compute hallucination indicators + data_entry['hall_repeated_ngrams'] = self.repeated_ngrams(words) + data_entry['hall_long_word'] = self.long_word(words) + data_entry['hall_frequent_single_word'] = self.frequent_single_word(text) + + return [DataEntry(data=data_entry)] \ No newline at end of file diff --git a/sdp/processors/inference/llm/utils/qwen_cleaning.py b/sdp/processors/inference/llm/utils/qwen_cleaning.py new file mode 100644 index 00000000..e3b78459 --- /dev/null +++ b/sdp/processors/inference/llm/utils/qwen_cleaning.py @@ -0,0 +1,131 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import string + +from sdp.logging import logger +from sdp.processors.base_processor import BaseParallelProcessor, DataEntry +from sdp.utils.metrics_computation import get_cer + + +class CleanQwenGeneration(BaseParallelProcessor): + """ + A processor that filters and post-processes model generations, replacing them with + reference text if they are considered low quality based on character error rate (CER) + and uppercase letter proportion. + + This processor is typically used after running a generation model (e.g., Qwen) to clean + up outputs and ensure alignment with reference transcriptions. + + Args: + cer_threshold (float): Maximum allowable character error rate (CER) between the + normalized generation and reference text. If exceeded, the generation is + replaced by the reference. + upper_case_threshold (float): Threshold for the proportion of uppercase letters in + the generation. If the ratio exceeds this value, the generation is replaced. + generation_field (str): Key in the input manifest for the model-generated text. + text_field (str): Key in the input manifest for the reference (target) text. + **kwargs: Additional arguments passed to the `BaseParallelProcessor`. + + Returns: + A manifest where each entry contains the cleaned generation in the specified + generation field. If a replacement occurred, it is recorded in the metrics. + + Metrics: + - 1 if the generation was replaced with the reference text. + - 0 if the generation was left as-is. + """ + + def __init__( + self, + cer_threshold=10, + upper_case_threshold=0.6, + generation_field='generation', + text_field='text', + **kwargs, + ): + super().__init__(**kwargs) + self.cer_threshold = cer_threshold + self.upper_case_threshold = upper_case_threshold + self.generation_field = generation_field + self.text_field = text_field + + def clean(self, generation): + """Remove template prompts and special tokens from model generation.""" + if "<|endoftext|>" in generation: + generation = generation.split("<|endoftext|>")[0] + + if "Input transcript:" in generation: + generation = generation.replace("Input transcript:", "") + + if "Output:" in generation: + generation = generation.replace("Output:", "") + + if "Output transcript:" in generation: + generation = generation.replace("Output transcript:", "") + + if "\n" in generation: + generation = generation.replace("\n", "") + + return generation + + def maybe_replace_with_text(self, generation, text): + """ + Determine if generation should be replaced with reference text based on + CER and uppercase ratio. + """ + chars = generation.replace(' ', '') + total_chars = len(chars) + + # Replace if generation is empty + if not total_chars: + return text, 1 + + # Replace if excessive capitalization + uppercase_count = sum(1 for char in chars if char.isupper()) + if uppercase_count / total_chars > self.upper_case_threshold: + return text, 1 + + # Normalize both strings for CER comparison + normalized_text = text.lower().translate(str.maketrans('', '', string.punctuation)).strip() + normalized_generation = generation.lower().translate(str.maketrans('', '', string.punctuation)).strip() + + if not normalized_text: + return text, 1 + + cer = get_cer(normalized_text, normalized_generation) + + if cer > self.cer_threshold: + return text, 1 + + return generation, 0 + + def process_dataset_entry(self, data_entry): + """Process a single entry from the manifest: clean and validate generation.""" + text = data_entry[self.text_field] + generation = data_entry[self.generation_field] + + generation = self.clean(generation) + maybe_replaced_generation, replaced = self.maybe_replace_with_text(generation, text) + + data_entry[self.generation_field] = maybe_replaced_generation.strip() + + return [DataEntry(data=data_entry, metrics=replaced)] + + def finalize(self, metrics): + """Log the total number of replaced generations.""" + logger.info( + f"Num of utterances that were replaced by text: {sum(metrics)}" + ) + super().finalize(metrics) \ No newline at end of file diff --git a/sdp/processors/inference/llm/vllm/vllm.py b/sdp/processors/inference/llm/vllm/vllm.py new file mode 100644 index 00000000..9ef9e89c --- /dev/null +++ b/sdp/processors/inference/llm/vllm/vllm.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml +import json +from tqdm import tqdm + +from sdp.processors.base_processor import BaseProcessor + + +class vLLMInference(BaseProcessor): + """ + A processor that performs inference using a vLLM model on entries from an input manifest. + + This class supports three prompt configuration modes: + - a static prompt template (`prompt`) + - a field in each entry containing the prompt (`prompt_field`) + - a YAML file containing the prompt structure (`prompt_file`) + + The prompts are converted into chat-style input using a tokenizer chat template, + passed to the vLLM engine for generation, and the results are written to an output manifest. + + Args: + prompt (str, optional): A fixed prompt used for all entries. + prompt_field (str, optional): The key in each entry that holds the prompt template. + prompt_file (str, optional): Path to a YAML file containing the prompt structure. + generation_field (str): Name of the output field to store generated text. Default is 'generation'. + model (dict): Parameters to initialize the vLLM model. + inference (dict): Sampling parameters passed to vLLM.SamplingParams. + apply_chat_template (dict): Arguments passed to the tokenizer's `apply_chat_template` method. + **kwargs: Passed to the BaseProcessor (includes `input_manifest_file` and `output_manifest_file`). + + Raises: + ValueError: If zero or more than one prompt configuration methods are used simultaneously. + + Returns: + A line-delimited JSON manifest where each entry includes the original fields + plus a field with the generated output. + + .. note:: + For detailed parameter options, refer to the following documentation: + + - model: https://docs.vllm.ai/en/latest/api/vllm/index.html#vllm.LLM + - inference: https://docs.vllm.ai/en/v0.6.4/dev/sampling_params.html + - apply_chat_template: https://huggingface.co/docs/transformers/main/en/chat_templating#applychattemplate + + Make sure to install `optree>=0.13.0` and `vllm` before using this processor: + pip install "optree>=0.13.0" vllm + + """ + + def __init__(self, + prompt: str = None, + prompt_field: str = None, + prompt_file: str = None, + generation_field: str = 'generation', + model: dict = {}, + inference: dict = {}, + apply_chat_template: dict = {}, + **kwargs): + + from vllm import SamplingParams + from transformers import AutoTokenizer + + super().__init__(**kwargs) + + self.prompt = prompt + self.prompt_field = prompt_field + self.generation_field = generation_field + + # Ensure that exactly one prompt method is used + prompt_args_counter = sum([prompt is not None, prompt_field is not None, prompt_file is not None]) + if prompt_args_counter < 1: + raise ValueError(f'One of `prompt`, `prompt_field` or `prompt_file` should be provided.') + elif prompt_args_counter > 1: + err = [] + if prompt: + err.append(f'`prompt` ({prompt})') + if prompt_field: + err.append(f'`prompt_field` ({prompt_field})') + if prompt_file: + err.append(f'`prompt_file` ({prompt_file})') + raise ValueError(f'Found more than one prompt values: {", ".join(err)}.') + + if prompt_file: + self.prompt = self._read_prompt_file(prompt_file) + + self.model_params = model + self.sampling_params = SamplingParams(**inference) + self.chat_template_params = apply_chat_template + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_params['model']) + + def _read_prompt_file(self, prompt_filepath): + """Read a YAML file with a chat-style prompt template.""" + with open(prompt_filepath, 'r') as prompt: + return yaml.safe_load(prompt) + + def get_entry_prompt(self, data_entry): + """Format the prompt for a single data entry using the chat template.""" + entry_chat = [] + prompt = self.prompt + if self.prompt_field: + prompt = data_entry[self.prompt_field] + + for role in prompt: + entry_chat.append(dict( + role=role, + content=prompt[role].format(**data_entry) + )) + + entry_prompt = self.tokenizer.apply_chat_template( + entry_chat, + **self.chat_template_params + ) + + return entry_prompt + + def process(self): + """Main processing function: reads entries, builds prompts, runs generation, writes results.""" + from vllm import LLM + + entries = [] + entry_prompts = [] + + # Read entries and generate prompts + with open(self.input_manifest_file, 'r', encoding='utf8') as fin: + for line in tqdm(fin, desc = "Building prompts: "): + data_entry = json.loads(line) + entries.append(data_entry) + entry_prompt = self.get_entry_prompt(data_entry) + entry_prompts.append(entry_prompt) + + # Run vLLM inference + llm = LLM(**self.model_params) + outputs = llm.generate(entry_prompts, self.sampling_params) + + # Write results to output manifest + with open(self.output_manifest_file, 'w', encoding='utf8') as fout: + for data_entry, output in tqdm(zip(entries, outputs), desc="Writing outputs: "): + data_entry[self.generation_field] = output.outputs[0].text + line = json.dumps(data_entry) + fout.writelines(f'{line}\n') \ No newline at end of file diff --git a/sdp/processors/nemo/pc_inference.py b/sdp/processors/inference/nlp/nemo/pc_inference.py similarity index 100% rename from sdp/processors/nemo/pc_inference.py rename to sdp/processors/inference/nlp/nemo/pc_inference.py diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 35b4d5b0..09a55011 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -22,6 +22,9 @@ from docx import Document from tqdm import tqdm import json +import librosa +import numpy as np +from pathlib import Path from sdp.logging import logger from sdp.processors.base_processor import ( @@ -1231,4 +1234,86 @@ def process_dataset_entry(self, data_entry) -> List[DataEntry]: return [DataEntry(data=data_entry)] def finalize(self, metrics): - super().finalize(metrics) \ No newline at end of file + super().finalize(metrics) + + +class EstimateBandwidth(BaseParallelProcessor): + """ + Adds estimated bandwidth to each utterance in the input manifest file. + + Args: + audio_dir (str): Root directory where audio files are stored. + input_audio_key (str): Manifest key with relative audio paths. + output_bandwidth_key (str): Manifest key to store estimated bandwidth in. + max_seconds (float): The maximum length of audio to use for bandwidth estimation. + By default, uses the first 30 seconds. + sample_rate (int): Sample rate to resample audio to before doing bandwidth estimation. + Defaults to 44100, upsampling the input audio as needed. + n_fft (int): Number of FFT bins to use for bandwidth estimation. Defaults to 512. + hop_length (int): Audio frame hop length to use for bandwidth estimation. + Defaults to 441, corresponding to 0.01 seconds for 44100 sample rate. + top_db (float): top_db treshhold to use for bandwidth estimation. + frequency_threshold (float): Bandwidth estimation finds the highest frequency with mean power spectrum that is + within 'frequency_threshold' dB of its peak power. Defaults to -50 dB. + + Returns: + This processor estimates the bandwidth of the audio file in the`input_audio_key` field and saves the estimate + in the output_bandwidth_key` field. + + Example: + .. code-block:: yaml + + - _target_: sdp.processors.EstimateBandwidth + input_manifest_file: ${workspace_dir}/manifest.json + output_manifest_file: ${workspace_dir}/manifest_bandwidth.json + audio_dir: ${workspace_dir}/audio_22khz + max_workers: 8 + """ + + def __init__( + self, + audio_dir: str, + input_audio_key: str = "audio_filepath", + output_bandwidth_key: str = "bandwidth", + max_seconds: float = 30.0, + sample_rate: int = 44100, + n_fft: int = 512, + hop_length: int = 441, + top_db: float = 100.0, + frequency_threshold: float = -50.0, + **kwargs, + ): + super().__init__(**kwargs) + self.audio_directory = Path(audio_dir) + self.input_audio_key = input_audio_key + self.output_bandwidth_key = output_bandwidth_key + self.max_seconds = max_seconds + self.sample_rate = sample_rate + self.n_fft = n_fft + self.hop_length = hop_length + self.top_db = top_db + self.frequency_threshold = frequency_threshold + + def _estimate_bandwidth(self, audio, sample_rate): + spec = librosa.stft(y=audio, n_fft=self.n_fft, hop_length=self.hop_length, window="blackmanharris") + power_spec = np.abs(spec) ** 2 + power_spec = np.mean(power_spec, axis=1) + power_spec = librosa.power_to_db(power_spec, ref=self.n_fft, top_db=self.top_db) + + bandwidth = 0 + peak = np.max(power_spec) + freq_width = sample_rate / self.n_fft + for idx in range(len(power_spec) - 1, -1, -1): + if power_spec[idx] - peak > self.frequency_threshold: + bandwidth = idx * freq_width + break + + return bandwidth + + def process_dataset_entry(self, data_entry): + audio_filename = data_entry[self.input_audio_key] + audio_file = self.audio_directory / audio_filename + audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds) + bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr) + data_entry[self.output_bandwidth_key] = int(bandwidth) + return [DataEntry(data=data_entry)] \ No newline at end of file diff --git a/sdp/processors/nemo/__init__.py b/sdp/processors/nemo/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/sdp/processors/nemo/estimate_bandwidth.py b/sdp/processors/nemo/estimate_bandwidth.py deleted file mode 100644 index 38b261e7..00000000 --- a/sdp/processors/nemo/estimate_bandwidth.py +++ /dev/null @@ -1,87 +0,0 @@ -import librosa -import numpy as np -from pathlib import Path - -from sdp.processors.base_processor import BaseParallelProcessor, DataEntry - - -class EstimateBandwidth(BaseParallelProcessor): - """ - Adds estimated bandwidth to each utterance in the input manifest file. - - Args: - audio_dir (str): Root directory where audio files are stored. - input_audio_key (str): Manifest key with relative audio paths. - output_bandwidth_key (str): Manifest key to store estimated bandwidth in. - max_seconds (float): The maximum length of audio to use for bandwidth estimation. - By default, uses the first 30 seconds. - sample_rate (int): Sample rate to resample audio to before doing bandwidth estimation. - Defaults to 44100, upsampling the input audio as needed. - n_fft (int): Number of FFT bins to use for bandwidth estimation. Defaults to 512. - hop_length (int): Audio frame hop length to use for bandwidth estimation. - Defaults to 441, corresponding to 0.01 seconds for 44100 sample rate. - top_db (float): top_db treshhold to use for bandwidth estimation. - frequency_threshold (float): Bandwidth estimation finds the highest frequency with mean power spectrum that is - within 'frequency_threshold' dB of its peak power. Defaults to -50 dB. - - Returns: - This processor estimates the bandwidth of the audio file in the`input_audio_key` field and saves the estimate - in the output_bandwidth_key` field. - - Example: - .. code-block:: yaml - - - _target_: sdp.processors.EstimateBandwidth - input_manifest_file: ${workspace_dir}/manifest.json - output_manifest_file: ${workspace_dir}/manifest_bandwidth.json - audio_dir: ${workspace_dir}/audio_22khz - max_workers: 8 - """ - - def __init__( - self, - audio_dir: str, - input_audio_key: str = "audio_filepath", - output_bandwidth_key: str = "bandwidth", - max_seconds: float = 30.0, - sample_rate: int = 44100, - n_fft: int = 512, - hop_length: int = 441, - top_db: float = 100.0, - frequency_threshold: float = -50.0, - **kwargs, - ): - super().__init__(**kwargs) - self.audio_directory = Path(audio_dir) - self.input_audio_key = input_audio_key - self.output_bandwidth_key = output_bandwidth_key - self.max_seconds = max_seconds - self.sample_rate = sample_rate - self.n_fft = n_fft - self.hop_length = hop_length - self.top_db = top_db - self.frequency_threshold = frequency_threshold - - def _estimate_bandwidth(self, audio, sample_rate): - spec = librosa.stft(y=audio, n_fft=self.n_fft, hop_length=self.hop_length, window="blackmanharris") - power_spec = np.abs(spec) ** 2 - power_spec = np.mean(power_spec, axis=1) - power_spec = librosa.power_to_db(power_spec, ref=self.n_fft, top_db=self.top_db) - - bandwidth = 0 - peak = np.max(power_spec) - freq_width = sample_rate / self.n_fft - for idx in range(len(power_spec) - 1, -1, -1): - if power_spec[idx] - peak > self.frequency_threshold: - bandwidth = idx * freq_width - break - - return bandwidth - - def process_dataset_entry(self, data_entry): - audio_filename = data_entry[self.input_audio_key] - audio_file = self.audio_directory / audio_filename - audio, sr = librosa.load(path=audio_file, sr=self.sample_rate, duration=self.max_seconds) - bandwidth = self._estimate_bandwidth(audio=audio, sample_rate=sr) - data_entry[self.output_bandwidth_key] = int(bandwidth) - return [DataEntry(data=data_entry)] diff --git a/tests/test_data_to_data.py b/tests/test_data_to_data.py index a18e40e8..9dd3278a 100644 --- a/tests/test_data_to_data.py +++ b/tests/test_data_to_data.py @@ -23,6 +23,9 @@ LambdaExpression, ) +from sdp.processors.inference.llm.utils.qwen_cleaning import CleanQwenGeneration +from sdp.processors.inference.asr.utils.whisper_hallucinations import DetectWhisperHallucinationFeatures + test_params_list = [] test_params_list.extend( @@ -197,6 +200,87 @@ ] ) +test_params_list.extend( + [ + # Case: generation is fine, no replacement + ( + CleanQwenGeneration, + {"cer_threshold": 10, "upper_case_threshold": 0.6}, + {"text": "hello world", "generation": "hello world"}, + [{"text": "hello world", "generation": "hello world"}], + ), + + # Case: generation is completely uppercase → replaced + ( + CleanQwenGeneration, + {"cer_threshold": 10, "upper_case_threshold": 0.5}, + {"text": "hello world", "generation": "HELLO WORLD"}, + [{"text": "hello world", "generation": "hello world"}], + ), + + # Case: generation contains <|endoftext|> and prompt remnants → cleaned + ( + CleanQwenGeneration, + {}, + {"text": "hello", "generation": "Input transcript: hello\nOutput transcript: hello<|endoftext|>"}, + [{"text": "hello", "generation": "hello"}], + ), + + # Case: generation is too different → high CER → replaced + ( + CleanQwenGeneration, + {"cer_threshold": 0.2}, + {"text": "hello world", "generation": "xyz abc"}, + [{"text": "hello world", "generation": "hello world"}], + ), + + # Case: generation is empty → replaced + ( + CleanQwenGeneration, + {}, + {"text": "reference", "generation": ""}, + [{"text": "reference", "generation": "reference"}], + ), + + # Case: text is empty → fallback to replacement + ( + CleanQwenGeneration, + {}, + {"text": "", "generation": "some output"}, + [{"text": "", "generation": ""}], + ), + ] +) + +@pytest.mark.parametrize( + "text,expected_flags", + [ + # repeated n-grams + ("yes yes yes yes yes", {"hall_repeated_ngrams": True, "hall_long_word": False, "hall_frequent_single_word": False}), + # long word + ("short reallyreallyreallyreallyreallyreallyreallylong", {"hall_repeated_ngrams": False, "hall_long_word": True, "hall_frequent_single_word": False}), + # known hallucinated phrase + ("lorem ipsum dolor sit amet", {"hall_repeated_ngrams": False, "hall_long_word": False, "hall_frequent_single_word": True}), + # no hallucination + ("this is a normal sentence", {"hall_repeated_ngrams": False, "hall_long_word": False, "hall_frequent_single_word": False}), + ] +) +def test_detect_whisper_hallucinations(tmp_path, text, expected_flags): + # prepare common phrases file + common_phrases_path = tmp_path / "common_phrases.txt" + common_phrases_path.write_text("lorem ipsum dolor sit amet\n") + + processor = DetectWhisperHallucinationFeatures( + common_hall_file=str(common_phrases_path), + output_manifest_file=None # assuming it's optional or handled elsewhere + ) + + input_entry = {"text": text} + result_entry = processor.process_dataset_entry(input_entry)[0].data + + # check each expected flag + for key, value in expected_flags.items(): + assert result_entry[key] == value, f"Failed for text='{text}' on key='{key}'" @pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str) def test_data_to_data(test_class, class_kwargs, test_input, expected_output):