diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 566b26ad054f..a4a88e825a43 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json import math import os from copy import deepcopy +from functools import reduce +from typing import List import numpy as np import omegaconf @@ -27,26 +30,46 @@ from nemo.collections.asr.parts.utils.nmesc_clustering import COSclustering from nemo.utils import logging - """ This file contains all the utility functions required for speaker embeddings part in diarization scripts """ def get_uniqname_from_filepath(filepath): - "return base name from provided filepath" + """ + Return base name from provided filepath + """ if type(filepath) is str: - basename = os.path.basename(filepath).rsplit('.', 1)[0] - return basename + uniq_id = os.path.splitext(os.path.basename(filepath))[0] + return uniq_id else: raise TypeError("input must be filepath string") +def get_uniq_id_with_dur(meta, deci=3): + """ + Return basename with offset and end time labels + """ + bare_uniq_id = get_uniqname_from_filepath(meta['audio_filepath']) + if meta['offset'] is None and meta['duration'] is None: + return bare_uniq_id + if meta['offset']: + offset = str(int(round(meta['offset'], deci) * pow(10, deci))) + else: + offset = 0 + if meta['duration']: + endtime = str(int(round(meta['offset'] + meta['duration'], deci) * pow(10, deci))) + else: + endtime = 'NULL' + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def audio_rttm_map(manifest): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - input: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists + Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists returns: AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files @@ -63,6 +86,7 @@ def audio_rttm_map(manifest): meta = { 'audio_filepath': dic['audio_filepath'], 'rttm_filepath': dic.get('rttm_filepath', None), + 'offset': dic.get('offset', None), 'duration': dic.get('duration', None), 'text': dic.get('text', None), 'num_speakers': dic.get('num_speakers', None), @@ -98,6 +122,9 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ parameters.window_length_in_sec=[1.5,1.0,0.5] parameters.shift_length_in_sec=[0.75,0.5,0.25] parameters.multiscale_weights=[0.33,0.33,0.33] + + In addition, you can also specify session-by-session multiscale weight. In this case, each dictionary key + points to different weights. """ checkFloatConfig = [type(var) == float for var in (window_lengths_in_sec, shift_lengths_in_sec)] checkListConfig = [ @@ -134,7 +161,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ else: shift_length_check = window_lengths[0] > shift_lengths[0] - multiscale_args_dict = {} + multiscale_args_dict = {'use_single_scale_clustering': False} if all([length_check, scale_order_check, shift_length_check]) == True: if len(window_lengths) > 1: multiscale_args_dict['scale_dict'] = { @@ -149,7 +176,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ elif any(checkListConfig): raise ValueError( - 'You must provide list config for all three parameters: window, shift and multiscale weights.' + 'You must provide a list config for all three parameters: window, shift and multiscale weights.' ) else: return None @@ -175,11 +202,18 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg uniq_id: {'multiscale_weights': [], 'scale_dict': {}} for uniq_id in multiscale_embeddings_and_timestamps[0][0].keys() } - for scale_idx in sorted(multiscale_args_dict['scale_dict'].keys()): + if multiscale_args_dict['use_single_scale_clustering']: + _multiscale_args_dict = deepcopy(multiscale_args_dict) + _multiscale_args_dict['scale_dict'] = {0: multiscale_args_dict['scale_dict'][0]} + _multiscale_args_dict['multiscale_weights'] = multiscale_args_dict['multiscale_weights'][:1] + else: + _multiscale_args_dict = multiscale_args_dict + + for scale_idx in sorted(_multiscale_args_dict['scale_dict'].keys()): embeddings, time_stamps = multiscale_embeddings_and_timestamps[scale_idx] for uniq_id in embeddings.keys(): embs_and_timestamps[uniq_id]['multiscale_weights'] = ( - torch.tensor(multiscale_args_dict['multiscale_weights']).unsqueeze(0).half() + torch.tensor(_multiscale_args_dict['multiscale_weights']).unsqueeze(0).half() ) assert len(embeddings[uniq_id]) == len(time_stamps[uniq_id]) embs_and_timestamps[uniq_id]['scale_dict'][scale_idx] = { @@ -212,7 +246,7 @@ def get_contiguous_stamps(stamps): def merge_stamps(lines): """ - merge time stamps of same speaker + Merge time stamps of the same speaker. """ stamps = deepcopy(lines) overlap_stamps = [] @@ -232,7 +266,7 @@ def merge_stamps(lines): def labels_to_pyannote_object(labels, uniq_name=''): """ - converts labels to pyannote object to calculate DER and for visualization + Convert the given labels to pyannote object to calculate DER and for visualization """ annotation = Annotation(uri=uniq_name) for label in labels: @@ -245,7 +279,7 @@ def labels_to_pyannote_object(labels, uniq_name=''): def uem_timeline_from_file(uem_file, uniq_name=''): """ - outputs pyannote timeline segments for uem file + Generate pyannote timeline segments for uem file file format UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME @@ -263,7 +297,7 @@ def uem_timeline_from_file(uem_file, uniq_name=''): def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): """ - write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels + Write rttm file with uniq_id name in out_rttm_dir with time_stamps in labels """ filename = os.path.join(out_rttm_dir, uniq_id + '.rttm') with open(filename, 'w') as f: @@ -280,7 +314,7 @@ def labels_to_rttmfile(labels, uniq_id, out_rttm_dir): def rttm_to_labels(rttm_filename): """ - prepares time stamps label list from rttm file + Prepare time stamps label list from rttm file """ labels = [] with open(rttm_filename, 'r') as f: @@ -293,7 +327,6 @@ def rttm_to_labels(rttm_filename): def write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir): """ - Write cluster labels that are generated from clustering into a file. Args: base_scale_idx (int): The base scale index which is the highest scale index. @@ -310,7 +343,7 @@ def write_cluster_labels(base_scale_idx, lines_cluster_labels, out_rttm_dir): def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, clustering_params): """ - performs spectral clustering on embeddings with time stamps generated from VAD output + Performs spectral clustering on embeddings with time stamps generated from VAD output Args: embs_and_timestamps (dict): This dictionary contains the following items indexed by unique IDs. @@ -334,7 +367,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste cuda = True if not torch.cuda.is_available(): - logging.warning("cuda=False, using CPU for Eigen decompostion. This might slow down the clustering process.") + logging.warning("cuda=False, using CPU for Eigen decomposition. This might slow down the clustering process.") cuda = False for uniq_id, value in tqdm(AUDIO_RTTM_MAP.items()): @@ -363,6 +396,7 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste lines[idx] += tag a = get_contiguous_stamps(lines) labels = merge_stamps(a) + if out_rttm_dir: labels_to_rttmfile(labels, uniq_id, out_rttm_dir) lines_cluster_labels.extend([f'{uniq_id} {seg_line}\n' for seg_line in lines]) @@ -386,16 +420,16 @@ def perform_clustering(embs_and_timestamps, AUDIO_RTTM_MAP, out_rttm_dir, cluste def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True): """ - calculates DER, CER, FA and MISS + Calculates DER, CER, FA and MISS Args: - AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath - all_reference (list[uniq_name,Annotation]): reference annotations for score calculation - all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation + AUDIO_RTTM_MAP : Dictionary containing information provided from manifestpath + all_reference (list[uniq_name,Annotation]): reference annotations for score calculation + all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. - mapping (dict): Mapping dict containing the mapping speaker label for each audio input + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. + mapping (dict): Mapping dict containing the mapping speaker label for each audio input < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -438,88 +472,326 @@ def score_labels(AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ign return None -def write_rttm2manifest(AUDIO_RTTM_MAP, manifest_file): +def get_vad_out_from_rttm_line(rttm_line): + """ + Extract VAD timestamp from the given RTTM lines. + """ + vad_out = rttm_line.strip().split() + if len(vad_out) > 3: + start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] + else: + start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] + start, dur = float("{:}".format(start)), float("{:}".format(dur)) + return start, dur + + +def get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci=5): + """ + Extract offset and duration information from AUDIO_RTTM_MAP dictionary. + If duration information is not specified, a duration value is extracted from the audio file directly. + + Args: + AUDIO_RTTM_MAP (dict): + Dictionary containing RTTM file information, which is indexed by unique file id. + uniq_id (str): + Unique file id + Returns: + offset (float): + The offset value that determines the beginning of the audio stream. + duration (float): + The length of audio stream that is expected to be used. + """ + audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] + if AUDIO_RTTM_MAP[uniq_id].get('duration', None): + duration = round(AUDIO_RTTM_MAP[uniq_id]['duration'], deci) + offset = round(AUDIO_RTTM_MAP[uniq_id]['offset'], deci) + else: + sound = sf.SoundFile(audio_path) + duration = sound.frames / sound.samplerate + offset = 0.0 + return offset, duration + + +def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci=5): + """ + Write the json dictionary into the specified manifest file. + + Args: + outfile: + File pointer that indicates output file path. + AUDIO_RTTM_MAP (dict): + Dictionary containing the input manifest information + uniq_id (str): + Unique file id + overlap_range_list (list): + List containing overlapping ranges between target and source. + """ + audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] + for (stt, end) in overlap_range_list: + meta = { + "audio_filepath": audio_path, + "offset": round(stt, deci), + "duration": round(end - stt, deci), + "label": 'UNK', + "uniq_id": uniq_id, + } + json.dump(meta, outfile) + outfile.write("\n") + + +def read_rttm_lines(rttm_file_path): + """ + Read rttm files and return the rttm information lines. + + Args: + rttm_file_path (str): + + Returns: + lines (list): + List containing the strings from the RTTM file. + """ + if rttm_file_path and os.path.exists(rttm_file_path): + f = open(rttm_file_path, 'r') + else: + raise FileNotFoundError( + "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( + rttm_file_path + ) + ) + lines = f.readlines() + return lines + + +def isOverlap(rangeA, rangeB): + """ + Check whether two ranges have overlap. + Args: + rangeA (list, tuple): + List or tuple containing start and end value in float. + rangeB (list, tuple): + List or tuple containing start and end value in float. + Returns: + (bool): + Boolean that indicates whether the input ranges have overlap. + """ + start1, end1 = rangeA + start2, end2 = rangeB + return end1 > start2 and end2 > start1 + + +def getOverlapRange(rangeA, rangeB): + """ + Calculate the overlapping range between rangeA and rangeB. + Args: + rangeA (list, tuple): + List or tuple containing start and end value in float. + rangeB (list, tuple): + List or tuple containing start and end value in float. + Returns: + (list): + List containing the overlapping range between rangeA and rangeB. + """ + assert isOverlap(rangeA, rangeB), f"There is no overlap between rangeA:{rangeA} and rangeB:{rangeB}" + return [max(rangeA[0], rangeB[0]), min(rangeA[1], rangeB[1])] + + +def combine_float_overlaps(ranges, deci=5, margin=2): + """ + Combine overlaps with floating point numbers. Since neighboring integers are considered as continuous range, + we need to add margin to the starting range before merging then subtract margin from the result range. + + Args: + ranges (list): + List containing ranges. + Example: [(10.2, 10.83), (10.42, 10.91), (10.45, 12.09)] + deci (int): + Number of rounding decimals + margin (int): + margin for determining overlap of the two ranges when ranges are converted to integer ranges. + Default is margin=2 which follows the python index convention. + + Examples: + If margin is 0: + [(1, 10), (10, 20)] -> [(1, 20)] + [(1, 10), (11, 20)] -> [(1, 20)] + If margin is 1: + [(1, 10), (10, 20)] -> [(1, 20)] + [(1, 10), (11, 20)] -> [(1, 10), (11, 20)] + If margin is 2: + [(1, 10), (10, 20)] -> [(1, 10), (10, 20)] + [(1, 10), (11, 20)] -> [(1, 10), (11, 20)] + + Returns: + merged_list (list): + List containing the combined ranges. + Example: [(10.2, 12.09)] + """ + ranges_int = [] + for x in ranges: + stt, end = fl2int(x[0], deci) + margin, fl2int(x[1], deci) + if stt == end: + logging.warning(f"The range {stt}:{end} is too short to be combined thus skipped.") + else: + ranges_int.append([stt, end]) + merged_ranges = combine_int_overlaps(ranges_int) + merged_ranges = [[int2fl(x[0] - margin, deci), int2fl(x[1], deci)] for x in merged_ranges] + return merged_ranges + + +def combine_int_overlaps(ranges): """ - writes manifest file based on rttm files (or vad table out files). This manifest file would be used by - speaker diarizer to compute embeddings and cluster them. This function also takes care of overlap time stamps + Merge the range pairs if there is overlap exists between the given ranges. + This algorithm needs a sorted range list in terms of the start time. + Note that neighboring numbers lead to a merged range. + Example: + [(1, 10), (11, 20)] -> [(1, 20)] + + Refer to the original code at https://stackoverflow.com/a/59378428 Args: - AUDIO_RTTM_MAP: dict containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, - these are used to extract oracle vad timestamps. - manifest (str): path to write manifest file + ranges(list): + List containing ranges. + Example: [(102, 103), (104, 109), (107, 120)] + Returns: + merged_list (list): + List containing the combined ranges. + Example: [(102, 120)] + + """ + ranges = sorted(ranges, key=lambda x: x[0]) + merged_list = reduce( + lambda x, element: x[:-1:] + [(min(*x[-1], *element), max(*x[-1], *element))] + if x[-1][1] >= element[0] - 1 + else x + [element], + ranges[1::], + ranges[0:1], + ) + return merged_list + +def fl2int(x, deci=3): + """ + Convert floating point number to integer. + """ + return int(round(x * pow(10, deci))) + + +def int2fl(x, deci=3): + """ + Convert integer to floating point number. + """ + return round(float(x / pow(10, deci)), int(deci)) + + +def getMergedRanges(label_list_A: List, label_list_B: List, deci: int = 3) -> List: + """ + Calculate the merged ranges between label_list_A and label_list_B. + + Args: + label_list_A (list): + List containing ranges (start and end values) + label_list_B (list): + List containing ranges (start and end values) Returns: - manifest (str): path to write manifest file + (list): + List containing the merged ranges + + """ + if label_list_A == [] and label_list_B != []: + return label_list_B + elif label_list_A != [] and label_list_B == []: + return label_list_A + else: + label_list_A = [[fl2int(x[0] + 1, deci), fl2int(x[1], deci)] for x in label_list_A] + label_list_B = [[fl2int(x[0] + 1, deci), fl2int(x[1], deci)] for x in label_list_B] + combined = combine_int_overlaps(label_list_A + label_list_B) + return [[int2fl(x[0] - 1, deci), int2fl(x[1], deci)] for x in combined] + + +def getMinMaxOfRangeList(ranges): + """ + Get the min and max of a given range list. + """ + _max = max([x[1] for x in ranges]) + _min = min([x[0] for x in ranges]) + return _min, _max + + +def getSubRangeList(target_range, source_range_list) -> List: + """ + Get the ranges that has overlaps with the target range from the source_range_list. + + Example: + source range: + |===--======---=====---====--| + target range: + |--------================----| + out_range: + |--------===---=====---==----| + + Args: + target_range (list): + A range (a start and end value pair) that defines the target range we want to select. + target_range = [(start, end)] + source_range_list (list): + List containing the subranges that need to be selected. + source_ragne = [(start0, end0), (start1, end1), ...] + Returns: + out_range (list): + List containing the overlap between target_range and + source_range_list. + """ + if target_range == []: + return [] + else: + out_range = [] + for s_range in source_range_list: + if isOverlap(s_range, target_range): + ovl_range = getOverlapRange(s_range, target_range) + out_range.append(ovl_range) + return out_range + + +def write_rttm2manifest(AUDIO_RTTM_MAP: str, manifest_file: str, include_uniq_id: bool = False, deci: int = 5) -> str: + """ + Write manifest file based on rttm files (or vad table out files). This manifest file would be used by + speaker diarizer to compute embeddings and cluster them. This function takes care of overlapping VAD timestamps + and trimmed with the given offset and duration value. + + Args: + AUDIO_RTTM_MAP (dict): + Dictionary containing keys to uniqnames, that contains audio filepath and rttm_filepath as its contents, + these are used to extract oracle vad timestamps. + manifest (str): + The path to the output manifest file. + + Returns: + manifest (str): + The path to the output manifest file. """ with open(manifest_file, 'w') as outfile: - for key in AUDIO_RTTM_MAP: - rttm_filename = AUDIO_RTTM_MAP[key]['rttm_filepath'] - if rttm_filename and os.path.exists(rttm_filename): - f = open(rttm_filename, 'r') + for uniq_id in AUDIO_RTTM_MAP: + rttm_file_path = AUDIO_RTTM_MAP[uniq_id]['rttm_filepath'] + rttm_lines = read_rttm_lines(rttm_file_path) + offset, duration = get_offset_and_duration(AUDIO_RTTM_MAP, uniq_id, deci) + vad_start_end_list_raw = [] + for line in rttm_lines: + start, dur = get_vad_out_from_rttm_line(line) + vad_start_end_list_raw.append([start, start + dur]) + vad_start_end_list = combine_float_overlaps(vad_start_end_list_raw, deci) + if len(vad_start_end_list) == 0: + logging.warning(f"File ID: {uniq_id}: The VAD label is not containing any speech segments.") + elif duration == 0: + logging.warning(f"File ID: {uniq_id}: The audio file has zero duration.") else: - raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_filename - ) + min_vad, max_vad = getMinMaxOfRangeList(vad_start_end_list) + if max_vad > round(offset + duration, deci) or min_vad < offset: + logging.warning("RTTM label has been truncated since start is greater than duration of audio file") + overlap_range_list = getSubRangeList( + source_range_list=vad_start_end_list, target_range=[offset, offset + duration] ) - - audio_path = AUDIO_RTTM_MAP[key]['audio_filepath'] - if AUDIO_RTTM_MAP[key].get('duration', None): - max_duration = AUDIO_RTTM_MAP[key]['duration'] - else: - sound = sf.SoundFile(audio_path) - max_duration = sound.frames / sound.samplerate - - lines = f.readlines() - time_tup = (-1, -1) - for line in lines: - vad_out = line.strip().split() - if len(vad_out) > 3: - start, dur, _ = float(vad_out[3]), float(vad_out[4]), vad_out[7] - else: - start, dur, _ = float(vad_out[0]), float(vad_out[1]), vad_out[2] - start, dur = float("{:.3f}".format(start)), float("{:.3f}".format(dur)) - - if start == 0 and dur == 0: # No speech segments - continue - else: - - if time_tup[0] >= 0 and start > time_tup[1]: - dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - if time_tup[0] < max_duration and dur2 > 0: - meta = { - "audio_filepath": audio_path, - "offset": time_tup[0], - "duration": dur2, - "label": 'UNK', - } - json.dump(meta, outfile) - outfile.write("\n") - else: - logging.warning( - "RTTM label has been truncated since start is greater than duration of audio file" - ) - time_tup = (start, start + dur) - else: - if time_tup[0] == -1: - end_time = start + dur - if end_time > max_duration: - end_time = max_duration - time_tup = (start, end_time) - else: - end_time = max(time_tup[1], start + dur) - if end_time > max_duration: - end_time = max_duration - time_tup = (min(time_tup[0], start), end_time) - dur2 = float("{:.3f}".format(time_tup[1] - time_tup[0])) - if time_tup[0] < max_duration and dur2 > 0: - meta = {"audio_filepath": audio_path, "offset": time_tup[0], "duration": dur2, "label": 'UNK'} - json.dump(meta, outfile) - outfile.write("\n") - else: - logging.warning("RTTM label has been truncated since start is greater than duration of audio file") - f.close() + write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, include_uniq_id, deci) return manifest_file @@ -529,18 +801,18 @@ def segments_manifest_to_subsegments_manifest( window: float = 1.5, shift: float = 0.75, min_subsegment_duration: float = 0.05, + include_uniq_id: bool = False, ): """ Generate subsegments manifest from segments manifest file - Args - input: + Args: segments_manifest file (str): path to segments manifest file, typically from VAD output subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value - output: + Returns: returns path to subsegment manifest file """ if subsegments_manifest_file is None: @@ -556,11 +828,21 @@ def segments_manifest_to_subsegments_manifest( dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) - + if include_uniq_id and 'uniq_id' in dic: + uniq_id = dic['uniq_id'] + else: + uniq_id = None for subsegment in subsegments: start, dur = subsegment if dur > min_subsegment_duration: - meta = {"audio_filepath": audio, "offset": start, "duration": dur, "label": label} + meta = { + "audio_filepath": audio, + "offset": start, + "duration": dur, + "label": label, + "uniq_id": uniq_id, + } + json.dump(meta, subsegments_manifest) subsegments_manifest.write("\n") @@ -569,14 +851,13 @@ def segments_manifest_to_subsegments_manifest( def get_subsegments(offset: float, window: float, shift: float, duration: float): """ - return subsegments from a segment of audio file - Args - input: + Return subsegments from a segment of audio file + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment - output: + Returns: subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment """ subsegments = [] @@ -596,10 +877,10 @@ def get_subsegments(offset: float, window: float, shift: float, duration: float) def embedding_normalize(embs, use_std=False, eps=1e-10): """ - mean and l2 length normalize the input speaker embeddings - input: + Mean and l2 length normalize the input speaker embeddings + Args: embs: embeddings of shape (Batch,emb_size) - output: + Returns: embs: normalized embeddings of shape (Batch,emb_size) """ embs = embs - embs.mean(axis=0)