Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdp/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
SubIfASRSubstitution,
SubMakeLowercase,
SubRegex,
GetWER,
GetCER,
GetEdgeCER,
GetLenDiffRatio,
)
from sdp.processors.modify_manifest.data_to_dropbool import (
DropASRError,
Expand Down
2 changes: 1 addition & 1 deletion sdp/processors/datasets/commoncrawl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@
TrainDevTestSplitCC,
TxtToVtt,
UseSonar,
)
)
356 changes: 356 additions & 0 deletions sdp/processors/modify_manifest/data_to_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
import os
import re
from typing import Dict, List
import jiwer
import editdistance
import itertools
from tqdm.contrib.concurrent import process_map
from tqdm import tqdm
import json

import soundfile as sf

Expand Down Expand Up @@ -525,3 +531,353 @@ def finalize(self, metrics):
for word, count in total_counter_sorted.items():
logger.info(f"{word} {count}")
super().finalize(metrics)

class GetWER(BaseParallelProcessor):
"""
Processor that computes the Word Error Rate (WER) between reference text and hypothesis text.
The WER is computed as the Levenshtein distance between the two texts normalized by the
number of words in the reference text.

Args:
reference_text_field (str): Key to get the reference text from the data.
hypothesis_text_field (str): Key to get the hypothesis text from the data.
output_metric_field (str): Key to put the computed WER value.

Returns:
All the same fields as in the input manifest plus the output_metric_field containing
the computed WER value.
"""

def __init__(
self,
reference_text_field: str = "text",
hypothesis_text_field: str = "pred_text",
output_metric_field: str = "wer",
**kwargs,
):
super().__init__(**kwargs)
self.reference_text_field = reference_text_field
self.hypothesis_text_field = hypothesis_text_field
self.output_metric_field = output_metric_field
self.word_dist = 0
self.num_words = 0

def process(self):
self.prepare()
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
metrics = []

with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
)
for data_entry in tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
self.word_dist += data_entry.metrics.get("word_dist", 0)
self.num_words += data_entry.metrics.get("num_words", 0)
fout.write("\n")

self.finalize(metrics)

def process_dataset_entry(self, data_entry):
reference_text = data_entry[self.reference_text_field]
hypothesis_text = data_entry[self.hypothesis_text_field]

ref_words_amount = len(reference_text.split())
hyp_words_amount = len(hypothesis_text.split())

if ref_words_amount == 0 or hyp_words_amount == 0:
if ref_words_amount == hyp_words_amount:
word_dist = 0
else:
word_dist = ref_words_amount
else:
word_dist_measures = jiwer.compute_measures(reference_text, hypothesis_text)
word_dist = word_dist_measures['substitutions'] + word_dist_measures['insertions'] + word_dist_measures['deletions']

wer_value = word_dist / ref_words_amount
data_entry[self.output_metric_field] = round(wer_value * 100, 2)

return [DataEntry(data=data_entry, metrics = {'word_dist' : word_dist, 'num_words' : ref_words_amount})]

def finalize(self, metrics: List):
logger.info("Total number of entries after processing: %d", self.number_of_entries)
if self.total_duration != 0:
logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600)

logger.info("Overall Word Error Rate (WER): %.2f%%", self.word_dist / self.num_words * 100)


class GetCER(BaseParallelProcessor):
"""
Processor that computes the Character Error Rate (CER) between reference text and hypothesis text.
The CER is computed as the Levenshtein distance between the two texts normalized by the
number of characters in the reference text.

Args:
reference_text_field (str): Key to get the reference text from the data.
hypothesis_text_field (str): Key to get the hypothesis text from the data.
output_metric_field (str): Key to put the computed CER value.

Returns:
All the same fields as in the input manifest plus the output_metric_field containing
the computed CER value.
"""

def __init__(
self,
reference_text_field: str = "text",
hypothesis_text_field: str = "pred_text",
output_metric_field: str = "cer",
**kwargs,
):
super().__init__(**kwargs)
self.reference_text_field = reference_text_field
self.hypothesis_text_field = hypothesis_text_field
self.output_metric_field = output_metric_field
self.char_dist = 0
self.num_chars = 0

def process(self):
self.prepare()
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
metrics = []

with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
)
for data_entry in tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
self.char_dist += data_entry.metrics.get("char_dist", 0)
self.num_chars += data_entry.metrics.get("num_chars", 0)
fout.write("\n")

self.finalize(metrics)

def process_dataset_entry(self, data_entry):
reference_text = data_entry[self.reference_text_field]
hypothesis_text = data_entry[self.hypothesis_text_field]

ref_chars_amount = len(reference_text)
hyp_chars_amount = len(hypothesis_text)

if ref_chars_amount == 0 or hyp_chars_amount == 0:
if ref_chars_amount == hyp_chars_amount:
char_dist = 0
else:
char_dist = ref_chars_amount
else:
char_dist = editdistance.eval(reference_text, hypothesis_text)

cer_value = char_dist / ref_chars_amount
data_entry[self.output_metric_field] = round(cer_value * 100, 2)

return [DataEntry(data=data_entry, metrics = {'char_dist' : char_dist, 'num_chars' : ref_chars_amount})]

def finalize(self, metrics: List):
logger.info("Total number of entries after processing: %d", self.number_of_entries)
if self.total_duration != 0:
logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600)

