diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index 2ab441c5..58cae45b 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -54,6 +54,10 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, + GetWER, + GetCER, + GetEdgeCER, + GetLenDiffRatio, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, diff --git a/sdp/processors/datasets/commoncrawl/__init__.py b/sdp/processors/datasets/commoncrawl/__init__.py index 815a5549..e20ef3b2 100644 --- a/sdp/processors/datasets/commoncrawl/__init__.py +++ b/sdp/processors/datasets/commoncrawl/__init__.py @@ -36,4 +36,4 @@ TrainDevTestSplitCC, TxtToVtt, UseSonar, -) +) \ No newline at end of file diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index dd09f8dc..762dc37f 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -16,6 +16,12 @@ import os import re from typing import Dict, List +import jiwer +import editdistance +import itertools +from tqdm.contrib.concurrent import process_map +from tqdm import tqdm +import json import soundfile as sf @@ -525,3 +531,353 @@ def finalize(self, metrics): for word, count in total_counter_sorted.items(): logger.info(f"{word} {count}") super().finalize(metrics) + +class GetWER(BaseParallelProcessor): + """ + Processor that computes the Word Error Rate (WER) between reference text and hypothesis text. + The WER is computed as the Levenshtein distance between the two texts normalized by the + number of words in the reference text. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed WER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed WER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "wer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.word_dist = 0 + self.num_words = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.word_dist += data_entry.metrics.get("word_dist", 0) + self.num_words += data_entry.metrics.get("num_words", 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_words_amount = len(reference_text.split()) + hyp_words_amount = len(hypothesis_text.split()) + + if ref_words_amount == 0 or hyp_words_amount == 0: + if ref_words_amount == hyp_words_amount: + word_dist = 0 + else: + word_dist = ref_words_amount + else: + word_dist_measures = jiwer.compute_measures(reference_text, hypothesis_text) + word_dist = word_dist_measures['substitutions'] + word_dist_measures['insertions'] + word_dist_measures['deletions'] + + wer_value = word_dist / ref_words_amount + data_entry[self.output_metric_field] = round(wer_value * 100, 2) + + return [DataEntry(data=data_entry, metrics = {'word_dist' : word_dist, 'num_words' : ref_words_amount})] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info("Overall Word Error Rate (WER): %.2f%%", self.word_dist / self.num_words * 100) + + +class GetCER(BaseParallelProcessor): + """ + Processor that computes the Character Error Rate (CER) between reference text and hypothesis text. + The CER is computed as the Levenshtein distance between the two texts normalized by the + number of characters in the reference text. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed CER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed CER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "cer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.char_dist = 0 + self.num_chars = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.char_dist += data_entry.metrics.get("char_dist", 0) + self.num_chars += data_entry.metrics.get("num_chars", 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_chars_amount = len(reference_text) + hyp_chars_amount = len(hypothesis_text) + + if ref_chars_amount == 0 or hyp_chars_amount == 0: + if ref_chars_amount == hyp_chars_amount: + char_dist = 0 + else: + char_dist = ref_chars_amount + else: + char_dist = editdistance.eval(reference_text, hypothesis_text) + + cer_value = char_dist / ref_chars_amount + data_entry[self.output_metric_field] = round(cer_value * 100, 2) + + return [DataEntry(data=data_entry, metrics = {'char_dist' : char_dist, 'num_chars' : ref_chars_amount})] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info("Overall Character Error Rate (CER): %.2f%%", self.char_dist / self.num_chars * 100) + + +class GetEdgeCER(BaseParallelProcessor): + """ + Processor that computes the Character Error Rate (CER) for a specified edge of reference + and hypothesis texts. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + edge (str): Specifies whether to compute CER for the 'start' or 'end' edge of the texts. + edge_len (int): Length of the edge window. + output_metric_field (str): Key to put the computed edge CER value. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed edge CER value. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + edge: str = "start", + edge_len: int = 10, + output_metric_field: str = "start_cer", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.edge = edge + self.edge_len = edge_len + self.output_metric_field = output_metric_field + self.edge_cer_sum = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.edge_cer_sum += data_entry.data.get(self.output_metric_field, 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + if self.edge == "start": + start_idx = 0 + end_idx = self.edge_len + elif self.edge == "end": + start_idx = -self.edge_len + end_idx = -1 + else: + raise ValueError(f"Current `Edge` parameter value ({self.edge}) is incorrect. Please select `start` or `end` edge.") + + reference_text_edge = data_entry[self.reference_text_field][start_idx : end_idx] + hypothesis_text_edge = data_entry[self.hypothesis_text_field][start_idx : end_idx] + + ref_chars_amount = len(reference_text_edge) + hyp_chars_amount = len(hypothesis_text_edge) + + if ref_chars_amount == 0 or hyp_chars_amount == 0: + if ref_chars_amount == hyp_chars_amount: + char_dist = 0 + else: + char_dist = ref_chars_amount + else: + char_dist = editdistance.eval(reference_text_edge, hypothesis_text_edge) + + edge_cer_value = char_dist / ref_chars_amount + data_entry[self.output_metric_field] = round(edge_cer_value * 100, 2) + + return [DataEntry(data=data_entry)] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info(f"Mean {self.edge} Character Error Rate (CER): {round(self.edge_cer_sum / self.number_of_entries, 2)}%") + + +class GetLenDiffRatio(BaseParallelProcessor): + """ + Processor that computes the length difference ratio between reference and hypothesis texts. + + Args: + reference_text_field (str): Key to get the reference text from the data. + hypothesis_text_field (str): Key to get the hypothesis text from the data. + output_metric_field (str): Key to put the computed length difference ratio. + + Returns: + All the same fields as in the input manifest plus the output_metric_field containing + the computed length difference ratio. + """ + + def __init__( + self, + reference_text_field: str = "text", + hypothesis_text_field: str = "pred_text", + output_metric_field: str = "len_diff_ratio", + **kwargs, + ): + super().__init__(**kwargs) + self.reference_text_field = reference_text_field + self.hypothesis_text_field = hypothesis_text_field + self.output_metric_field = output_metric_field + self.words_len_diff_ratio_sum = 0 + + def process(self): + self.prepare() + os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True) + metrics = [] + + with open(self.output_manifest_file, "wt", encoding="utf8") as fout: + for manifest_chunk in self._chunk_manifest(): + # this will unroll all inner lists + data = itertools.chain( + *process_map( + self.process_dataset_entry, + manifest_chunk, + max_workers=self.max_workers, + chunksize=self.chunksize, + ) + ) + for data_entry in tqdm(data): + metrics.append(data_entry.metrics) + if data_entry.data is None: + continue + json.dump(data_entry.data, fout, ensure_ascii=False) + self.number_of_entries += 1 + self.total_duration += data_entry.data.get("duration", 0) + self.words_len_diff_ratio_sum += data_entry.data.get(self.output_metric_field, 0) + fout.write("\n") + + self.finalize(metrics) + + def process_dataset_entry(self, data_entry): + reference_text = data_entry[self.reference_text_field] + hypothesis_text = data_entry[self.hypothesis_text_field] + + ref_words_amount = len(reference_text.split()) + hyp_words_amount = len(hypothesis_text.split()) + + eps = 1e-9 + len_diff_ratio = 1.0 * abs(ref_words_amount - hyp_words_amount) / max(ref_words_amount, eps) + + data_entry[self.output_metric_field] = round(len_diff_ratio * 100, 2) + + return [DataEntry(data=data_entry)] + + def finalize(self, metrics: List): + logger.info("Total number of entries after processing: %d", self.number_of_entries) + if self.total_duration != 0: + logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600) + + logger.info(f"Mean Text Length Difference Ratio (in words): {round(self.words_len_diff_ratio_sum / self.number_of_entries, 2)}%") \ No newline at end of file