From acc18dd67aa3b9bbba40974881a55a48768aee6d Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 12:27:43 -0700 Subject: [PATCH 01/10] Added changes to speaker_utils Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 423 ++++++++++++++---- 1 file changed, 325 insertions(+), 98 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 566b26ad054f..96a4ffd8be11 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import math import os from copy import deepcopy +from typing import List, Optional, Dict +from functools import reduce import numpy as np import omegaconf @@ -34,19 +37,39 @@ def get_uniqname_from_filepath(filepath): - "return base name from provided filepath" + """ + Return base name from provided filepath + """ if type(filepath) is str: basename = os.path.basename(filepath).rsplit('.', 1)[0] return basename else: raise TypeError("input must be filepath string") +def get_uniq_id_with_dur(meta, deci=2): + """ + Return basename with offset and end time labels + """ + bare_uniq_id = meta['audio_filepath'].split('/')[-1].split('.wav')[0] + if meta['offset'] == None and meta['duration'] == None: + return bare_uniq_id + if meta['offset']: + offset = str(int(round(meta['offset'], deci) * pow(10, deci))) + else: + offset = 0 + if meta['duration']: + endtime = str(int(round(meta['offset'] + meta['duration'], deci) * pow(10, deci))) + else: + endtime = 'NULL' + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + def audio_rttm_map(manifest): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - input: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists + Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists returns: AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files @@ -63,6 +86,7 @@ def audio_rttm_map(manifest): meta = { 'audio_filepath': dic['audio_filepath'], 'rttm_filepath': dic.get('rttm_filepath', None), + 'offset': dic.get('offset', None), 'duration': dic.get('duration', None), 'text': dic.get('text', None), 'num_speakers': dic.get('num_speakers', None), @@ -98,6 +122,9 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ parameters.window_length_in_sec=[1.5,1.0,0.5] parameters.shift_length_in_sec=[0.75,0.5,0.25] parameters.multiscale_weights=[0.33,0.33,0.33] + + In addition, you can also specify session-by-session multiscale weight. In this case, each dictionary key + points to different weights. """ checkFloatConfig = [type(var) == float for var in (window_lengths_in_sec, shift_lengths_in_sec)] checkListConfig = [ @@ -245,7 +272,7 @@ def labels_to_pyannote_object(labels, uniq_name=''): def uem_timeline_from_file(uem_file, uniq_name=''): """ - outputs pyannote timeline segments for uem file + Generate pyannote timeline segments for uem file file format UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME @@ -263,7 +290,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): """ - write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels + Write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels """ filename = os.path.join(out_rttm_dir, uniq_id + '.rttm') with open(filename, 'w') as f: @@ -280,7 +307,7 @@ def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): def rttm_to_labels(rttm_filename): """ - prepares time stamps label list from rttm file + Prepare time stamps label list from rttm file """ labels = [] with open(rttm_filename, 'r') as f: @@ -293,7 +320,6 @@ def rttm_to_labels(rttm_filename): def write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir): """ - Write cluster labels that are generated from clustering into a file. Args: base_scale_idx (int): The base scale index which is the highest scale index. @@ -310,7 +336,7 @@ def write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir): def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params): """ - performs spectral clustering on embeddings with time stamps generated from VAD output + Performs spectral clustering on embeddings with time stamps generated from VAD output Args: embs_and_timestamps (dict): This dictionary contains the following items indexed by unique IDs. @@ -363,6 +389,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste lines[idx] += tag a = get_contiguous_stamps(lines) labels = merge_stamps(a) + if out_rttm_dir: labels_to_rttmfile(labels, uniq_id, out_rttm_dir) lines_cluster_labels.extend([f'{uniq_id} {seg_line}\n' for seg_line in lines]) @@ -386,16 +413,16 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True): """ - calculates DER, CER, FA and MISS + Calculates DER, CER, FA and MISS Args: - AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation + AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. + mapping (dict): Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -437,110 +464,307 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign return None +def get_vad_out_from_rttm_line(rttm_line): + vad_out = rttm_line.strip().split() + if len(vad_out) > 3: + start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] + else: + start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] + start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) + return start, dur + +def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): + """ + Extract offset and duration information from AUDIO_RTTM_MAP dictionary. + If duration information is not specified, a duration value is extracted from the audio file directly. + + Args: + AUDIO_RTTM_MAP (dict): + Dictionary containing RTTM file information, which is indexed by unique file id. + uniq_id (str): + Unique file id + Returns: + offset (float): + The offset value that determines the beginning of the audio stream. + duration (float): + The length of audio stream that is expected to be used. + """ + audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] + if AUDIO_RTTM_MAP[uniq_id].get('duration', None): + duration = round(AUDIO_RTTM_MAP[uniq_id]['duration'], 2) + offset = round(AUDIO_RTTM_MAP[uniq_id]['offset'], 2) + else: + sound = sf.SoundFile(audio_path) + duration = sound.frames / sound.samplerate + offset = 0.0 + return offset, duration + +def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id): + """ + Write the json dictionary into the specified file. + + Args: + outfile: + File pointer that indicates output file path. + AUDIO_RTTM_MAP (dict): + Dictionary containing the input manifest information + uniq_id (str): + Unique file id + overlap_range_list (list): + List containing overlapping ranges between target and source. + """ + audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] + for (stt, end) in overlap_range_list: + meta = {"audio_filepath": audio_path, + "offset": round(stt, 2), + "duration": round(end - stt, 2), + "label": 'UNK', + "uniq_id": uniq_id} + json.dump(meta, outfile) + outfile.write("\n") + +def read_rttm_lines(rttm_file_path): + """ + Read rttm files and return the rttm information lines. + + Args: + rttm_file_path (str): + + Returns: + lines (list): + List containing the strings from the RTTM file. + """ + if rttm_file_path and os.path.exists(rttm_file_path): + f = open(rttm_file_path, 'r') + else: + raise FileNotFoundError( + "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( + rttm_file_path + ) + ) + lines = f.readlines() + return lines + +def isOverlap(rangeA, rangeB): + """ + Check whether two ranges have overlap. + Args: + rangeA (list, tuple): + List or tuple containing start and end value in float. + rangeB (list, tuple): + List or tuple containing start and end value in float. + Returns: + (bool): + Boolean that indicates whether the input ranges have overlap. + """ + start1, end1 = rangeA + start2, end2 = rangeB + return end1 > start2 and end2 > start1 + +def getOverlapRange(rangeA, rangeB): + """ + Calculate the overlapping range between rangeA and rangeB. + Args: + rangeA (list, tuple): + List or tuple containing start and end value in float. + rangeB (list, tuple): + List or tuple containing start and end value in float. + Returns: + (list): + List containing the overlapping range between rangeA and rangeB. + """ + assert isOverlap(rangeA, rangeB), f"There is no overlap between rangeA:{rangeA} and rangeB:{rangeB}" + return [max(rangeA[0], rangeB[0]), min(rangeA[1], rangeB[1])] + +def combine_float_overlaps(ranges): + """ + Args: + ranges(list): + List containing ranges. + Example: [(10.2, 10.83), (10.42, 10.91), (10.45, 12.09)] + Returns: + merged_list (list): + List containing the combined ranges. + Example: [(10.2, 12.09)] + + Combine overlaps with floating point numbers. Since neighboring integers are considered as continuous range, + we need to add 1 to the starting range before merging then subtract 1 from the result range. + """ + ranges_int = [] + for x in ranges: + stt, end = fl2int(x[0])+1, fl2int(x[1]) + if stt == end: + logging.warning(f"The ragne {stt}:{end} is too short to be combined therefore skipped.") + else: + ranges_int.append([stt, end]) + merged_ranges = combine_int_overlaps(ranges_int) + merged_ranges = [[int2fl(x[0]-1), int2fl(x[1])] for x in merged_ranges] + return merged_ranges + +def combine_int_overlaps(ranges): + """ + Merge the range pairs if there is overlap exists between the given ranges. + Refer to the original code at https://stackoverflow.com/a/59378428 + + Args: + ranges(list): + List containing ranges. + Example: [(102, 108), (104, 109), (107, 120)] + Returns: + merged_list (list): + List containing the combined ranges. + Example: [(102, 120)] + + """ + merged_list = reduce( + lambda x, element: x[:-1:] + [(min(*x[-1], *element), max(*x[-1], *element))] + if x[-1][1] >= element[0] - 1 + else x + [element], + ranges[1::], + ranges[0:1], + ) + return merged_list + +def fl2int(x, deci=2): + """ + Convert floating point number to integer. + """ + return int(round(x*pow(10,deci))) + +def int2fl(x, deci=2): + """ + Convert integer to floating point number. + """ + return round(float(x/pow(10,deci)), int(deci)) -def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file): +def getMergedRanges(label_list_A: List, label_list_B: List) -> List: """ - writes manifest file based on rttm files (or vad table out files). This manifest file would be used by - speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps + Calculate the merged ranges between label_list_A and label_list_B. Args: - AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, - these are used to extract oracle vad timestamps. - manifest (str): path to write manifest file + label_list_A (list): + List containing ranges (start and end values) + label_list_B (list): + List containing ranges (start and end values) + Returns: + (list): + List containing the merged ranges + + """ + if label_list_A == [] and label_list_B != []: + return label_list_B + elif label_list_A != [] and label_list_B == []: + return label_list_A + else: + label_list_A = [ [fl2int(x[0]), fl2int(x[1])] for x in label_list_A] + label_list_B = [ [fl2int(x[0]), fl2int(x[1])] for x in label_list_B] + combined = combine_int_overlaps(label_list_A + label_list_B) + return [ [int2fl(x[0]), int2fl(x[1])] for x in combined ] + +def getMinMaxOfRangeList(ranges): + """ + Get the min and max of a given range list. + """ + _max = max([x[1] for x in ranges]) + _min = min([x[0] for x in ranges]) + return _min, _max + +def getSubRangeList(target_range, source_range_list) -> List: + """ + Get the ranges that has overlaps with the target range from the source_range_list. + + Example: + source range: + |===--======---=====---====--| + target range: + |--------================----| + out_range: + |--------===---=====---==----| + + Args: + target_range (list): + A range (a start and end value pair) that defines the target range we want to select. + target_range = [(start, end)] + source_range_list (list): + List containing the subranges that need to be selected. + source_ragne = [(start0, end0), (start1, end1), ...] + Returns: + out_range (list): + List containing the overlap between target_range and + source_range_list. + """ + if target_range == []: + return [] + else: + out_range = [] + for s_range in source_range_list: + if isOverlap(s_range, target_range): + ovl_range = getOverlapRange(s_range, target_range) + out_range.append(ovl_range) + return out_range + +def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 2) -> str: + """ + Write manifest file based on rttm files (or vad table out files). This manifest file would be used by + speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping time stamps. + + Args: + AUDIO_RTTM_MAP (dict): + Dictionary containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, + these are used to extract oracle vad timestamps. + manifest (str): + The path to the output manifest file. Returns: - manifest (str): path to write manifest file + manifest (str): + The path to the output manifest file. """ with open(manifest_file, 'w') as outfile: - for key in AUDIO_RTTM_MAP: - rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath'] - if rttm_filename and os.path.exists(rttm_filename): - f = open(rttm_filename, 'r') - else: - raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_filename - ) + for uniq_id in AUDIO_RTTM_MAP: + rttm_file_path = AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] + rttm_lines = read_rttm_lines(rttm_file_path) + offset, duration = get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id) + vad_start_end_list_raw = [] + for line in rttm_lines: + start, dur = get_vad_out_from_rttm_line(line) + vad_start_end_list_raw.append([start, start+dur]) + vad_start_end_list = combine_float_overlaps(vad_start_end_list_raw) + if len(vad_start_end_list) == 0: + logging.warning( + f"File ID: {uniq_id}: The VAD label is not containing any speech segments." + ) + elif duration == 0: + logging.warning( + f"File ID: {uniq_id}: The audio file has zero duration." ) - - audio_path = AUDIO_RTTM_MAP[key]['audio_filepath'] - if AUDIO_RTTM_MAP[key].get('duration', None): - max_duration = AUDIO_RTTM_MAP[key]['duration'] - else: - sound = sf.SoundFile(audio_path) - max_duration = sound.frames / sound.samplerate - - lines = f.readlines() - time_tup = (-1, -1) - for line in lines: - vad_out = line.strip().split() - if len(vad_out) > 3: - start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] - else: - start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] - start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) - - if start == 0 and dur == 0: # No speech segments - continue - else: - - if time_tup[0] >= 0 and start > time_tup[1]: - dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - if time_tup[0] < max_duration and dur2 > 0: - meta = { - "audio_filepath": audio_path, - "offset": time_tup[0], - "duration": dur2, - "label": 'UNK', - } - json.dump(meta, outfile) - outfile.write("\n") - else: - logging.warning( - "RTTM label has been truncated since start is greater than duration of audio file" - ) - time_tup = (start, start + dur) - else: - if time_tup[0] == -1: - end_time = start + dur - if end_time > max_duration: - end_time = max_duration - time_tup = (start, end_time) - else: - end_time = max(time_tup[1], start + dur) - if end_time > max_duration: - end_time = max_duration - time_tup = (min(time_tup[0], start), end_time) - dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - if time_tup[0] < max_duration and dur2 > 0: - meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} - json.dump(meta, outfile) - outfile.write("\n") else: - logging.warning("RTTM label has been truncated since start is greater than duration of audio file") - f.close() + min_vad, max_vad = getMinMaxOfRangeList(vad_start_end_list) + if max_vad > round(offset + duration, deci) or min_vad < offset: + logging.warning("RTTM label has been truncated since start is greater than duration of audio file") + overlap_range_list = getSubRangeList(source_range_list=vad_start_end_list, + target_range=[offset, offset+duration]) + write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id) return manifest_file - def segments_manifest_to_subsegments_manifest( segments_manifest_file: str, subsegments_manifest_file: str = None, window: float = 1.5, shift: float = 0.75, min_subsegment_duration: float = 0.05, + include_uniq_id:bool = False ): """ Generate subsegments manifest from segments manifest file - Args - input: + Args: segments_manifest file (str): path to segments manifest file, typically from VAD output subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value - output: + Returns: returns path to subsegment manifest file """ if subsegments_manifest_file is None: @@ -556,11 +780,15 @@ def segments_manifest_to_subsegments_manifest( dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) - + if include_uniq_id and 'uniq_id' in dic: + uniq_id = dic['uniq_id'] + else: + uniq_id = None for subsegment in subsegments: start, dur = subsegment if dur > min_subsegment_duration: - meta = {"audio_filepath": audio, "offset": start, "duration": dur, "label": label} + meta = {"audio_filepath": audio, "offset": start, "duration": dur, "label": label, "uniq_id": uniq_id} + json.dump(meta, subsegments_manifest) subsegments_manifest.write("\n") @@ -569,14 +797,13 @@ def segments_manifest_to_subsegments_manifest( def get_subsegments(offset: float, window: float, shift: float, duration: float): """ - return subsegments from a segment of audio file - Args - input: + Return subsegments from a segment of audio file + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment - output: + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ subsegments = [] @@ -596,10 +823,10 @@ def get_subsegments(offset: float, window: float, shift: float, duration: float) def embedding_normalize(embs, use_std=False, eps=1e-10): """ - mean and l2 length normalize the input speaker embeddings - input: + Mean and l2 length normalize the input speaker embeddings + Args: embs: embeddings of shape (Batch,emb_size) - output: + Returns: embs: normalized embeddings of shape (Batch,emb_size) """ embs = embs - embs.mean(axis=0) From 2b718c34e779aa9e32d4454305707008e254caae Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 12:51:59 -0700 Subject: [PATCH 02/10] style fix Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 86 ++++++++++++------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 96a4ffd8be11..330d6141ab63 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -16,8 +16,8 @@ import math import os from copy import deepcopy -from typing import List, Optional, Dict from functools import reduce +from typing import Dict, List, Optional import numpy as np import omegaconf @@ -30,7 +30,6 @@ from nemo.collections.asr.parts.utils.nmesc_clustering import COSclustering from nemo.utils import logging - """ This file contains all the utility functions required for speaker embeddings part in diarization scripts """ @@ -46,6 +45,7 @@ def get_uniqname_from_filepath(filepath): else: raise TypeError("input must be filepath string") + def get_uniq_id_with_dur(meta, deci=2): """ Return basename with offset and end time labels @@ -176,7 +176,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ elif any(checkListConfig): raise ValueError( - 'You must provide list config for all three parameters: window, shift and multiscale weights.' + 'You must provide a list config for all three parameters: window, shift and multiscale weights.' ) else: return None @@ -239,7 +239,7 @@ def get_contiguous_stamps(stamps): def merge_stamps(lines): """ - merge time stamps of same speaker + Merge time stamps of the same speaker. """ stamps = deepcopy(lines) overlap_stamps = [] @@ -259,7 +259,7 @@ def merge_stamps(lines): def labels_to_pyannote_object(labels, uniq_name=''): """ - converts labels to pyannote object to calculate DER and for visualization + Convert the given labels to pyannote object to calculate DER and for visualization """ annotation = Annotation(uri=uniq_name) for label in labels: @@ -360,7 +360,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste cuda = True if not torch.cuda.is_available(): - logging.warning("cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process.") + logging.warning("cuda=False, using CPU for Eigen decomposition. This might slow down the clustering process.") cuda = False for uniq_id, value in tqdm(AUDIO_RTTM_MAP.items()): @@ -389,7 +389,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste lines[idx] += tag a = get_contiguous_stamps(lines) labels = merge_stamps(a) - + if out_rttm_dir: labels_to_rttmfile(labels, uniq_id, out_rttm_dir) lines_cluster_labels.extend([f'{uniq_id} {seg_line}\n' for seg_line in lines]) @@ -464,6 +464,7 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign return None + def get_vad_out_from_rttm_line(rttm_line): vad_out = rttm_line.strip().split() if len(vad_out) > 3: @@ -473,6 +474,7 @@ def get_vad_out_from_rttm_line(rttm_line): start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) return start, dur + def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): """ Extract offset and duration information from AUDIO_RTTM_MAP dictionary. @@ -499,6 +501,7 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): offset = 0.0 return offset, duration + def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id): """ Write the json dictionary into the specified file. @@ -515,14 +518,17 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] for (stt, end) in overlap_range_list: - meta = {"audio_filepath": audio_path, - "offset": round(stt, 2), - "duration": round(end - stt, 2), - "label": 'UNK', - "uniq_id": uniq_id} + meta = { + "audio_filepath": audio_path, + "offset": round(stt, 2), + "duration": round(end - stt, 2), + "label": 'UNK', + "uniq_id": uniq_id, + } json.dump(meta, outfile) outfile.write("\n") + def read_rttm_lines(rttm_file_path): """ Read rttm files and return the rttm information lines. @@ -545,6 +551,7 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() return lines + def isOverlap(rangeA, rangeB): """ Check whether two ranges have overlap. @@ -561,6 +568,7 @@ def isOverlap(rangeA, rangeB): start2, end2 = rangeB return end1 > start2 and end2 > start1 + def getOverlapRange(rangeA, rangeB): """ Calculate the overlapping range between rangeA and rangeB. @@ -576,6 +584,7 @@ def getOverlapRange(rangeA, rangeB): assert isOverlap(rangeA, rangeB), f"There is no overlap between rangeA:{rangeA} and rangeB:{rangeB}" return [max(rangeA[0], rangeB[0]), min(rangeA[1], rangeB[1])] + def combine_float_overlaps(ranges): """ Args: @@ -592,15 +601,16 @@ def combine_float_overlaps(ranges): """ ranges_int = [] for x in ranges: - stt, end = fl2int(x[0])+1, fl2int(x[1]) + stt, end = fl2int(x[0]) + 1, fl2int(x[1]) if stt == end: logging.warning(f"The ragne {stt}:{end} is too short to be combined therefore skipped.") else: ranges_int.append([stt, end]) merged_ranges = combine_int_overlaps(ranges_int) - merged_ranges = [[int2fl(x[0]-1), int2fl(x[1])] for x in merged_ranges] + merged_ranges = [[int2fl(x[0] - 1), int2fl(x[1])] for x in merged_ranges] return merged_ranges + def combine_int_overlaps(ranges): """ Merge the range pairs if there is overlap exists between the given ranges. @@ -618,24 +628,27 @@ def combine_int_overlaps(ranges): """ merged_list = reduce( lambda x, element: x[:-1:] + [(min(*x[-1], *element), max(*x[-1], *element))] - if x[-1][1] >= element[0] - 1 - else x + [element], + if x[-1][1] >= element[0] - 1 + else x + [element], ranges[1::], ranges[0:1], ) return merged_list + def fl2int(x, deci=2): """ Convert floating point number to integer. """ - return int(round(x*pow(10,deci))) + return int(round(x * pow(10, deci))) + def int2fl(x, deci=2): """ Convert integer to floating point number. """ - return round(float(x/pow(10,deci)), int(deci)) + return round(float(x / pow(10, deci)), int(deci)) + def getMergedRanges(label_list_A: List, label_list_B: List) -> List: """ @@ -656,10 +669,11 @@ def getMergedRanges(label_list_A: List, label_list_B: List) -> List: elif label_list_A != [] and label_list_B == []: return label_list_A else: - label_list_A = [ [fl2int(x[0]), fl2int(x[1])] for x in label_list_A] - label_list_B = [ [fl2int(x[0]), fl2int(x[1])] for x in label_list_B] + label_list_A = [[fl2int(x[0]), fl2int(x[1])] for x in label_list_A] + label_list_B = [[fl2int(x[0]), fl2int(x[1])] for x in label_list_B] combined = combine_int_overlaps(label_list_A + label_list_B) - return [ [int2fl(x[0]), int2fl(x[1])] for x in combined ] + return [[int2fl(x[0]), int2fl(x[1])] for x in combined] + def getMinMaxOfRangeList(ranges): """ @@ -669,6 +683,7 @@ def getMinMaxOfRangeList(ranges): _min = min([x[0] for x in ranges]) return _min, _max + def getSubRangeList(target_range, source_range_list) -> List: """ Get the ranges that has overlaps with the target range from the source_range_list. @@ -701,7 +716,8 @@ def getSubRangeList(target_range, source_range_list) -> List: if isOverlap(s_range, target_range): ovl_range = getOverlapRange(s_range, target_range) out_range.append(ovl_range) - return out_range + return out_range + def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 2) -> str: """ @@ -728,32 +744,30 @@ def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id vad_start_end_list_raw = [] for line in rttm_lines: start, dur = get_vad_out_from_rttm_line(line) - vad_start_end_list_raw.append([start, start+dur]) + vad_start_end_list_raw.append([start, start + dur]) vad_start_end_list = combine_float_overlaps(vad_start_end_list_raw) if len(vad_start_end_list) == 0: - logging.warning( - f"File ID: {uniq_id}: The VAD label is not containing any speech segments." - ) + logging.warning(f"File ID: {uniq_id}: The VAD label is not containing any speech segments.") elif duration == 0: - logging.warning( - f"File ID: {uniq_id}: The audio file has zero duration." - ) + logging.warning(f"File ID: {uniq_id}: The audio file has zero duration.") else: min_vad, max_vad = getMinMaxOfRangeList(vad_start_end_list) if max_vad > round(offset + duration, deci) or min_vad < offset: logging.warning("RTTM label has been truncated since start is greater than duration of audio file") - overlap_range_list = getSubRangeList(source_range_list=vad_start_end_list, - target_range=[offset, offset+duration]) + overlap_range_list = getSubRangeList( + source_range_list=vad_start_end_list, target_range=[offset, offset + duration] + ) write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id) return manifest_file + def segments_manifest_to_subsegments_manifest( segments_manifest_file: str, subsegments_manifest_file: str = None, window: float = 1.5, shift: float = 0.75, min_subsegment_duration: float = 0.05, - include_uniq_id:bool = False + include_uniq_id: bool = False, ): """ Generate subsegments manifest from segments manifest file @@ -787,7 +801,13 @@ def segments_manifest_to_subsegments_manifest( for subsegment in subsegments: start, dur = subsegment if dur > min_subsegment_duration: - meta = {"audio_filepath": audio, "offset": start, "duration": dur, "label": label, "uniq_id": uniq_id} + meta = { + "audio_filepath": audio, + "offset": start, + "duration": dur, + "label": label, + "uniq_id": uniq_id, + } json.dump(meta, subsegments_manifest) subsegments_manifest.write("\n") From 881a24ce9c863d53cbe357e7388d271812ae50de Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 12:56:01 -0700 Subject: [PATCH 03/10] Added missing docstrings Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/speaker_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 330d6141ab63..20d2fd5434cf 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -466,6 +466,9 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign def get_vad_out_from_rttm_line(rttm_line): + """ + Extract VAD timestamp from the given RTTM lines. + """ vad_out = rttm_line.strip().split() if len(vad_out) > 3: start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] From 1051bf16e017005cf2371819cd050f2816805991 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 17:01:24 -0700 Subject: [PATCH 04/10] Fixed docstrings Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 20d2fd5434cf..565d81919ca6 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -69,10 +69,13 @@ def audio_rttm_map(manifest): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists + Args: + manifest (str): + Path to a file containing keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + Returns: + AUDIO_RTTM_MAP (dict) : + A dictionary with keys of uniq_id, which is being used to map audio files and corresponding rttm files """ AUDIO_RTTM_MAP = {} @@ -459,7 +462,7 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign return metric, mapping_dict else: logging.warning( - "check if each ground truth RTTMs were present in provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diarization Error Rate" ) return None @@ -492,7 +495,7 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): offset (float): The offset value that determines the beginning of the audio stream. duration (float): - The length of audio stream that is expected to be used. + The length of the audio stream that is expected to be used. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] if AUDIO_RTTM_MAP[uniq_id].get('duration', None): @@ -616,7 +619,7 @@ def combine_float_overlaps(ranges): def combine_int_overlaps(ranges): """ - Merge the range pairs if there is overlap exists between the given ranges. + Merge the range pairs if there is overlap between the given ranges. Refer to the original code at https://stackoverflow.com/a/59378428 Args: @@ -689,7 +692,7 @@ def getMinMaxOfRangeList(ranges): def getSubRangeList(target_range, source_range_list) -> List: """ - Get the ranges that has overlaps with the target range from the source_range_list. + Get the ranges that have overlaps with the target range from the source_range_list. Example: source range: @@ -725,7 +728,7 @@ def getSubRangeList(target_range, source_range_list) -> List: def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 2) -> str: """ Write manifest file based on rttm files (or vad table out files). This manifest file would be used by - speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping time stamps. + speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping timestamps. Args: AUDIO_RTTM_MAP (dict): @@ -780,9 +783,11 @@ def segments_manifest_to_subsegments_manifest( window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value + include_uniq_id (bool): if True, add uniq_id variable into for every json dictionary. Returns: - returns path to subsegment manifest file + subsegments_manifest_file (str): + Path to subsegment manifest file """ if subsegments_manifest_file is None: pwd = os.getcwd() From 5f7699e6845fce1a1a6921ecbfec62a71862143e Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 19:43:45 -0700 Subject: [PATCH 05/10] Added mandatory sorting for reduce alg Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/speaker_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 565d81919ca6..e6f111670151 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -632,6 +632,7 @@ def combine_int_overlaps(ranges): Example: [(102, 120)] """ + ranges = sorted(ranges, key=lambda x: x[0]) merged_list = reduce( lambda x, element: x[:-1:] + [(min(*x[-1], *element), max(*x[-1], *element))] if x[-1][1] >= element[0] - 1 From df62ecc02170d42c12e47e8ab3f6dc67d14ac066 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 21:38:20 -0700 Subject: [PATCH 06/10] Added decimal and margin variable Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 68 +++++++++---------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index e6f111670151..1a80058529b2 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -46,7 +46,7 @@ def get_uniqname_from_filepath(filepath): raise TypeError("input must be filepath string") -def get_uniq_id_with_dur(meta, deci=2): +def get_uniq_id_with_dur(meta, deci=3): """ Return basename with offset and end time labels """ @@ -69,13 +69,10 @@ def audio_rttm_map(manifest): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: - manifest (str): - Path to a file containing keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists + Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - Returns: - AUDIO_RTTM_MAP (dict) : - A dictionary with keys of uniq_id, which is being used to map audio files and corresponding rttm files + returns: + AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files """ AUDIO_RTTM_MAP = {} @@ -462,7 +459,7 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign return metric, mapping_dict else: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diarization Error Rate" + "check if each ground truth RTTMs were present in provided manifest file. Skipping calculation of Diariazation Error Rate" ) return None @@ -477,11 +474,11 @@ def get_vad_out_from_rttm_line(rttm_line): start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] else: start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] - start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) + start, dur = float("{:}".format(start)), float("{:}".format(dur)) return start, dur -def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): +def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=3): """ Extract offset and duration information from AUDIO_RTTM_MAP dictionary. If duration information is not specified, a duration value is extracted from the audio file directly. @@ -495,12 +492,12 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): offset (float): The offset value that determines the beginning of the audio stream. duration (float): - The length of the audio stream that is expected to be used. + The length of audio stream that is expected to be used. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] if AUDIO_RTTM_MAP[uniq_id].get('duration', None): - duration = round(AUDIO_RTTM_MAP[uniq_id]['duration'], 2) - offset = round(AUDIO_RTTM_MAP[uniq_id]['offset'], 2) + duration = round(AUDIO_RTTM_MAP[uniq_id]['duration'], deci) + offset = round(AUDIO_RTTM_MAP[uniq_id]['offset'], deci) else: sound = sf.SoundFile(audio_path) duration = sound.frames / sound.samplerate @@ -508,7 +505,7 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id): return offset, duration -def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id): +def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci): """ Write the json dictionary into the specified file. @@ -526,8 +523,8 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, for (stt, end) in overlap_range_list: meta = { "audio_filepath": audio_path, - "offset": round(stt, 2), - "duration": round(end - stt, 2), + "offset": round(stt, deci), + "duration": round(end - stt, deci), "label": 'UNK', "uniq_id": uniq_id, } @@ -591,7 +588,7 @@ def getOverlapRange(rangeA, rangeB): return [max(rangeA[0], rangeB[0]), min(rangeA[1], rangeB[1])] -def combine_float_overlaps(ranges): +def combine_float_overlaps(ranges, deci=5, margin=2): """ Args: ranges(list): @@ -603,23 +600,24 @@ def combine_float_overlaps(ranges): Example: [(10.2, 12.09)] Combine overlaps with floating point numbers. Since neighboring integers are considered as continuous range, - we need to add 1 to the starting range before merging then subtract 1 from the result range. + we need to add margin to the starting range before merging then subtract margin from the result range. """ ranges_int = [] for x in ranges: - stt, end = fl2int(x[0]) + 1, fl2int(x[1]) + stt, end = fl2int(x[0], deci) + margin, fl2int(x[1], deci) if stt == end: - logging.warning(f"The ragne {stt}:{end} is too short to be combined therefore skipped.") + logging.warning(f"The range {stt}:{end} is too short to be combined therefore skipped.") else: ranges_int.append([stt, end]) merged_ranges = combine_int_overlaps(ranges_int) - merged_ranges = [[int2fl(x[0] - 1), int2fl(x[1])] for x in merged_ranges] + merged_ranges = [[int2fl(x[0] - margin, deci), int2fl(x[1], deci)] for x in merged_ranges] return merged_ranges def combine_int_overlaps(ranges): """ - Merge the range pairs if there is overlap between the given ranges. + Merge the range pairs if there is overlap exists between the given ranges. + This algorithm needs a sorted range list in terms of the start time. Refer to the original code at https://stackoverflow.com/a/59378428 Args: @@ -643,21 +641,21 @@ def combine_int_overlaps(ranges): return merged_list -def fl2int(x, deci=2): +def fl2int(x, deci=3): """ Convert floating point number to integer. """ return int(round(x * pow(10, deci))) -def int2fl(x, deci=2): +def int2fl(x, deci=3): """ Convert integer to floating point number. """ return round(float(x / pow(10, deci)), int(deci)) -def getMergedRanges(label_list_A: List, label_list_B: List) -> List: +def getMergedRanges(label_list_A: List, label_list_B: List, deci: int = 3) -> List: """ Calculate the merged ranges between label_list_A and label_list_B. @@ -676,10 +674,10 @@ def getMergedRanges(label_list_A: List, label_list_B: List) -> List: elif label_list_A != [] and label_list_B == []: return label_list_A else: - label_list_A = [[fl2int(x[0]), fl2int(x[1])] for x in label_list_A] - label_list_B = [[fl2int(x[0]), fl2int(x[1])] for x in label_list_B] + label_list_A = [[fl2int(x[0]+1, deci), fl2int(x[1], deci)] for x in label_list_A] + label_list_B = [[fl2int(x[0]+1, deci), fl2int(x[1], deci)] for x in label_list_B] combined = combine_int_overlaps(label_list_A + label_list_B) - return [[int2fl(x[0]), int2fl(x[1])] for x in combined] + return [[int2fl(x[0]-1, deci), int2fl(x[1], deci)] for x in combined] def getMinMaxOfRangeList(ranges): @@ -693,7 +691,7 @@ def getMinMaxOfRangeList(ranges): def getSubRangeList(target_range, source_range_list) -> List: """ - Get the ranges that have overlaps with the target range from the source_range_list. + Get the ranges that has overlaps with the target range from the source_range_list. Example: source range: @@ -726,10 +724,10 @@ def getSubRangeList(target_range, source_range_list) -> List: return out_range -def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 2) -> str: +def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 5) -> str: """ Write manifest file based on rttm files (or vad table out files). This manifest file would be used by - speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping timestamps. + speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping time stamps. Args: AUDIO_RTTM_MAP (dict): @@ -747,7 +745,7 @@ def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id for uniq_id in AUDIO_RTTM_MAP: rttm_file_path = AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] rttm_lines = read_rttm_lines(rttm_file_path) - offset, duration = get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id) + offset, duration = get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci) vad_start_end_list_raw = [] for line in rttm_lines: start, dur = get_vad_out_from_rttm_line(line) @@ -764,7 +762,7 @@ def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id overlap_range_list = getSubRangeList( source_range_list=vad_start_end_list, target_range=[offset, offset + duration] ) - write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id) + write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci) return manifest_file @@ -784,11 +782,9 @@ def segments_manifest_to_subsegments_manifest( window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value - include_uniq_id (bool): if True, add uniq_id variable into for every json dictionary. Returns: - subsegments_manifest_file (str): - Path to subsegment manifest file + returns path to subsegment manifest file """ if subsegments_manifest_file is None: pwd = os.getcwd() From 8d4e46e491d11367ee36280c241147623793b4cc Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Mon, 9 May 2022 23:46:24 -0700 Subject: [PATCH 07/10] Added docstring for margin Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 37 +++++++++++++++---- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 1a80058529b2..b827c3f340d1 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -505,7 +505,7 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=3): return offset, duration -def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci): +def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci=5): """ Write the json dictionary into the specified file. @@ -590,23 +590,40 @@ def getOverlapRange(rangeA, rangeB): def combine_float_overlaps(ranges, deci=5, margin=2): """ + Combine overlaps with floating point numbers. Since neighboring integers are considered as continuous range, + we need to add margin to the starting range before merging then subtract margin from the result range. + Args: - ranges(list): + ranges (list): List containing ranges. Example: [(10.2, 10.83), (10.42, 10.91), (10.45, 12.09)] + deci (int): + Number of rounding decimals + margin (int): + margin for determining overlap of the two ranges when ranges are converted to integer ranges. + Default is margin=2 which follows the python index convention. + + Examples: + If margin is 0: + [(1, 10), (10, 20)] -> [(1, 20)] + [(1, 10), (11, 20)] -> [(1, 20)] + If margin is 1: + [(1, 10), (10, 20)] -> [(1, 20)] + [(1, 10), (11, 20)] -> [(1, 10), (11, 20)] + If margin is 2: + [(1, 10), (10, 20)] -> [(1, 10), (10, 20)] + [(1, 10), (11, 20)] -> [(1, 10), (11, 20)] + Returns: merged_list (list): List containing the combined ranges. Example: [(10.2, 12.09)] - - Combine overlaps with floating point numbers. Since neighboring integers are considered as continuous range, - we need to add margin to the starting range before merging then subtract margin from the result range. """ ranges_int = [] for x in ranges: stt, end = fl2int(x[0], deci) + margin, fl2int(x[1], deci) if stt == end: - logging.warning(f"The range {stt}:{end} is too short to be combined therefore skipped.") + logging.warning(f"The range {stt}:{end} is too short to be combined thus skipped.") else: ranges_int.append([stt, end]) merged_ranges = combine_int_overlaps(ranges_int) @@ -618,12 +635,16 @@ def combine_int_overlaps(ranges): """ Merge the range pairs if there is overlap exists between the given ranges. This algorithm needs a sorted range list in terms of the start time. + Note that neighboring numbers lead to a merged range. + Example: + [(1, 10), (11, 20)] -> [(1, 20)] + Refer to the original code at https://stackoverflow.com/a/59378428 Args: ranges(list): List containing ranges. - Example: [(102, 108), (104, 109), (107, 120)] + Example: [(102, 103), (104, 109), (107, 120)] Returns: merged_list (list): List containing the combined ranges. @@ -750,7 +771,7 @@ def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id for line in rttm_lines: start, dur = get_vad_out_from_rttm_line(line) vad_start_end_list_raw.append([start, start + dur]) - vad_start_end_list = combine_float_overlaps(vad_start_end_list_raw) + vad_start_end_list = combine_float_overlaps(vad_start_end_list_raw, deci) if len(vad_start_end_list) == 0: logging.warning(f"File ID: {uniq_id}: The VAD label is not containing any speech segments.") elif duration == 0: From 9e24044df8a0094a0fdf43b991166896773f5d85 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Tue, 10 May 2022 00:03:00 -0700 Subject: [PATCH 08/10] style fix Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/speaker_utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index b827c3f340d1..460562c33c56 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -17,7 +17,7 @@ import os from copy import deepcopy from functools import reduce -from typing import Dict, List, Optional +from typing import List import numpy as np import omegaconf @@ -51,7 +51,7 @@ def get_uniq_id_with_dur(meta, deci=3): Return basename with offset and end time labels """ bare_uniq_id = meta['audio_filepath'].split('/')[-1].split('.wav')[0] - if meta['offset'] == None and meta['duration'] == None: + if meta['offset'] is None and meta['duration'] is None: return bare_uniq_id if meta['offset']: offset = str(int(round(meta['offset'], deci) * pow(10, deci))) @@ -478,7 +478,7 @@ def get_vad_out_from_rttm_line(rttm_line): return start, dur -def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=3): +def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=5): """ Extract offset and duration information from AUDIO_RTTM_MAP dictionary. If duration information is not specified, a duration value is extracted from the audio file directly. @@ -695,10 +695,10 @@ def getMergedRanges(label_list_A: List, label_list_B: List, deci: int = 3) -> Li elif label_list_A != [] and label_list_B == []: return label_list_A else: - label_list_A = [[fl2int(x[0]+1, deci), fl2int(x[1], deci)] for x in label_list_A] - label_list_B = [[fl2int(x[0]+1, deci), fl2int(x[1], deci)] for x in label_list_B] + label_list_A = [[fl2int(x[0] + 1, deci), fl2int(x[1], deci)] for x in label_list_A] + label_list_B = [[fl2int(x[0] + 1, deci), fl2int(x[1], deci)] for x in label_list_B] combined = combine_int_overlaps(label_list_A + label_list_B) - return [[int2fl(x[0]-1, deci), int2fl(x[1], deci)] for x in combined] + return [[int2fl(x[0] - 1, deci), int2fl(x[1], deci)] for x in combined] def getMinMaxOfRangeList(ranges): @@ -748,7 +748,8 @@ def getSubRangeList(target_range, source_range_list) -> List: def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 5) -> str: """ Write manifest file based on rttm files (or vad table out files). This manifest file would be used by - speaker diarizer to compute embeddings and cluster them. This function also takes care of overlapping time stamps. + speaker diarizer to compute embeddings and cluster them. This function takes care of overlapping VAD timestamps + and trimmed with the given offset and duration value. Args: AUDIO_RTTM_MAP (dict): From 3b27b0e0fa1321db25579a8164472d0a4e3e8074 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 18 May 2022 20:10:28 -0700 Subject: [PATCH 09/10] reflected review and Added use_single_scale Signed-off-by: Taejin Park --- .../asr/parts/utils/speaker_utils.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 460562c33c56..bd824da39096 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -40,8 +40,8 @@ def get_uniqname_from_filepath(filepath): Return base name from provided filepath """ if type(filepath) is str: - basename = os.path.basename(filepath).rsplit('.', 1)[0] - return basename + uniq_id = os.path.splitext(os.path.basename(filepath))[0] + return uniq_id else: raise TypeError("input must be filepath string") @@ -50,7 +50,7 @@ def get_uniq_id_with_dur(meta, deci=3): """ Return basename with offset and end time labels """ - bare_uniq_id = meta['audio_filepath'].split('/')[-1].split('.wav')[0] + bare_uniq_id = get_uniqname_from_filepath(meta['audio_filepath']) if meta['offset'] is None and meta['duration'] is None: return bare_uniq_id if meta['offset']: @@ -105,7 +105,6 @@ def audio_rttm_map(manifest): return AUDIO_RTTM_MAP - def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights): """ Check whether multiscale parameters are provided correctly. window_lengths_in_sec, shift_lengfhs_in_sec and @@ -161,7 +160,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ else: shift_length_check = window_lengths[0] > shift_lengths[0] - multiscale_args_dict = {} + multiscale_args_dict = {'use_single_scale_clustering' : False} if all([length_check, scale_order_check, shift_length_check]) == True: if len(window_lengths) > 1: multiscale_args_dict['scale_dict'] = { @@ -181,7 +180,6 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ else: return None - def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_args_dict): """ The embeddings and timestamps in multiscale_embeddings_and_timestamps dictionary are @@ -202,11 +200,18 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg uniq_id: {'multiscale_weights': [], 'scale_dict': {}} for uniq_id in multiscale_embeddings_and_timestamps[0][0].keys() } - for scale_idx in sorted(multiscale_args_dict['scale_dict'].keys()): + if multiscale_args_dict['use_single_scale_clustering']: + _multiscale_args_dict = deepcopy(multiscale_args_dict) + _multiscale_args_dict['scale_dict'] = { 0 : multiscale_args_dict['scale_dict'][0] } + _multiscale_args_dict['multiscale_weights'] = multiscale_args_dict['multiscale_weights'][:1] + else: + _multiscale_args_dict = multiscale_args_dict + + for scale_idx in sorted(_multiscale_args_dict['scale_dict'].keys()): embeddings, time_stamps = multiscale_embeddings_and_timestamps[scale_idx] for uniq_id in embeddings.keys(): embs_and_timestamps[uniq_id]['multiscale_weights'] = ( - torch.tensor(multiscale_args_dict['multiscale_weights']).unsqueeze(0).half() + torch.tensor(_multiscale_args_dict['multiscale_weights']).unsqueeze(0).half() ) assert len(embeddings[uniq_id]) == len(time_stamps[uniq_id]) embs_and_timestamps[uniq_id]['scale_dict'][scale_idx] = { @@ -216,7 +221,6 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg return embs_and_timestamps - def get_contiguous_stamps(stamps): """ Return contiguous time stamps @@ -507,7 +511,7 @@ def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=5): def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci=5): """ - Write the json dictionary into the specified file. + Write the json dictionary into the specified manifest file. Args: outfile: From 2eda6be3014d96909466402617c79857e858f026 Mon Sep 17 00:00:00 2001 From: Taejin Park Date: Wed, 18 May 2022 23:31:47 -0700 Subject: [PATCH 10/10] Style fix Signed-off-by: Taejin Park --- nemo/collections/asr/parts/utils/speaker_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index bd824da39096..a4a88e825a43 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -105,6 +105,7 @@ def audio_rttm_map(manifest): return AUDIO_RTTM_MAP + def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights): """ Check whether multiscale parameters are provided correctly. window_lengths_in_sec, shift_lengfhs_in_sec and @@ -160,7 +161,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ else: shift_length_check = window_lengths[0] > shift_lengths[0] - multiscale_args_dict = {'use_single_scale_clustering' : False} + multiscale_args_dict = {'use_single_scale_clustering': False} if all([length_check, scale_order_check, shift_length_check]) == True: if len(window_lengths) > 1: multiscale_args_dict['scale_dict'] = { @@ -180,6 +181,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ else: return None + def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_args_dict): """ The embeddings and timestamps in multiscale_embeddings_and_timestamps dictionary are @@ -202,7 +204,7 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg } if multiscale_args_dict['use_single_scale_clustering']: _multiscale_args_dict = deepcopy(multiscale_args_dict) - _multiscale_args_dict['scale_dict'] = { 0 : multiscale_args_dict['scale_dict'][0] } + _multiscale_args_dict['scale_dict'] = {0: multiscale_args_dict['scale_dict'][0]} _multiscale_args_dict['multiscale_weights'] = multiscale_args_dict['multiscale_weights'][:1] else: _multiscale_args_dict = multiscale_args_dict @@ -221,6 +223,7 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg return embs_and_timestamps + def get_contiguous_stamps(stamps): """ Return contiguous time stamps