logger.info("Overall Character Error Rate (CER): %.2f%%", self.char_dist / self.num_chars * 100)


class GetEdgeCER(BaseParallelProcessor):
"""
Processor that computes the Character Error Rate (CER) for a specified edge of reference
and hypothesis texts.

Args:
reference_text_field (str): Key to get the reference text from the data.
hypothesis_text_field (str): Key to get the hypothesis text from the data.
edge (str): Specifies whether to compute CER for the 'start' or 'end' edge of the texts.
edge_len (int): Length of the edge window.
output_metric_field (str): Key to put the computed edge CER value.

Returns:
All the same fields as in the input manifest plus the output_metric_field containing
the computed edge CER value.
"""

def __init__(
self,
reference_text_field: str = "text",
hypothesis_text_field: str = "pred_text",
edge: str = "start",
edge_len: int = 10,
output_metric_field: str = "start_cer",
**kwargs,
):
super().__init__(**kwargs)
self.reference_text_field = reference_text_field
self.hypothesis_text_field = hypothesis_text_field
self.edge = edge
self.edge_len = edge_len
self.output_metric_field = output_metric_field
self.edge_cer_sum = 0

def process(self):
self.prepare()
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
metrics = []

with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
)
for data_entry in tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
self.edge_cer_sum += data_entry.data.get(self.output_metric_field, 0)
fout.write("\n")

self.finalize(metrics)

def process_dataset_entry(self, data_entry):
if self.edge == "start":
start_idx = 0
end_idx = self.edge_len
elif self.edge == "end":
start_idx = -self.edge_len
end_idx = -1
else:
raise ValueError(f"Current `Edge` parameter value ({self.edge}) is incorrect. Please select `start` or `end` edge.")

reference_text_edge = data_entry[self.reference_text_field][start_idx : end_idx]
hypothesis_text_edge = data_entry[self.hypothesis_text_field][start_idx : end_idx]

ref_chars_amount = len(reference_text_edge)
hyp_chars_amount = len(hypothesis_text_edge)

if ref_chars_amount == 0 or hyp_chars_amount == 0:
if ref_chars_amount == hyp_chars_amount:
char_dist = 0
else:
char_dist = ref_chars_amount
else:
char_dist = editdistance.eval(reference_text_edge, hypothesis_text_edge)

edge_cer_value = char_dist / ref_chars_amount
data_entry[self.output_metric_field] = round(edge_cer_value * 100, 2)

return [DataEntry(data=data_entry)]

def finalize(self, metrics: List):
logger.info("Total number of entries after processing: %d", self.number_of_entries)
if self.total_duration != 0:
logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600)

logger.info(f"Mean {self.edge} Character Error Rate (CER): {round(self.edge_cer_sum / self.number_of_entries, 2)}%")


class GetLenDiffRatio(BaseParallelProcessor):
"""
Processor that computes the length difference ratio between reference and hypothesis texts.

Args:
reference_text_field (str): Key to get the reference text from the data.
hypothesis_text_field (str): Key to get the hypothesis text from the data.
output_metric_field (str): Key to put the computed length difference ratio.

Returns:
All the same fields as in the input manifest plus the output_metric_field containing
the computed length difference ratio.
"""

def __init__(
self,
reference_text_field: str = "text",
hypothesis_text_field: str = "pred_text",
output_metric_field: str = "len_diff_ratio",
**kwargs,
):
super().__init__(**kwargs)
self.reference_text_field = reference_text_field
self.hypothesis_text_field = hypothesis_text_field
self.output_metric_field = output_metric_field
self.words_len_diff_ratio_sum = 0

def process(self):
self.prepare()
os.makedirs(os.path.dirname(self.output_manifest_file), exist_ok=True)
metrics = []

with open(self.output_manifest_file, "wt", encoding="utf8") as fout:
for manifest_chunk in self._chunk_manifest():
# this will unroll all inner lists
data = itertools.chain(
*process_map(
self.process_dataset_entry,
manifest_chunk,
max_workers=self.max_workers,
chunksize=self.chunksize,
)
)
for data_entry in tqdm(data):
metrics.append(data_entry.metrics)
if data_entry.data is None:
continue
json.dump(data_entry.data, fout, ensure_ascii=False)
self.number_of_entries += 1
self.total_duration += data_entry.data.get("duration", 0)
self.words_len_diff_ratio_sum += data_entry.data.get(self.output_metric_field, 0)
fout.write("\n")

self.finalize(metrics)

def process_dataset_entry(self, data_entry):
reference_text = data_entry[self.reference_text_field]
hypothesis_text = data_entry[self.hypothesis_text_field]

ref_words_amount = len(reference_text.split())
hyp_words_amount = len(hypothesis_text.split())

eps = 1e-9
len_diff_ratio = 1.0 * abs(ref_words_amount - hyp_words_amount) / max(ref_words_amount, eps)

data_entry[self.output_metric_field] = round(len_diff_ratio * 100, 2)

return [DataEntry(data=data_entry)]

def finalize(self, metrics: List):
logger.info("Total number of entries after processing: %d", self.number_of_entries)
if self.total_duration != 0:
logger.info("Total audio duration (hours) after processing: %.2f", self.total_duration / 3600)

logger.info(f"Mean Text Length Difference Ratio (in words): {round(self.words_len_diff_ratio_sum / self.number_of_entries, 2)}%")