diff --git a/basics/base_binarizer.py b/basics/base_binarizer.py index ddad6e02e..397bd8305 100644 --- a/basics/base_binarizer.py +++ b/basics/base_binarizer.py @@ -13,9 +13,8 @@ from utils.hparams import hparams from utils.indexed_datasets import IndexedDatasetBuilder from utils.multiprocess_utils import chunked_multiprocess_run -from utils.phoneme_utils import build_phoneme_list, locate_dictionary +from utils.phoneme_utils import load_phoneme_dictionary from utils.plot import distribution_to_figure -from utils.text_encoder import TokenTextEncoder class BinarizationError(Exception): @@ -44,13 +43,11 @@ class BaseBinarizer: the phoneme set. """ - def __init__(self, data_dir=None, data_attrs=None): - if data_dir is None: - data_dir = hparams['raw_data_dir'] - if not isinstance(data_dir, list): - data_dir = [data_dir] - - self.raw_data_dirs = [pathlib.Path(d) for d in data_dir] + def __init__(self, datasets=None, data_attrs=None): + if datasets is None: + datasets = hparams['datasets'] + self.datasets = datasets + self.raw_data_dirs = [pathlib.Path(ds['raw_data_dir']) for ds in self.datasets] self.binary_data_dir = pathlib.Path(hparams['binary_data_dir']) self.data_attrs = [] if data_attrs is None else data_attrs @@ -58,59 +55,76 @@ def __init__(self, data_dir=None, data_attrs=None): self.augmentation_args = hparams.get('augmentation_args', {}) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - self.spk_map = None - self.spk_ids = hparams['spk_ids'] - self.speakers = hparams['speakers'] + self.spk_map = {} + self.spk_ids = None self.build_spk_map() + self.lang_map = {} + self.dictionaries = hparams['dictionaries'] + self.build_lang_map() + self.items = {} self.item_names: list = None self._train_item_names: list = None self._valid_item_names: list = None - self.phone_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) + self.phoneme_dictionary = load_phoneme_dictionary() self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] def build_spk_map(self): - assert isinstance(self.speakers, list), 'Speakers must be a list' - assert len(self.speakers) == len(self.raw_data_dirs), \ - 'Number of raw data dirs must equal number of speaker names!' - if len(self.spk_ids) == 0: - self.spk_ids = list(range(len(self.raw_data_dirs))) - else: - assert len(self.spk_ids) == len(self.raw_data_dirs), \ - 'Length of explicitly given spk_ids must equal the number of raw datasets.' - assert max(self.spk_ids) < hparams['num_spk'], \ - f'Index in spk_id sequence {self.spk_ids} is out of range. All values should be smaller than num_spk.' - - self.spk_map = {} - for spk_name, spk_id in zip(self.speakers, self.spk_ids): + spk_ids = [ds.get('spk_id') for ds in self.datasets] + assigned_spk_ids = {spk_id for spk_id in spk_ids if spk_id is not None} + idx = 0 + for i in range(len(spk_ids)): + if spk_ids[i] is not None: + continue + while idx in assigned_spk_ids: + idx += 1 + spk_ids[i] = idx + assigned_spk_ids.add(idx) + assert max(spk_ids) < hparams['num_spk'], \ + f'Index in spk_id sequence {spk_ids} is out of range. All values should be smaller than num_spk.' + + for spk_id, dataset in zip(spk_ids, self.datasets): + spk_name = dataset['speaker'] if spk_name in self.spk_map and self.spk_map[spk_name] != spk_id: raise ValueError(f'Invalid speaker ID assignment. Name \'{spk_name}\' is assigned ' f'with different speaker IDs: {self.spk_map[spk_name]} and {spk_id}.') self.spk_map[spk_name] = spk_id + self.spk_ids = spk_ids print("| spk_map: ", self.spk_map) - def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): + def build_lang_map(self): + assert len(self.dictionaries.keys()) <= hparams['num_lang'], \ + 'Number of languages must not be greater than num_lang!' + for dataset in self.datasets: + assert dataset['language'] in self.dictionaries, f'Unrecognized language name: {dataset["language"]}' + + for lang_id, lang_name in enumerate(sorted(self.dictionaries.keys()), start=1): + self.lang_map[lang_name] = lang_id + + print("| lang_map: ", self.lang_map) + + def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang) -> dict: raise NotImplementedError() - def split_train_valid_set(self, item_names): + def split_train_valid_set(self, prefixes: list): """ Split the dataset into training set and validation set. :return: train_item_names, valid_item_names """ - prefixes = {str(pr): 1 for pr in hparams['test_prefixes']} + prefixes = {str(pr): 1 for pr in prefixes} valid_item_names = {} # Add prefixes that specified speaker index and matches exactly item name to test set for prefix in deepcopy(prefixes): - if prefix in item_names: + if prefix in self.item_names: valid_item_names[prefix] = 1 prefixes.pop(prefix) # Add prefixes that exactly matches item name without speaker id to test set for prefix in deepcopy(prefixes): matched = False - for name in item_names: + for name in self.item_names: if name.split(':')[-1] == prefix: valid_item_names[name] = 1 matched = True @@ -119,7 +133,7 @@ def split_train_valid_set(self, item_names): # Add names with one of the remaining prefixes to test set for prefix in deepcopy(prefixes): matched = False - for name in item_names: + for name in self.item_names: if name.startswith(prefix): valid_item_names[name] = 1 matched = True @@ -127,7 +141,7 @@ def split_train_valid_set(self, item_names): prefixes.pop(prefix) for prefix in deepcopy(prefixes): matched = False - for name in item_names: + for name in self.item_names: if name.split(':')[-1].startswith(prefix): valid_item_names[name] = 1 matched = True @@ -143,7 +157,7 @@ def split_train_valid_set(self, item_names): valid_item_names = list(valid_item_names.keys()) assert len(valid_item_names) > 0, 'Validation set is empty!' - train_item_names = [x for x in item_names if x not in set(valid_item_names)] + train_item_names = [x for x in self.item_names if x not in set(valid_item_names)] assert len(train_item_names) > 0, 'Training set is empty!' return train_item_names, valid_item_names @@ -167,21 +181,34 @@ def meta_data_iterator(self, prefix): def process(self): # load each dataset - for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs): - self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id) + test_prefixes = [] + for ds_id, dataset in enumerate(self.datasets): + items = self.load_meta_data( + pathlib.Path(dataset['raw_data_dir']), + ds_id=ds_id, spk=dataset['speaker'], lang=dataset['language'] + ) + self.items.update(items) + test_prefixes.extend( + f'{ds_id}:{prefix}' + for prefix in dataset.get('test_prefixes', []) + ) self.item_names = sorted(list(self.items.keys())) - self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names) + self._train_item_names, self._valid_item_names = self.split_train_valid_set(test_prefixes) if self.binarization_args['shuffle']: random.shuffle(self.item_names) self.binary_data_dir.mkdir(parents=True, exist_ok=True) - # Copy spk_map and dictionary to binary data dir + # Copy spk_map, lang_map and dictionary to binary data dir spk_map_fn = self.binary_data_dir / 'spk_map.json' with open(spk_map_fn, 'w', encoding='utf-8') as f: - json.dump(self.spk_map, f) - shutil.copy(locate_dictionary(), self.binary_data_dir / 'dictionary.txt') + json.dump(self.spk_map, f, ensure_ascii=False) + lang_map_fn = self.binary_data_dir / 'lang_map.json' + with open(lang_map_fn, 'w', encoding='utf-8') as f: + json.dump(self.lang_map, f, ensure_ascii=False) + for lang, dict_path in hparams['dictionaries'].items(): + shutil.copy(dict_path, self.binary_data_dir / f'dictionary-{lang}.txt') self.check_coverage() # Process valid set and train set @@ -197,40 +224,47 @@ def process(self): def check_coverage(self): # Group by phonemes in the dictionary. - ph_required = set(build_phoneme_list()) - phoneme_map = {} - for ph in ph_required: - phoneme_map[ph] = 0 - ph_occurred = [] + ph_idx_required = set(range(1, len(self.phoneme_dictionary))) + ph_idx_occurred = set() + ph_idx_count_map = { + idx: 0 + for idx in ph_idx_required + } # Load and count those phones that appear in the actual data for item_name in self.items: - ph_occurred += self.items[item_name]['ph_seq'] - if len(ph_occurred) == 0: - raise BinarizationError(f'Empty tokens in {item_name}.') - for ph in ph_occurred: - if ph not in ph_required: - continue - phoneme_map[ph] += 1 - ph_occurred = set(ph_occurred) + ph_idx_occurred.update(self.items[item_name]['ph_seq']) + for idx in self.items[item_name]['ph_seq']: + ph_idx_count_map[idx] += 1 + ph_count_map = { + self.phoneme_dictionary.decode_one(idx, scalar=False): count + for idx, count in ph_idx_count_map.items() + } + + def display_phoneme(phoneme): + if isinstance(phoneme, tuple): + return f'({", ".join(phoneme)})' + return phoneme print('===== Phoneme Distribution Summary =====') - for i, key in enumerate(sorted(phoneme_map.keys())): - if i == len(ph_required) - 1: + keys = sorted(ph_count_map.keys(), key=lambda v: v[0] if isinstance(v, tuple) else v) + for i, key in enumerate(keys): + if i == len(ph_count_map) - 1: end = '\n' elif i % 10 == 9: end = ',\n' else: end = ', ' - print(f'\'{key}\': {phoneme_map[key]}', end=end) + key_disp = display_phoneme(key) + print(f'{key_disp}: {ph_count_map[key]}', end=end) # Draw graph. - x = sorted(phoneme_map.keys()) - values = [phoneme_map[k] for k in x] + xs = [display_phoneme(k) for k in keys] + ys = [ph_count_map[k] for k in keys] plt = distribution_to_figure( title='Phoneme Distribution Summary', x_label='Phoneme', y_label='Number of occurrences', - items=x, values=values + items=xs, values=ys, rotate=len(self.dictionaries) > 1 ) filename = self.binary_data_dir / 'phoneme_distribution.jpg' plt.savefig(fname=filename, @@ -239,19 +273,21 @@ def check_coverage(self): print(f'| save summary to \'{filename}\'') # Check unrecognizable or missing phonemes - if ph_occurred != ph_required: - unrecognizable_phones = ph_occurred.difference(ph_required) - missing_phones = ph_required.difference(ph_occurred) - raise BinarizationError('transcriptions and dictionary mismatch.\n' - f' (+) {sorted(unrecognizable_phones)}\n' - f' (-) {sorted(missing_phones)}') + if ph_idx_occurred != ph_idx_required: + missing_phones = sorted({ + self.phoneme_dictionary.decode_one(idx, scalar=False) + for idx in ph_idx_required.difference(ph_idx_occurred) + }, key=lambda v: v[0] if isinstance(v, tuple) else v) + raise BinarizationError( + f'The following phonemes are not covered in transcriptions: {missing_phones}' + ) def process_dataset(self, prefix, num_workers=0, apply_augmentation=False): args = [] builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs) total_sec = {k: 0.0 for k in self.spk_map} total_raw_sec = {k: 0.0 for k in self.spk_map} - extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}} + extra_info = {'names': {}, 'ph_texts': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}} max_no = -1 for item_name, meta_data in self.meta_data_iterator(prefix): @@ -271,6 +307,7 @@ def postprocess(_item): extra_info[k] = {} extra_info[k][item_no] = v.shape[0] extra_info['names'][item_no] = _item['name'].split(':', 1)[-1] + extra_info['ph_texts'][item_no] = _item['ph_text'] extra_info['spk_ids'][item_no] = _item['spk_id'] extra_info['spk_names'][item_no] = _item['spk_name'] extra_info['lengths'][item_no] = _item['length'] @@ -287,6 +324,7 @@ def postprocess(_item): extra_info[k] = {} extra_info[k][aug_item_no] = v.shape[0] extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1] + extra_info['ph_texts'][aug_item_no] = aug_item['ph_text'] extra_info['spk_ids'][aug_item_no] = aug_item['spk_id'] extra_info['spk_names'][aug_item_no] = aug_item['spk_name'] extra_info['lengths'][aug_item_no] = aug_item['length'] @@ -315,6 +353,7 @@ def postprocess(_item): builder.finalize() if prefix == "train": extra_info.pop("names") + extra_info.pop('ph_texts') extra_info.pop("spk_names") with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f: # noinspection PyTypeChecker diff --git a/basics/base_exporter.py b/basics/base_exporter.py index cc016004a..77e5805a8 100644 --- a/basics/base_exporter.py +++ b/basics/base_exporter.py @@ -1,4 +1,6 @@ import json +import pathlib +import shutil from pathlib import Path from typing import Union @@ -31,6 +33,18 @@ def build_spk_map(self) -> dict: else: return {} + # noinspection PyMethodMayBeStatic + def build_lang_map(self) -> dict: + lang_map_fn = pathlib.Path(hparams['work_dir']) / 'lang_map.json' + if lang_map_fn.exists(): + with open(lang_map_fn, 'r', encoding='utf8') as f: + lang_map = json.load(f) + assert isinstance(lang_map, dict) and len(lang_map) > 0, 'Invalid or empty language map!' + assert len(lang_map) == len(set(lang_map.values())), 'Duplicate language id in language map!' + return lang_map + else: + return {} + def build_model(self) -> nn.Module: """ Creates an instance of nn.Module and load its state dict on the target device. @@ -44,6 +58,19 @@ def export_model(self, path: Path): """ raise NotImplementedError() + # noinspection PyMethodMayBeStatic + def export_dictionaries(self, path: Path): + dicts = hparams.get('dictionaries') + if dicts is not None: + for lang in dicts.keys(): + fn = f'dictionary-{lang}.txt' + shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path) + print(f'| export dictionary => {path / fn}') + else: + fn = 'dictionary.txt' + shutil.copy(pathlib.Path(hparams['work_dir']) / fn, path) + print(f'| export dictionary => {path / fn}') + def export_attachments(self, path: Path): """ Exports related files and configs (e.g. the dictionary) to the target directory. diff --git a/basics/base_svs_infer.py b/basics/base_svs_infer.py index e040993a7..2b23d0112 100644 --- a/basics/base_svs_infer.py +++ b/basics/base_svs_infer.py @@ -29,6 +29,7 @@ def __init__(self, device=None): self.device = device self.timestep = hparams['hop_size'] / hparams['audio_sample_rate'] self.spk_map = {} + self.lang_map = {} self.model: torch.nn.Module = None def build_model(self, ckpt_steps=None) -> torch.nn.Module: @@ -50,7 +51,11 @@ def load_speaker_mix(self, param_src: dict, summary_dst: dict, spk_mix_map = param_src.get(param_key) # { spk_name: value } or { spk_name: "value value value ..." } dynamic = False if spk_mix_map is None: - # Get the first speaker + assert len(self.spk_map) == 1, ( + "This is a multi-speaker model. " + "Please specify a speaker or speaker mix by --spk option." + ) + # Get the only speaker for name in self.spk_map.keys(): spk_mix_map = {name: 1.0} break diff --git a/basics/base_task.py b/basics/base_task.py index 768f8e311..065f8273a 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -8,7 +8,6 @@ import matplotlib import utils -from utils.text_encoder import TokenTextEncoder matplotlib.use('Agg') @@ -24,7 +23,7 @@ DsBatchSampler, DsTensorBoardLogger, get_latest_checkpoint_path, get_strategy ) -from utils.phoneme_utils import locate_dictionary, build_phoneme_list +from utils.phoneme_utils import load_phoneme_dictionary torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) @@ -71,7 +70,7 @@ def __init__(self, *args, **kwargs): self.skip_immediate_validation = False self.skip_immediate_ckpt_save = False - self.phone_encoder = self.build_phone_encoder() + self.phoneme_dictionary = load_phoneme_dictionary() self.build_model() self.valid_losses: Dict[str, Metric] = {} @@ -165,11 +164,6 @@ def load_pre_train_model(self): else: raise RuntimeError("") - @staticmethod - def build_phone_encoder(): - phone_list = build_phoneme_list() - return TokenTextEncoder(vocab_list=phone_list) - def _build_model(self): raise NotImplementedError() @@ -448,21 +442,21 @@ def start(cls): if not hparams['infer']: # train @rank_zero_only def train_payload_copy(): - # Copy spk_map.json and dictionary.txt to work dir + # Copy files to work_dir binary_dir = pathlib.Path(hparams['binary_data_dir']) - spk_map = work_dir / 'spk_map.json' + spk_map_dst = work_dir / 'spk_map.json' spk_map_src = binary_dir / 'spk_map.json' - if not spk_map.exists() and spk_map_src.exists(): - shutil.copy(spk_map_src, spk_map) - print(f'| Copied spk map to {spk_map}.') - dictionary = work_dir / 'dictionary.txt' - dict_src = binary_dir / 'dictionary.txt' - if not dictionary.exists(): - if dict_src.exists(): - shutil.copy(dict_src, dictionary) - else: - shutil.copy(locate_dictionary(), dictionary) - print(f'| Copied dictionary to {dictionary}.') + shutil.copy(spk_map_src, spk_map_dst) + print(f'| Copied spk map to {spk_map_dst}.') + lang_map_dst = work_dir / 'lang_map.json' + lang_map_src = binary_dir / 'lang_map.json' + shutil.copy(lang_map_src, lang_map_dst) + print(f'| Copied lang map to {lang_map_dst}.') + for lang in hparams['dictionaries'].keys(): + dict_dst = work_dir / f'dictionary-{lang}.txt' + dict_src = binary_dir / f'dictionary-{lang}.txt' + shutil.copy(dict_src, dict_dst) + print(f'| Copied dictionary for language \'{lang}\' to {dict_dst}.') train_payload_copy() trainer.fit(task, ckpt_path=get_latest_checkpoint_path(work_dir)) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 61847b0fb..9f27733f7 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -2,17 +2,11 @@ base_config: - configs/base.yaml task_cls: training.acoustic_task.AcousticTask -num_spk: 1 -speakers: - - opencpop -spk_ids: [] -test_prefixes: [ - '2044', - '2086', - '2092', - '2093', - '2100', -] + +dictionaries: {} +extra_phonemes: [] +merged_phoneme_groups: [] +datasets: [] vocoder: NsfHifiGAN vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt @@ -41,10 +35,8 @@ augmentation_args: range: [0.5, 2.] scale: 0.75 -raw_data_dir: 'data/opencpop/raw' binary_data_dir: 'data/opencpop/binary' binarizer_cls: preprocessing.acoustic_binarizer.AcousticBinarizer -dictionary: dictionaries/opencpop-extension.txt spec_min: [-12] spec_max: [0] mel_vmin: -14. @@ -55,7 +47,10 @@ breathiness_smooth_width: 0.12 voicing_smooth_width: 0.12 tension_smooth_width: 0.12 +use_lang_id: false +num_lang: 1 use_spk_id: false +num_spk: 1 use_energy_embed: false use_breathiness_embed: false use_voicing_embed: false diff --git a/configs/base.yaml b/configs/base.yaml index b2e610f95..ab33c5541 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -5,7 +5,7 @@ task_cls: null # dataset ############# sort_by_len: true -raw_data_dir: null +datasets: [] binary_data_dir: null binarizer_cls: null binarization_args: diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 1cf5235b1..59778df99 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -1,19 +1,33 @@ -base_config: configs/acoustic.yaml +base_config: + - configs/acoustic.yaml + +dictionaries: + zh: dictionaries/opencpop-extension.txt +extra_phonemes: [] +merged_phoneme_groups: [] + +datasets: + - raw_data_dir: data/xxx1/raw + speaker: speaker1 + spk_id: 0 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 + - raw_data_dir: data/xxx2/raw + speaker: speaker2 + spk_id: 1 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 -raw_data_dir: - - data/xxx1/raw - - data/xxx2/raw -speakers: - - speaker1 - - speaker2 -spk_ids: [] -test_prefixes: - - wav1 - - wav2 - - wav3 - - wav4 - - wav5 -dictionary: dictionaries/opencpop-extension.txt binary_data_dir: data/xxx/binary binarization_args: num_workers: 0 @@ -24,6 +38,8 @@ hnsep_ckpt: 'checkpoints/vr/model.pt' vocoder: NsfHifiGAN vocoder_ckpt: checkpoints/nsf_hifigan_44.1k_hop512_128bin_2024.02/model.ckpt +use_lang_id: false +num_lang: 1 use_spk_id: false num_spk: 1 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index cf5163cc2..7d5b211aa 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -1,29 +1,43 @@ base_config: - configs/variance.yaml -raw_data_dir: - - data/xxx1/raw - - data/xxx2/raw -speakers: - - speaker1 - - speaker2 -spk_ids: [] -test_prefixes: - - wav1 - - wav2 - - wav3 - - wav4 - - wav5 -dictionary: dictionaries/opencpop-extension.txt +dictionaries: + zh: dictionaries/opencpop-extension.txt +extra_phonemes: [] +merged_phoneme_groups: [] + +datasets: + - raw_data_dir: data/xxx1/raw + speaker: speaker1 + spk_id: 0 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 + - raw_data_dir: data/xxx2/raw + speaker: speaker2 + spk_id: 1 + language: zh + test_prefixes: + - wav1 + - wav2 + - wav3 + - wav4 + - wav5 + binary_data_dir: data/xxx/binary binarization_args: num_workers: 0 - pe: parselmouth pe_ckpt: 'checkpoints/rmvpe/model.pt' hnsep: vr hnsep_ckpt: 'checkpoints/vr/model.pt' +use_lang_id: false +num_lang: 1 use_spk_id: false num_spk: 1 # NOTICE: before enabling variance modules, please read the docs at diff --git a/configs/variance.yaml b/configs/variance.yaml index d38f4752a..61c508a1b 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -2,17 +2,11 @@ base_config: - configs/base.yaml task_cls: training.variance_task.VarianceTask -num_spk: 1 -speakers: - - opencpop -spk_ids: [] -test_prefixes: [ - '2044', - '2086', - '2092', - '2093', - '2100', -] + +dictionaries: {} +extra_phonemes: [] +merged_phoneme_groups: [] +datasets: [] audio_sample_rate: 44100 hop_size: 512 # Hop size. @@ -25,17 +19,13 @@ binarization_args: num_workers: 0 prefer_ds: false -raw_data_dir: 'data/opencpop_variance/raw' binary_data_dir: 'data/opencpop_variance/binary' binarizer_cls: preprocessing.variance_binarizer.VarianceBinarizer -dictionary: dictionaries/opencpop-extension.txt +use_lang_id: false +num_lang: 1 use_spk_id: false - -enc_ffn_kernel_size: 3 -use_rope: true -rel_pos: true -hidden_size: 256 +num_spk: 1 predict_dur: true predict_pitch: true @@ -44,6 +34,11 @@ predict_breathiness: false predict_voicing: false predict_tension: false +enc_ffn_kernel_size: 3 +use_rope: true +rel_pos: true +hidden_size: 256 + dur_prediction_args: arch: fs2 hidden_size: 512 diff --git a/deployment/exporters/acoustic_exporter.py b/deployment/exporters/acoustic_exporter.py index 4f0f533e2..849dae5db 100644 --- a/deployment/exporters/acoustic_exporter.py +++ b/deployment/exporters/acoustic_exporter.py @@ -1,4 +1,4 @@ -import shutil +import json from pathlib import Path from typing import List, Union, Tuple, Dict @@ -12,8 +12,7 @@ from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST from utils import load_ckpt, onnx_helper, remove_suffix from utils.hparams import hparams -from utils.phoneme_utils import locate_dictionary, build_phoneme_list -from utils.text_encoder import TokenTextEncoder +from utils.phoneme_utils import load_phoneme_dictionary class DiffSingerAcousticExporter(BaseExporter): @@ -32,7 +31,9 @@ def __init__( self.model_name: str = hparams['exp_name'] self.ckpt_steps: int = ckpt_steps self.spk_map: dict = self.build_spk_map() - self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list()) + self.lang_map: dict = self.build_lang_map() + self.phoneme_dictionary = load_phoneme_dictionary() + self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 self.model = self.build_model() self.fs2_aux_cache_path = self.cache_dir / ( 'fs2_aux.onnx' if self.model.use_shallow_diffusion else 'fs2.onnx' @@ -80,8 +81,12 @@ def __init__( def build_model(self) -> DiffSingerAcousticONNX: model = DiffSingerAcousticONNX( - vocab_size=len(self.vocab), - out_dims=hparams['audio_num_mel_bins'] + vocab_size=len(self.phoneme_dictionary), + out_dims=hparams['audio_num_mel_bins'], + cross_lingual_token_idx=sorted({ + self.phoneme_dictionary.encode_one(p) + for p in self.phoneme_dictionary.cross_lingual_phonemes + }) ).eval().to(self.device) load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, prefix_in_ckpt='model', strict=True, device=self.device) @@ -111,15 +116,17 @@ def export_attachments(self, path: Path): path / f'{self.model_name}.{spk[0]}.emb', self._perform_spk_mix(spk[1]) ) - self._export_dictionary(path / 'dictionary.txt') - self._export_phonemes(path / f'{self.model_name}.phonemes.txt') + self.export_dictionaries(path) + self._export_phonemes(path) model_name = self.model_name if self.freeze_spk is not None: model_name += '.' + self.freeze_spk[0] dsconfig = { # basic configs - 'phonemes': f'{self.model_name}.phonemes.txt', + 'phonemes': f'{self.model_name}.phonemes.json', + 'languages': f'{self.model_name}.languages.json', + 'use_lang_id': self.use_lang_id, 'acoustic': f'{model_name}.onnx', 'hidden_size': hparams['hidden_size'], 'vocoder': 'nsf_hifigan_44.1k_hop512_128bin_2024.02', @@ -211,6 +218,12 @@ def _torch_export_model(self): dynamix_axes['spk_embed'] = { 1: 'n_frames' } + if self.use_lang_id: + kwargs['languages'] = torch.zeros_like(tokens) + input_names.append('languages') + dynamix_axes['languages'] = { + 1: 'n_tokens' + } dynamix_axes['condition'] = { 1: 'n_frames' } @@ -334,6 +347,10 @@ def _optimize_fs2_aux_graph(self, fs2: onnx.ModelProto) -> onnx.ModelProto: print(f'Running ONNX Simplifier on {self.fs2_aux_class_name}...') fs2, check = onnxsim.simplify(fs2, include_subgraph=True) assert check, 'Simplified ONNX model could not be validated' + onnx_helper.model_reorder_io_list( + fs2, 'input', + target_name='languages', insert_after_name='tokens' + ) print(f'| optimize graph: {self.fs2_aux_class_name}') return fs2 @@ -395,11 +412,11 @@ def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor): f.write(spk_embed.cpu().numpy().tobytes()) print(f'| export spk embed => {path}') - # noinspection PyMethodMayBeStatic - def _export_dictionary(self, path: Path): - print(f'| export dictionary => {path}') - shutil.copy(locate_dictionary(), path) - def _export_phonemes(self, path: Path): - self.vocab.store_to_file(path) - print(f'| export phonemes => {path}') + ph_path = path / f'{self.model_name}.phonemes.json' + self.phoneme_dictionary.dump(ph_path) + print(f'| export phonemes => {ph_path}') + lang_path = path / f'{self.model_name}.languages.json' + with open(lang_path, 'w', encoding='utf8') as f: + json.dump(self.lang_map, f, ensure_ascii=False, indent=2) + print(f'| export languages => {lang_path}') diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 1af433ae4..82808ec08 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -1,4 +1,4 @@ -import shutil +import json from pathlib import Path from typing import Union, List, Tuple, Dict @@ -12,8 +12,7 @@ from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST from utils import load_ckpt, onnx_helper, remove_suffix from utils.hparams import hparams -from utils.phoneme_utils import locate_dictionary, build_phoneme_list -from utils.text_encoder import TokenTextEncoder +from utils.phoneme_utils import load_phoneme_dictionary class DiffSingerVarianceExporter(BaseExporter): @@ -32,7 +31,9 @@ def __init__( self.model_name: str = hparams['exp_name'] self.ckpt_steps: int = ckpt_steps self.spk_map: dict = self.build_spk_map() - self.vocab = TokenTextEncoder(vocab_list=build_phoneme_list()) + self.lang_map: dict = self.build_lang_map() + self.phoneme_dictionary = load_phoneme_dictionary() + self.use_lang_id = hparams.get('use_lang_id', False) and len(self.phoneme_dictionary.cross_lingual_phonemes) > 0 self.model = self.build_model() self.linguistic_encoder_cache_path = self.cache_dir / 'linguistic.onnx' self.dur_predictor_cache_path = self.cache_dir / 'dur.onnx' @@ -83,7 +84,11 @@ def __init__( def build_model(self) -> DiffSingerVarianceONNX: model = DiffSingerVarianceONNX( - vocab_size=len(self.vocab) + vocab_size=len(self.phoneme_dictionary), + cross_lingual_token_idx=sorted({ + self.phoneme_dictionary.encode_one(p) + for p in self.phoneme_dictionary.cross_lingual_phonemes + }) ).eval().to(self.device) load_ckpt(model, hparams['work_dir'], ckpt_steps=self.ckpt_steps, prefix_in_ckpt='model', strict=True, device=self.device) @@ -142,15 +147,17 @@ def export_attachments(self, path: Path): path / f'{self.model_name}.{spk[0]}.emb', self._perform_spk_mix(spk[1]) ) - self._export_dictionary(path / 'dictionary.txt') - self._export_phonemes((path / f'{self.model_name}.phonemes.txt')) + self.export_dictionaries(path) + self._export_phonemes(path) model_name = self.model_name if self.freeze_spk is not None: model_name += '.' + self.freeze_spk[0] dsconfig = { # basic configs - 'phonemes': f'{self.model_name}.phonemes.txt', + 'phonemes': f'{self.model_name}.phonemes.json', + 'languages': f'{self.model_name}.languages.json', + 'use_lang_id': self.use_lang_id, 'linguistic': f'{model_name}.linguistic.onnx', 'hidden_size': self.model.hidden_size, 'predict_dur': self.model.predict_dur, @@ -186,6 +193,7 @@ def _torch_export_model(self): ph_dur = torch.LongTensor([[3, 5, 2, 1, 4]]).to(self.device) word_div = torch.LongTensor([[2, 2, 1]]).to(self.device) word_dur = torch.LongTensor([[8, 3, 4]]).to(self.device) + languages = torch.LongTensor([[0] * 5]).to(self.device) encoder_out = torch.rand(1, 5, hparams['hidden_size'], dtype=torch.float32, device=self.device) x_masks = tokens == 0 ph_midi = torch.LongTensor([[60] * 5]).to(self.device) @@ -198,6 +206,7 @@ def _torch_export_model(self): 1: 'n_tokens' } } + input_lang_id = self.use_lang_id input_spk_embed = hparams['use_spk_id'] and not self.freeze_spk print(f'Exporting {self.fs2_class_name}...') @@ -207,13 +216,15 @@ def _torch_export_model(self): ( tokens, word_div, - word_dur + word_dur, + *([languages] if input_lang_id else []) ), self.linguistic_encoder_cache_path, input_names=[ 'tokens', 'word_div', - 'word_dur' + 'word_dur', + *(['languages'] if input_lang_id else []) ], output_names=encoder_output_names, dynamic_axes={ @@ -226,7 +237,8 @@ def _torch_export_model(self): 'word_dur': { 1: 'n_words' }, - **encoder_common_axes + **encoder_common_axes, + **({'languages': {1: 'n_tokens'}} if input_lang_id else {}) }, opset_version=15 ) @@ -270,12 +282,14 @@ def _torch_export_model(self): self.model.view_as_linguistic_encoder(), ( tokens, - ph_dur + ph_dur, + *([languages] if input_lang_id else []) ), self.linguistic_encoder_cache_path, input_names=[ 'tokens', - 'ph_dur' + 'ph_dur', + *(['languages'] if input_lang_id else []) ], output_names=encoder_output_names, dynamic_axes={ @@ -285,7 +299,8 @@ def _torch_export_model(self): 'ph_dur': { 1: 'n_tokens' }, - **encoder_common_axes + **encoder_common_axes, + **({'languages': {1: 'n_tokens'}} if input_lang_id else {}) }, opset_version=15 ) @@ -637,6 +652,10 @@ def _optimize_linguistic_graph(self, linguistic: onnx.ModelProto) -> onnx.ModelP print(f'Running ONNX Simplifier on {self.fs2_class_name}...') linguistic, check = onnxsim.simplify(linguistic, include_subgraph=True) assert check, 'Simplified ONNX model could not be validated' + onnx_helper.model_reorder_io_list( + linguistic, 'input', + target_name='languages', insert_after_name='tokens' + ) print(f'| optimize graph: {self.fs2_class_name}') return linguistic @@ -771,11 +790,11 @@ def _export_spk_embed(self, path: Path, spk_embed: torch.Tensor): f.write(spk_embed.cpu().numpy().tobytes()) print(f'| export spk embed => {path}') - # noinspection PyMethodMayBeStatic - def _export_dictionary(self, path: Path): - print(f'| export dictionary => {path}') - shutil.copy(locate_dictionary(), path) - def _export_phonemes(self, path: Path): - self.vocab.store_to_file(path) - print(f'| export phonemes => {path}') + ph_path = path / f'{self.model_name}.phonemes.json' + self.phoneme_dictionary.dump(ph_path) + print(f'| export phonemes => {ph_path}') + lang_path = path / f'{self.model_name}.languages.json' + with open(lang_path, 'w', encoding='utf8') as fw: + json.dump(self.lang_map, fw, ensure_ascii=False, indent=2) + print(f'| export languages => {lang_path}') diff --git a/deployment/modules/fastspeech2.py b/deployment/modules/fastspeech2.py index d0a3c7b5a..20dfdb0d7 100644 --- a/deployment/modules/fastspeech2.py +++ b/deployment/modules/fastspeech2.py @@ -9,7 +9,7 @@ from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.variance_encoder import FastSpeech2Variance from utils.hparams import hparams -from utils.text_encoder import PAD_INDEX +from utils.phoneme_utils import PAD_INDEX f0_bin = 256 f0_max = 1100.0 @@ -41,8 +41,15 @@ def forward(self, dur): class FastSpeech2AcousticONNX(FastSpeech2Acoustic): - def __init__(self, vocab_size): + def __init__(self, vocab_size, cross_lingual_token_idx=None): super().__init__(vocab_size=vocab_size) + self.register_buffer( + 'cross_lingual_token_idx', + torch.LongTensor(cross_lingual_token_idx), + persistent=False + ) # [N,] + if len(cross_lingual_token_idx) == 0: + self.use_lang_id = False # for temporary compatibility; will be completely removed in the future self.f0_embed_type = hparams.get('f0_embed_type', 'continuous') @@ -56,14 +63,29 @@ def __init__(self, vocab_size): self.speed_min, self.speed_max = hparams['augmentation_args']['random_time_stretching']['range'] # noinspection PyMethodOverriding - def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity=None, spk_embed=None): + def forward( + self, tokens, durations, + f0, variances: dict, + gender=None, velocity=None, + spk_embed=None, + languages=None + ): txt_embed = self.txt_embed(tokens) durations = durations * (tokens > 0) mel2ph = self.lr(durations) f0 = f0 * (mel2ph > 0) mel2ph = mel2ph[..., None].repeat((1, 1, hparams['hidden_size'])) dur_embed = self.dur_embed(durations.float()[:, :, None]) - encoded = self.encoder(txt_embed, dur_embed, tokens == PAD_INDEX) + if self.use_lang_id: + lang_mask = torch.any( + tokens[..., None] == self.cross_lingual_token_idx[None, None], + dim=-1 + ) + lang_embed = self.lang_embed(languages * lang_mask) + extra_embed = dur_embed + lang_embed + else: + extra_embed = dur_embed + encoded = self.encoder(txt_embed, extra_embed, tokens == PAD_INDEX) encoded = F.pad(encoded, (0, 0, 1, 0)) condition = torch.gather(encoded, 1, mel2ph) @@ -109,25 +131,49 @@ def forward(self, tokens, durations, f0, variances: dict, gender=None, velocity= class FastSpeech2VarianceONNX(FastSpeech2Variance): - def __init__(self, vocab_size): + def __init__(self, vocab_size, cross_lingual_token_idx=None): super().__init__(vocab_size=vocab_size) + self.register_buffer( + 'cross_lingual_token_idx', + torch.LongTensor(cross_lingual_token_idx), + persistent=False + ) + if len(cross_lingual_token_idx) == 0: + self.use_lang_id = False self.lr = LengthRegulator() - def forward_encoder_word(self, tokens, word_div, word_dur): + def forward_encoder_word(self, tokens, word_div, word_dur, languages=None): txt_embed = self.txt_embed(tokens) ph2word = self.lr(word_div) onset = ph2word > F.pad(ph2word, [1, -1]) onset_embed = self.onset_embed(onset.long()) ph_word_dur = torch.gather(F.pad(word_dur, [1, 0]), 1, ph2word) word_dur_embed = self.word_dur_embed(ph_word_dur.float()[:, :, None]) + extra_embed = onset_embed + word_dur_embed + if self.use_lang_id: + lang_mask = torch.any( + tokens[..., None] == self.cross_lingual_token_idx[None, None], + dim=-1 + ) + lang_embed = self.lang_embed(languages * lang_mask) + extra_embed += lang_embed x_masks = tokens == PAD_INDEX - return self.encoder(txt_embed, onset_embed + word_dur_embed, x_masks), x_masks + return self.encoder(txt_embed, extra_embed, x_masks), x_masks - def forward_encoder_phoneme(self, tokens, ph_dur): + def forward_encoder_phoneme(self, tokens, ph_dur, languages=None): txt_embed = self.txt_embed(tokens) ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) + if self.use_lang_id: + lang_mask = torch.any( + tokens[..., None] == self.cross_lingual_token_idx[None, None], + dim=-1 + ) + lang_embed = self.lang_embed(languages * lang_mask) + extra_embed = ph_dur_embed + lang_embed + else: + extra_embed = ph_dur_embed x_masks = tokens == PAD_INDEX - return self.encoder(txt_embed, ph_dur_embed, x_masks), x_masks + return self.encoder(txt_embed, extra_embed, x_masks), x_masks def forward_dur_predictor(self, encoder_out, x_masks, ph_midi, spk_embed=None): midi_embed = self.midi_embed(ph_midi) diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index e358f25a0..90ade235d 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -18,12 +18,13 @@ class DiffSingerAcousticONNX(DiffSingerAcoustic): - def __init__(self, vocab_size, out_dims): + def __init__(self, vocab_size, out_dims, cross_lingual_token_idx=None): super().__init__(vocab_size, out_dims) del self.fs2 del self.diffusion self.fs2 = FastSpeech2AcousticONNX( - vocab_size=vocab_size + vocab_size=vocab_size, + cross_lingual_token_idx=cross_lingual_token_idx ) if self.diffusion_type == 'ddpm': self.diffusion = GaussianDiffusionONNX( @@ -65,11 +66,13 @@ def forward_fs2_aux( variances: dict, gender: Tensor = None, velocity: Tensor = None, - spk_embed: Tensor = None + spk_embed: Tensor = None, + languages: Tensor = None ): condition = self.fs2( tokens, durations, f0, variances=variances, - gender=gender, velocity=velocity, spk_embed=spk_embed + gender=gender, velocity=velocity, spk_embed=spk_embed, + languages=languages ) if self.use_shallow_diffusion: aux_mel_pred = self.aux_decoder(condition, infer=True) @@ -127,11 +130,12 @@ def view_as_reflow(self) -> nn.Module: class DiffSingerVarianceONNX(DiffSingerVariance): - def __init__(self, vocab_size): + def __init__(self, vocab_size, cross_lingual_token_idx=None): super().__init__(vocab_size=vocab_size) del self.fs2 self.fs2 = FastSpeech2VarianceONNX( - vocab_size=vocab_size + vocab_size=vocab_size, + cross_lingual_token_idx=cross_lingual_token_idx ) self.hidden_size = hparams['hidden_size'] if self.predict_pitch: @@ -194,13 +198,13 @@ def embed_frozen_spk(self, encoder_out): encoder_out += self.frozen_spk_embed return encoder_out - def forward_linguistic_encoder_word(self, tokens, word_div, word_dur): - encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur) + def forward_linguistic_encoder_word(self, tokens, word_div, word_dur, languages=None): + encoder_out, x_masks = self.fs2.forward_encoder_word(tokens, word_div, word_dur, languages=languages) encoder_out = self.embed_frozen_spk(encoder_out) return encoder_out, x_masks - def forward_linguistic_encoder_phoneme(self, tokens, ph_dur): - encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur) + def forward_linguistic_encoder_phoneme(self, tokens, ph_dur, languages=None): + encoder_out, x_masks = self.fs2.forward_encoder_phoneme(tokens, ph_dur, languages=languages) encoder_out = self.embed_frozen_spk(encoder_out) return encoder_out, x_masks diff --git a/docs/BestPractices.md b/docs/BestPractices.md index 04426b836..cc9c26dd9 100644 --- a/docs/BestPractices.md +++ b/docs/BestPractices.md @@ -1,42 +1,126 @@ # Best Practices -## Materials for training and using models +## Fundamental concepts and materials -### Datasets +### Configuration files -A dataset mainly includes recordings and transcriptions, which is called a _raw dataset_. Raw datasets should be organized as the following folder structure: +A configuration file is a YAML file that defines enabled features, model hyperparameters and controls the behavior of the binarizer, trainer and inference. Almost all settings and controls in this repository, including the practices in this guidance, are achieved through configuration files. -- my_raw_data/ - - wavs/ - - 001.wav - - 002.wav - - ... (more recording files) - - transcriptions.csv +For more information of the configuration system and configurable attributes, see [Configuration Schemas](ConfigurationSchemas.md). -In the example above, the _my_raw_data_ folder is the root directory of a raw dataset. +### Languages -The _transcriptions.csv_ file contains all labels of the recordings. The common column of the CSV file is `name`, which represents all recording items by their filenames **without extension**. Elements of sequence attributes should be split by `space`. Other required columns may vary according to the category of the model you are training, and will be introduced in the following sections. +Each language you are dealing with should have a unique tag in the configuration file. **We highly recommend using ISO 639 language codes as language tags.** For example, `zh` and `zho` stands for Chinese (`cmn` specifically for Mandarin Chinese), `ja` and `jpn` for Japanese, `en` and `eng` for English, `yue` for Cantonese (Yue). You can download a complete language code table from https://iso639-3.sil.org/code_tables/download_tables. + +### Phonemes + +Phonemes are the fundamental part of dictionaries and labels. There are two types of phonemes: language-specific phonemes and global phonemes. + +**Language-specific phonemes:** If there are multiple languages, all language-specific phonemes will be prefixed with its language name. For example: `zh/a`, `ja/o`, `en/eh`. These are called the **full name** of the phonemes, while `a`, `o`, `eh` are called the **short name** which has definite meaning only in a specific language context. If there is only one language, the short names can be used to determine each phoneme. + +**Global phonemes:** Some phonemes do not belong to any language. There are two reserved global phoneme tags: `SP` for space, and `AP` for aspiration. There can also be other user-defined tags (`EP`, `GS`, `VF`, etc.). These tags will not be prefixed with language, and are prior when identifying phoneme names. + +Extra phonemes, including user-defined global phonemes and additional language-specific phonemes that are not present in the dictionaries, can be defined in a list in the configuration file (full names should be used): + +```yaml +extra_phonemes: ['EP', 'ja/cl'] +``` + +The phoneme set expands rapidly with the number of languages. There are actually many similar phonemes that can be merged. Define the merging groups in your configuration file (full names should be used): + +```yaml +merged_phoneme_groups: + - [zh/i, ja/i, en/iy] + - [zh/s, ja/s, en/s] + - [ja/cl, SP] # global phonemes can also be merged + # ... (other groups omitted for brevity) +use_lang_id: true # whether to use language embedding; only take effects if there are cross-lingual phonemes +``` + +Merging phonemes does not mean that they are exactly the same for the dictionary. For those cross-lingual merged phonemes, Setting `use_lang_id` to true will still distinguish them by language IDs. + +#### Phoneme naming principles + +- Short names of language-specific phonemes should not conflict with global phoneme names, including reserved ones. +- `/` cannot be used because it is already used for splitting the language tag and the short name. +- `-` and `+` cannot be used because they are defined as slur tags in most singing voice synthesis editors. +- Other special characters, including but not limited to `@`, `#`, `&`, `|`, `<`, `>`, is not recommended because they may be used as special tags in the future format changes. +- ASCII characters are preferred for the best encoding compatibility, but all UTF-8 characters are acceptable. ### Dictionaries -A dictionary is a .txt file, in which each line represents a mapping rule from one syllable to its phoneme sequence. The syllable and the phonemes are split by `tab`, and the phonemes are split by `space`: +Each language should have a corresponding dictionary. Define languages and dictionaries in your configuration file: + +```yaml +dictionaries: + zh: dictionaries/opencpop-extension.txt + ja: dictionaries/japanese_dict_full.txt + en: dictionaries/ds_cmudict-07b.txt +num_lang: 3 # number of languages; should be >= number of defined languages +``` + +Each dictionary is a *.txt* file, in which each line represents a mapping rule from one syllable to its phoneme sequence. The syllable and the phonemes are split by `tab`, and the phonemes are split by `space`: ``` ... ``` -Syllable names and phoneme names can be customized, but with the following limitations/suggestions: +#### Syllable naming principles -- `SP` (rest), `AP` (breath) and `` (padding) cannot be used because they are reserved. +- Try to use a standard writing or pronouncing system. For example, pinyin for Mandarin Chinese, romaji for Japanese and English words for English. +- `AP` and `SP` cannot be used because they are reserved tags when using DiffSinger in editors. +- `/` cannot be used because it is already used for splitting the language tag and the short name. - `-` and `+` cannot be used because they are defined as slur tags in most singing voice synthesis editors. -- Special characters including but not limited to `@`, `#`, `&`, `|`, `/`, `<`, `>`, etc. should be avoided because they may be used as special tags in the future format changes. Using them now is okay, and all modifications will be notified in advance. +- Syllable names is not recommended to start with `.` because this may have special meanings in the future editors. +- Other special characters, including but not limited to `@`, `#`, `&`, `|`, `<`, `>`, is not recommended because they may be used as special tags in the future format changes. - ASCII characters are preferred for the best encoding compatibility, but all UTF-8 characters are acceptable. -There are some preset dictionaries in the [dictionaries/](../dictionaries) folder. For the guidance of using a custom dictionary, see [Using custom dictionaries](#using-custom-dictionaries). +There are some example dictionaries in the [dictionaries/](../dictionaries) folder. -### Configuration files +### Datasets -A configuration file is a YAML file that defines enabled features, model hyperparameters and controls the behavior of the binarizer, trainer and inference. For more information of the configuration system and configurable attributes, see [Configuration Schemas](ConfigurationSchemas.md). +A dataset mainly includes recordings and transcriptions, which is called a _raw dataset_. Raw datasets should be organized as the following folder structure: + +- my_raw_data/ + - wavs/ + - 001.wav + - 002.wav + - ... (more recording files) + - transcriptions.csv + +In the example above, the _my_raw_data_ directory is the root directory of a raw dataset. + +The _transcriptions.csv_ file contains all labels of the recordings. The common column of the CSV file is `name`, which represents all recording items by their filenames **without extension**. Elements of sequence attributes should be split by `space`. Other required columns may vary according to the category of the model you are training, and will be introduced in the following sections. + +Each dataset should have a main language. If you have many recordings in multiple languages, it is recommended to separate them by language (you can merge their speaker IDs in the configuration). In each dataset, the main language is set as the language context, and phoneme labels in transcriptions.csv do not need a prefix (short name). It is also valid if there are phonemes from other languages, but all of them should be prefixed with their actual language (full name). Global phonemes should not be prefixed in any datasets. + +You can define your datasets in the configuration file like this: + +```yaml +datasets: # define all raw datasets + - raw_data_dir: data/spk1-zh/raw # path to the root of a raw dataset + speaker: speaker1 # speaker name + spk_id: 0 # optional; use this to merge two datasets; otherwise automatically assigned + language: zh # language tag (main language) of this dataset + test_prefixes: # optional; validation samples from this dataset + - wav1 + - wav2 + - raw_data_dir: data/spk1-en/raw + speaker: speaker1 + spk_id: 0 # specify the same speaker ID to merge into the previous one + language: en + test_prefixes: + - wav1 + - wav2 + - raw_data_dir: data/spk2/raw + speaker: speaker2 + language: ja + test_prefixes: + - wav1 + - wav2 + # ... (other datasets omitted for brevity) +num_spk: 2 # number of languages; should be > maximum speaker ID +``` ### DS files @@ -54,7 +138,7 @@ The [DiffSinger Community Vocoders Project](https://openvpi.github.io/vocoders) The pre-trained vocoder can be fine-tuned on your target dataset. It is highly recommended to do so because fine-tuned vocoder can generate much better results on specific (seen) datasets while does not need much computing resources. See the [vocoder training and fine-tuning repository](https://github.com/openvpi/SingingVocoders) for detailed instructions. After you get the fine-tuned vocoder checkpoint, you can configure it by `vocoder_ckpt` key in your configuration file. The fine-tuned NSF-HiFiGAN vocoder checkpoints can be exported to ONNX format like other DiffSinger user models for further production purposes. -Another unrecommended option: train a ultra-lightweight [DDSP vocoder](https://github.com/yxlllc/pc-ddsp) first by yourself, then configure it according to the relevant [instructions](https://github.com/yxlllc/pc-ddsp/blob/master/DiffSinger.md). +Another unrecommended option: train an ultra-lightweight [DDSP vocoder](https://github.com/yxlllc/pc-ddsp) first by yourself, then configure it according to the relevant [instructions](https://github.com/yxlllc/pc-ddsp/blob/master/DiffSinger.md). #### Feature extractors or auxiliary models @@ -108,57 +192,6 @@ Functionalities of variance models are defined by their outputs. There are three There may be some mutual influence between the modules above when they are enabled together. See [mutual influence between variance modules](#mutual-influence-between-variance-modules) for more details. -## Using custom dictionaries - -This section is about using a custom grapheme-to-phoneme dictionary for any language(s). - -### Add a dictionary - -Assume that you have made a dictionary file named `my_dict.txt`. Edit your configuration file: - -```yaml -dictionary: my_dict.txt -``` - -Then you can binarize your data as normal. The phonemes in your dataset must cover, and must only cover the phonemes appeared in your dictionary. Otherwise, the binarizer will raise an error: - -``` -AssertionError: transcriptions and dictionary mismatch. - (+) ['E', 'En', 'i0', 'ir'] - (-) ['AP', 'SP'] -``` - -This means there are 4 unexpected symbols in the data labels (`ir`, `i0`, `E`, `En`) and 2 missing phonemes that are not covered by the data labels (`AP`, `SP`). - -Once the coverage checks passed, a phoneme distribution summary will be saved into your binary data directory. Below is an example. - -![phoneme-distribution](resources/phoneme-distribution.jpg) - -During the binarization process, each phoneme will be assigned with a unique phoneme ID according the order of their names. There are one padding index (marked as `defaultlengths -### dictionary +### datasets -Path to the word-phoneme mapping dictionary file. Training data must fully cover phonemes in the dictionary. +List of dataset configs for preprocessing. - + +
visibilityacoustic, variance
scopepreprocessing
customizabilitynormal
typeList[dict]
+ +### datasets[].language + +Language context of this dataset. Must be a key of [dictionaries](#dictionaries). + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typestr
+ +### datasets[].raw_data_dir + +Path to this dataset including wave files, transcriptions, etc. + + + + +
visibilityall
scopepreprocessing
customizabilityrequired
typestr
+### datasets[].speaker + +The name of speaker of this dataset. Speaker names are mapped to speaker indexes and stored into spk_map.json when preprocessing. + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typestr
+ +### datasets[].spk_id + +The speaker ID assigned to this dataset. Will be automatically assigned if not given. IDs can be duplicate or discontinuous to merge multiple datasets to one speaker. + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilitynormal
typeint
+ +### datasets[].test_prefixes + +List of data item names or name prefixes in this dataset for the validation set. For each string `s` in the list: + +- If `s` equals to an actual item name, add that item to validation set. +- If `s` does not equal to any item names, add all items whose names start with `s` to validation set. + + + + + + +
visibilityall
scopepreprocessing
customizabilityrequired
typelist
+ +### dictionaries + +Map of language names and their corresponding dictionary file paths. The phonemes in these dictionaries will be combined as the final phoneme set and have their phoneme IDs. Training data must fully cover all phoneme IDs. + + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typeDict[str, str]
default{}
+ ### diff_accelerator DDPM sampling acceleration method. The following methods are currently available: @@ -655,6 +724,18 @@ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted ener default0.12 +### extra_phonemes + +Extra phonemes to be added to the phoneme set. This list can be used to define custom global phoneme tags besides `AP` and `SP`, or to contain phonemes that are not present in any of the dictionaries. + + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilitynormal
typelist
default[]
+ ### f0_max Maximum base frequency (F0) in Hz for pitch extraction. @@ -1122,6 +1203,18 @@ Arguments for melody encoder. Available sub-keys: `hidden_size`, `enc_layers`, ` typedict +### merged_phoneme_groups + +Phoneme groups to merge. Each group is a phoneme name list. The merged phonemes share the same ID and thus the same phoneme embedding. + + + + + + + +
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typelist
default[]
+ ### midi_smooth_width Length of sinusoidal smoothing convolution kernel (in seconds) on the step function representing MIDI sequence for base pitch calculation. @@ -1170,6 +1263,17 @@ The number of attention heads of `torch.nn.MultiheadAttention` in FastSpeech2 en default2 +### num_lang + +Number of languages. This value is used to allocate language embeddings in the linguistic encoder. + + + + + + +
visibilityacoustic, variance
scopenn
customizabilityrequired
typeint
+ ### num_sanity_val_steps Number of sanity validation steps at the beginning. @@ -1499,17 +1603,6 @@ Whether to enable voicing prediction. defaulttrue -### raw_data_dir - -Path(s) to the raw dataset including wave files, transcriptions, etc. - - - - - - -
visibilityall
scopepreprocessing
customizabilityrequired
typestr, List[str]
- ### rel_pos Whether to use relative positional encoding in FastSpeech2 module. @@ -1674,29 +1767,6 @@ Whether to apply the _sorting by similar length_ algorithm described in [sampler defaulttrue -### speakers - -The names of speakers in a multi-speaker model. Speaker names are mapped to speaker indexes and stored into spk_map.json when preprocessing. - - - - - - -
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typelist
- -### spk_ids - -The IDs of speakers in a multi-speaker model. If an empty list is given, speaker IDs will be automatically generated as $0,1,2,...,N_{spk}-1$. IDs can be duplicate or discontinuous. - - - - - - - -
visibilityacoustic, variance
scopepreprocessing
customizabilityrequired
typeList[int]
default[]
- ### spec_min Minimum mel spectrogram value used for normalization to [-1, 1]. Different mel bins can have different minimum values. @@ -1801,22 +1871,6 @@ Length of sinusoidal smoothing convolution kernel (in seconds) on extracted tens default0.12 -### test_prefixes - -List of data item names or name prefixes for the validation set. For each string `s` in the list: - -- If `s` equals to an actual item name, add that item to validation set. -- If `s` does not equal to any item names, add all items whose names start with `s` to validation set. - -For multi-speaker combined datasets, "ds_id:name_prefix" can be used to apply the rules above within data from a specific sub-dataset, where ds_id represents the dataset index. - - - - - - -
visibilityall
scopepreprocessing
customizabilityrequired
typelist
- ### time_scale_factor The scale factor that will be multiplied on the time $t$ of Rectified Flow before embedding into the model. @@ -1891,6 +1945,18 @@ Whether to embed key shifting values introduced by random pitch shifting augment constraintsMust be true if random pitch shifting is enabled. +### use_lang_id + +Whether to embed the language ID from a multilingual dataset. This option only takes effect for those cross-lingual phonemes in the merged groups. + + + + + + + +
visibilityacoustic, variance
scopenn, preprocessing, inference
customizabilityrecommended
typebool
defaultfalse
+ ### use_melody_encoder Whether to enable melody encoder for the pitch predictor. @@ -1941,7 +2007,7 @@ Whether to embed speed values introduced by random time stretching augmentation. ### use_spk_id -Whether embed the speaker id from a multi-speaker dataset. +Whether to embed the speaker ID from a multi-speaker dataset. diff --git a/docs/GettingStarted.md b/docs/GettingStarted.md index a3422c3f0..92ddb395f 100644 --- a/docs/GettingStarted.md +++ b/docs/GettingStarted.md @@ -14,9 +14,9 @@ DiffSinger requires Python 3.8 or later. We strongly recommend you create a virt pip install -r requirements.txt ``` -### Materials and assets +### Concepts and materials -Some essential materials and assets are needed before continuing with this repository. See [materials for training and using models](BestPractices.md#materials-for-training-and-using-models) for detailed instructions. +Before you proceed, it is necessary to understand some fundamental concepts in this repository and prepare some materials and assets. See [fundamental concepts and materials](BestPractices.md#fundamental-concepts-and-materials) for detailed information. ## Configuration diff --git a/inference/ds_acoustic.py b/inference/ds_acoustic.py index a67f5b166..8b139f62f 100644 --- a/inference/ds_acoustic.py +++ b/inference/ds_acoustic.py @@ -1,12 +1,11 @@ -from collections import OrderedDict - -import tqdm import json import pathlib +from collections import OrderedDict +from typing import Dict import numpy as np import torch -from typing import Dict +import tqdm from basics.base_svs_infer import BaseSVSInfer from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST @@ -16,8 +15,7 @@ from utils import load_ckpt from utils.hparams import hparams from utils.infer_utils import cross_fade, resample_align_curve, save_wav -from utils.phoneme_utils import build_phoneme_list -from utils.text_encoder import TokenTextEncoder +from utils.phoneme_utils import load_phoneme_dictionary class DiffSingerAcousticInfer(BaseSVSInfer): @@ -37,12 +35,16 @@ def __init__(self, device=None, load_model=True, load_vocoder=True, ckpt_steps=N if hparams.get('use_tension_embed', False): self.variances_to_embed.add('tension') - self.ph_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) + self.phoneme_dictionary = load_phoneme_dictionary() if hparams['use_spk_id']: with open(pathlib.Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f: self.spk_map = json.load(f) assert isinstance(self.spk_map, dict) and len(self.spk_map) > 0, 'Invalid or empty speaker map!' assert len(self.spk_map) == len(set(self.spk_map.values())), 'Duplicate speaker id in speaker map!' + lang_map_fn = pathlib.Path(hparams['work_dir']) / 'lang_map.json' + if lang_map_fn.exists(): + with open(lang_map_fn, 'r', encoding='utf8') as f: + self.lang_map = json.load(f) self.model = self.build_model(ckpt_steps=ckpt_steps) self.lr = LengthRegulator().to(self.device) if load_vocoder: @@ -50,7 +52,7 @@ def __init__(self, device=None, load_model=True, load_vocoder=True, ckpt_steps=N def build_model(self, ckpt_steps=None): model = DiffSingerAcoustic( - vocab_size=len(self.ph_encoder), + vocab_size=len(self.phoneme_dictionary), out_dims=hparams['audio_num_mel_bins'] ).eval().to(self.device) load_ckpt(model, hparams['work_dir'], ckpt_steps=ckpt_steps, @@ -73,7 +75,28 @@ def preprocess_input(self, param, idx=0): """ batch = {} summary = OrderedDict() - txt_tokens = torch.LongTensor([self.ph_encoder.encode(param['ph_seq'])]).to(self.device) # => [B, T_txt] + + lang = param.get('lang') + if lang is None: + assert len(self.lang_map) <= 1, ( + "This is a multilingual model. " + "Please specify a language by --lang option." + ) + else: + assert lang in self.lang_map, f'Unrecognized language name: \'{lang}\'.' + if hparams.get('use_lang_id', False): + languages = torch.LongTensor([ + ( + self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]] + if self.phoneme_dictionary.is_cross_lingual(p) + else 0 + ) + for p in param['ph_seq'].split() + ]).to(self.device) # => [B, T_txt] + batch['languages'] = languages + txt_tokens = torch.LongTensor([ + self.phoneme_dictionary.encode(param['ph_seq'], lang=lang) + ]).to(self.device) # => [B, T_txt] batch['tokens'] = txt_tokens ph_dur = torch.from_numpy(np.array(param['ph_dur'].split(), np.float32)).to(self.device) @@ -175,9 +198,11 @@ def forward_model(self, sample): else: spk_mix_embed = None mel_pred: ShallowDiffusionOutput = self.model( - txt_tokens, mel2ph=sample['mel2ph'], f0=sample['f0'], **variances, + txt_tokens, languages=sample.get('languages'), + mel2ph=sample['mel2ph'], f0=sample['f0'], **variances, key_shift=sample.get('key_shift'), speed=sample.get('speed'), - spk_mix_embed=spk_mix_embed, infer=True + spk_mix_embed=spk_mix_embed, + infer=True ) return mel_pred.diff_out diff --git a/inference/ds_variance.py b/inference/ds_variance.py index c8a9b090a..aa74dcabd 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -1,31 +1,29 @@ import copy import json - -import tqdm import pathlib from collections import OrderedDict +from typing import List, Tuple import librosa import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import tqdm from scipy import interpolate -from typing import List, Tuple from basics.base_svs_infer import BaseSVSInfer +from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST from modules.fastspeech.tts_modules import ( LengthRegulator, RhythmRegulator, mel2ph_to_dur ) -from modules.fastspeech.param_adaptor import VARIANCE_CHECKLIST from modules.toplevel import DiffSingerVariance from utils import load_ckpt from utils.hparams import hparams from utils.infer_utils import resample_align_curve -from utils.phoneme_utils import build_phoneme_list +from utils.phoneme_utils import load_phoneme_dictionary from utils.pitch_utils import interp_f0 -from utils.text_encoder import TokenTextEncoder class DiffSingerVarianceInfer(BaseSVSInfer): @@ -34,12 +32,16 @@ def __init__( predictions: set = None ): super().__init__(device=device) - self.ph_encoder = TokenTextEncoder(vocab_list=build_phoneme_list()) + self.phoneme_dictionary = load_phoneme_dictionary() if hparams['use_spk_id']: with open(pathlib.Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f: self.spk_map = json.load(f) assert isinstance(self.spk_map, dict) and len(self.spk_map) > 0, 'Invalid or empty speaker map!' assert len(self.spk_map) == len(set(self.spk_map.values())), 'Duplicate speaker id in speaker map!' + lang_map_fn = pathlib.Path(hparams['work_dir']) / 'lang_map.json' + if lang_map_fn.exists(): + with open(lang_map_fn, 'r', encoding='utf8') as f: + self.lang_map = json.load(f) self.model: DiffSingerVariance = self.build_model(ckpt_steps=ckpt_steps) self.lr = LengthRegulator() self.rr = RhythmRegulator() @@ -76,7 +78,7 @@ def __init__( def build_model(self, ckpt_steps=None): model = DiffSingerVariance( - vocab_size=len(self.ph_encoder) + vocab_size=len(self.phoneme_dictionary) ).eval().to(self.device) load_ckpt(model, hparams['work_dir'], ckpt_steps=ckpt_steps, prefix_in_ckpt='model', strict=True, device=self.device) @@ -97,7 +99,28 @@ def preprocess_input( """ batch = {} summary = OrderedDict() - txt_tokens = torch.LongTensor([self.ph_encoder.encode(param['ph_seq'].split())]).to(self.device) # [B=1, T_ph] + + lang = param.get('lang') + if lang is None: + assert len(self.lang_map) <= 1, ( + "This is a multilingual model. " + "Please specify a language by --lang option." + ) + else: + assert lang in self.lang_map, f'Unrecognized language name: \'{lang}\'.' + if hparams.get('use_lang_id', False): + languages = torch.LongTensor([ + ( + self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]] + if self.phoneme_dictionary.is_cross_lingual(p) + else 0 + ) + for p in param['ph_seq'].split() + ]).to(self.device) # [B=1, T_ph] + batch['languages'] = languages + txt_tokens = torch.LongTensor([ + self.phoneme_dictionary.encode(param['ph_seq'], lang=lang) + ]).to(self.device) # [B=1, T_ph] T_ph = txt_tokens.shape[1] batch['tokens'] = txt_tokens ph_num = torch.from_numpy(np.array([param['ph_num'].split()], np.int64)).to(self.device) # [B=1, T_w] @@ -305,7 +328,8 @@ def forward_model(self, sample): ph_spk_mix_embed = spk_mix_embed = None dur_pred, pitch_pred, variance_pred = self.model( - txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, + txt_tokens, languages=sample.get('languages'), + midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index a639e52f3..b6f986bb0 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -8,13 +8,16 @@ ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur from utils.hparams import hparams -from utils.text_encoder import PAD_INDEX +from utils.phoneme_utils import PAD_INDEX class FastSpeech2Acoustic(nn.Module): def __init__(self, vocab_size): super().__init__() self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX) + self.use_lang_id = hparams.get('use_lang_id', False) + if self.use_lang_id: + self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) self.dur_embed = Linear(1, hparams['hidden_size']) self.encoder = FastSpeech2Encoder( hidden_size=hparams['hidden_size'], num_layers=hparams['enc_layers'], @@ -79,12 +82,18 @@ def forward_variance_embedding(self, condition, key_shift=None, speed=None, **va def forward( self, txt_tokens, mel2ph, f0, key_shift=None, speed=None, - spk_embed_id=None, **kwargs + spk_embed_id=None, languages=None, + **kwargs ): txt_embed = self.txt_embed(txt_tokens) dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float() dur_embed = self.dur_embed(dur[:, :, None]) - encoder_out = self.encoder(txt_embed, dur_embed, txt_tokens == 0) + if self.use_lang_id: + lang_embed = self.lang_embed(languages) + extra_embed = dur_embed + lang_embed + else: + extra_embed = dur_embed + encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0) encoder_out = F.pad(encoder_out, [0, 0, 1, 0]) mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 03f274caa..deab9ee84 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -8,7 +8,7 @@ ) from modules.fastspeech.tts_modules import FastSpeech2Encoder, DurationPredictor from utils.hparams import hparams -from utils.text_encoder import PAD_INDEX +from utils.phoneme_utils import PAD_INDEX class FastSpeech2Variance(nn.Module): @@ -16,8 +16,11 @@ def __init__(self, vocab_size): super().__init__() self.predict_dur = hparams['predict_dur'] self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme' + self.use_lang_id = hparams['use_lang_id'] self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX) + if self.use_lang_id: + self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0) if self.predict_dur: self.onset_embed = Embedding(2, hparams['hidden_size']) @@ -46,7 +49,12 @@ def __init__(self, vocab_size): dur_loss_type=dur_hparams['loss_type'] ) - def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_embed=None, infer=True): + def forward( + self, txt_tokens, midi, ph2word, + ph_dur=None, word_dur=None, + spk_embed=None, languages=None, + infer=True + ): """ :param txt_tokens: (train, infer) [B, T_ph] :param midi: (train, infer) [B, T_ph] @@ -54,6 +62,7 @@ def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_emb :param ph_dur: (train, [infer]) [B, T_ph] :param word_dur: (infer) [B, T_w] :param spk_embed: (train) [B, T_ph, H] + :param languages (train, infer) [B, T_ph] :param infer: whether inference :return: encoder_out, ph_dur_pred """ @@ -69,11 +78,14 @@ def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_emb )[:, 1:] # [B, T_ph] => [B, T_w] word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph] word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None]) - - encoder_out = self.encoder(txt_embed, onset_embed + word_dur_embed, txt_tokens == 0) + extra_embed = onset_embed + word_dur_embed else: ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None]) - encoder_out = self.encoder(txt_embed, ph_dur_embed, txt_tokens == 0) + extra_embed = ph_dur_embed + if self.use_lang_id: + lang_embed = self.lang_embed(languages) + extra_embed += lang_embed + encoder_out = self.encoder(txt_embed, extra_embed, txt_tokens == 0) if self.predict_dur: midi_embed = self.midi_embed(midi) # => [B, T_ph, H] diff --git a/modules/toplevel.py b/modules/toplevel.py index 3a8ae06c3..aceff1f70 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -83,11 +83,12 @@ def __init__(self, vocab_size, out_dims): def forward( self, txt_tokens, mel2ph, f0, key_shift=None, speed=None, - spk_embed_id=None, gt_mel=None, infer=True, **kwargs + spk_embed_id=None, languages=None, gt_mel=None, infer=True, **kwargs ) -> ShallowDiffusionOutput: condition = self.fs2( txt_tokens, mel2ph, f0, key_shift=key_shift, speed=speed, - spk_embed_id=spk_embed_id, **kwargs + spk_embed_id=spk_embed_id, languages=languages, + **kwargs ) if infer: if self.use_shallow_diffusion: @@ -199,7 +200,8 @@ def forward( note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None, base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None, - spk_id=None, infer=True, **kwargs + spk_id=None, languages=None, + infer=True, **kwargs ): if self.use_spk_id: ph_spk_mix_embed = kwargs.get('ph_spk_mix_embed') @@ -215,7 +217,8 @@ def forward( encoder_out, dur_pred_out = self.fs2( txt_tokens, midi=midi, ph2word=ph2word, ph_dur=ph_dur, word_dur=word_dur, - spk_embed=ph_spk_embed, infer=infer + spk_embed=ph_spk_embed, languages=languages, + infer=infer ) if not self.predict_pitch and not self.predict_variances: diff --git a/preprocessing/acoustic_binarizer.py b/preprocessing/acoustic_binarizer.py index b61c88f88..0455c4f94 100644 --- a/preprocessing/acoustic_binarizer.py +++ b/preprocessing/acoustic_binarizer.py @@ -36,6 +36,7 @@ ACOUSTIC_ITEM_ATTRIBUTES = [ 'spk_id', 'mel', + 'languages', 'tokens', 'mel2ph', 'f0', @@ -67,17 +68,26 @@ def __init__(self): "See https://github.com/openvpi/DiffSinger/releases/tag/v2.3.0 for more details." ) - def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): + def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang): meta_data_dict = {} with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf-8') as f: for utterance_label in csv.DictReader(f): item_name = utterance_label['name'] temp_dict = { 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': utterance_label['ph_seq'].split(), + 'spk_id': self.spk_map[spk], + 'spk_name': spk, + 'lang_seq': [ + ( + self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]] + if self.phoneme_dictionary.is_cross_lingual(p) + else 0 + ) + for p in utterance_label['ph_seq'].split() + ], + 'ph_seq': self.phoneme_dictionary.encode(utterance_label['ph_seq'], lang=lang), 'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()], - 'spk_id': spk_id, - 'spk_name': self.speakers[ds_id], + 'ph_text': utterance_label['ph_seq'], } assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' @@ -85,7 +95,7 @@ def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): f'Negative ph_dur found in \'{item_name}\'.' meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict - self.items.update(meta_data_dict) + return meta_data_dict @torch.no_grad() def process_item(self, item_name, meta_data, binarization_args): @@ -106,8 +116,10 @@ def process_item(self, item_name, meta_data, binarization_args): 'seconds': seconds, 'length': length, 'mel': mel, - 'tokens': np.array(self.phone_encoder.encode(meta_data['ph_seq']), dtype=np.int64), + 'languages': np.array(meta_data['lang_seq'], dtype=np.int64), + 'tokens': np.array(meta_data['ph_seq'], dtype=np.int64), 'ph_dur': np.array(meta_data['ph_dur']).astype(np.float32), + 'ph_text': meta_data['ph_text'], } # get ground truth dur diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index b4bef4fec..84d9ea499 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -30,6 +30,7 @@ os.environ["OMP_NUM_THREADS"] = "1" VARIANCE_ITEM_ATTRIBUTES = [ 'spk_id', # index number of dataset/speaker, int64 + 'languages', # index numbers of phoneme languages, int64[T_ph,] 'tokens', # index numbers of phonemes, int64[T_ph,] 'ph_dur', # durations of phonemes, in number of frames, int64[T_ph,] 'midi', # phoneme-level mean MIDI pitch, int64[T_ph,] @@ -108,7 +109,7 @@ def load_attr_from_ds(self, ds_id, name, attr, idx=0): ds = ds[idx] return ds.get(attr) - def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): + def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang): meta_data_dict = {} with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f: @@ -130,11 +131,22 @@ def require(attr, optional=False): temp_dict = { 'ds_idx': item_idx, - 'spk_id': spk_id, - 'spk_name': self.speakers[ds_id], + 'spk_id': self.spk_map[spk], + 'spk_name': spk, + 'language_id': self.lang_map[lang], + 'language_name': lang, 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': require('ph_seq').split(), - 'ph_dur': [float(x) for x in require('ph_dur').split()] + 'lang_seq': [ + ( + self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]] + if self.phoneme_dictionary.is_cross_lingual(p) + else 0 + ) + for p in utterance_label['ph_seq'].split() + ], + 'ph_seq': self.phoneme_dictionary.encode(require('ph_seq'), lang=lang), + 'ph_dur': [float(x) for x in require('ph_dur').split()], + 'ph_text': require('ph_seq'), } assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ @@ -170,7 +182,7 @@ def require(attr, optional=False): meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict - self.items.update(meta_data_dict) + return meta_data_dict def check_coverage(self): super().check_coverage() @@ -258,7 +270,9 @@ def process_item(self, item_name, meta_data, binarization_args): 'spk_name': meta_data['spk_name'], 'seconds': seconds, 'length': length, - 'tokens': np.array(self.phone_encoder.encode(meta_data['ph_seq']), dtype=np.int64) + 'languages': np.array(meta_data['lang_seq'], dtype=np.int64), + 'tokens': np.array(meta_data['ph_seq'], dtype=np.int64), + 'ph_text': meta_data['ph_text'], } ph_dur_sec = torch.FloatTensor(meta_data['ph_dur']).to(self.device) diff --git a/scripts/infer.py b/scripts/infer.py index 83a5cabb7..ae08f5d12 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -61,6 +61,11 @@ def main(): required=False, help='Speaker name or mixture of speakers' ) +@click.option( + '--lang', type=click.STRING, + required=False, + help='Default language name' +) @click.option( '--out', type=click.Path( file_okay=False, dir_okay=True, path_type=pathlib.Path @@ -112,6 +117,7 @@ def acoustic( exp: str, ckpt: int, spk: str, + lang: str, out: pathlib.Path, title: str, num: int, @@ -195,9 +201,10 @@ def acoustic( for param in params: if gender is not None and hparams['use_key_shift_embed']: param['gender'] = gender - if spk_mix is not None: param['spk_mix'] = spk_mix + if lang is not None: + param['lang'] = lang from inference.ds_acoustic import DiffSingerAcousticInfer infer_ins = DiffSingerAcousticInfer(load_vocoder=not mel, ckpt_steps=ckpt) @@ -241,6 +248,11 @@ def acoustic( required=False, help='Speaker name or mixture of speakers' ) +@click.option( + '--lang', type=click.STRING, + required=False, + help='Default language name' +) @click.option( '--out', type=click.Path( file_okay=False, dir_okay=True, path_type=pathlib.Path @@ -282,6 +294,7 @@ def variance( exp: str, ckpt: int, spk: str, + lang: str, predict: Tuple[str], out: pathlib.Path, title: str, @@ -344,11 +357,12 @@ def variance( for param in params: if expr is not None: param['expr'] = expr - if spk_mix is not None: param['ph_spk_mix_backup'] = param.get('ph_spk_mix') param['spk_mix_backup'] = param.get('spk_mix') param['ph_spk_mix'] = param['spk_mix'] = spk_mix + if lang is not None: + param['lang'] = lang from inference.ds_variance import DiffSingerVarianceInfer infer_ins = DiffSingerVarianceInfer(ckpt_steps=ckpt, predictions=set(predict)) diff --git a/training/acoustic_task.py b/training/acoustic_task.py index de6a9adb5..ca6a71c65 100644 --- a/training/acoustic_task.py +++ b/training/acoustic_task.py @@ -35,6 +35,7 @@ def __init__(self, prefix, preload=False): self.need_key_shift = hparams['use_key_shift_embed'] self.need_speed = hparams['use_speed_embed'] self.need_spk_id = hparams['use_spk_id'] + self.need_lang_id = hparams['use_lang_id'] def collater(self, samples): batch = super().collater(samples) @@ -60,6 +61,9 @@ def collater(self, samples): if self.need_spk_id: spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) batch['spk_ids'] = spk_ids + if self.need_lang_id: + languages = utils.collate_nd([s['languages'] for s in samples], 0) + batch['languages'] = languages return batch @@ -92,7 +96,7 @@ def __init__(self): def _build_model(self): return DiffSingerAcoustic( - vocab_size=len(self.phone_encoder), + vocab_size=len(self.phoneme_dictionary), out_dims=hparams['audio_num_mel_bins'] ) @@ -128,9 +132,14 @@ def run_model(self, sample, infer=False): spk_embed_id = sample['spk_ids'] else: spk_embed_id = None + if hparams['use_lang_id']: + languages = sample['languages'] + else: + languages = None output: ShallowDiffusionOutput = self.model( txt_tokens, mel2ph=mel2ph, f0=f0, **variances, - key_shift=key_shift, speed=speed, spk_embed_id=spk_embed_id, + key_shift=key_shift, speed=speed, + spk_embed_id=spk_embed_id, languages=languages, gt_mel=target, infer=infer ) diff --git a/training/variance_task.py b/training/variance_task.py index 88a844952..2fdc599f6 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -41,6 +41,8 @@ def collater(self, samples): if hparams['use_spk_id']: batch['spk_ids'] = torch.LongTensor([s['spk_id'] for s in samples]) + if hparams['use_lang_id']: + batch['languages'] = utils.collate_nd([s['languages'] for s in samples], 0) if hparams['predict_dur']: batch['ph2word'] = utils.collate_nd([s['ph2word'] for s in samples], 0) batch['midi'] = utils.collate_nd([s['midi'] for s in samples], 0) @@ -85,6 +87,7 @@ def __init__(self): self.diffusion_type = hparams['diffusion_type'] self.use_spk_id = hparams['use_spk_id'] + self.use_lang_id = hparams['use_lang_id'] self.predict_dur = hparams['predict_dur'] if self.predict_dur: @@ -113,7 +116,7 @@ def __init__(self): def _build_model(self): return DiffSingerVariance( - vocab_size=len(self.phone_encoder), + vocab_size=len(self.phoneme_dictionary), ) # noinspection PyAttributeOutsideInit @@ -154,6 +157,7 @@ def build_losses_and_metrics(self): def run_model(self, sample, infer=False): spk_ids = sample['spk_ids'] if self.use_spk_id else None # [B,] + languages = sample['languages'] if self.use_lang_id else None # [B,] txt_tokens = sample['tokens'] # [B, T_ph] ph_dur = sample['ph_dur'] # [B, T_ph] ph2word = sample.get('ph2word') # [B, T_ph] @@ -188,7 +192,8 @@ def run_model(self, sample, infer=False): } output = self.model( - txt_tokens, midi=midi, ph2word=ph2word, + txt_tokens, languages=languages, + midi=midi, ph2word=ph2word, ph_dur=ph_dur, mel2ph=mel2ph, note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, @@ -262,7 +267,10 @@ def sample_get(key, idx, abs_idx): self.valid_metrics['ph_dur_acc'].update( pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask ) - self.plot_dur(data_idx, gt_dur, pred_dur, tokens) + self.plot_dur( + data_idx, gt_dur, pred_dur, + txt=self.valid_dataset.metadata['ph_texts'][data_idx].split() + ) if pitch_preds is not None: pitch_len = self.valid_dataset.metadata['pitch'][data_idx] pred_pitch = sample_get('base_pitch', i, data_idx) + pitch_preds[i][:pitch_len].unsqueeze(0) @@ -295,7 +303,6 @@ def sample_get(key, idx, abs_idx): def plot_dur(self, data_idx, gt_dur, pred_dur, txt=None): gt_dur = gt_dur[0].cpu().numpy() pred_dur = pred_dur[0].cpu().numpy() - txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split() title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" self.logger.all_rank_experiment.add_figure(f'dur_{data_idx}', dur_to_figure( gt_dur, pred_dur, txt, title_text diff --git a/utils/onnx_helper.py b/utils/onnx_helper.py index 176df56dc..1470e47d6 100644 --- a/utils/onnx_helper.py +++ b/utils/onnx_helper.py @@ -1,5 +1,5 @@ import re -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, Literal import onnx from google.protobuf.internal.containers import RepeatedCompositeFieldContainer @@ -51,6 +51,42 @@ def _override_shapes( _override_shapes(model.graph.output, output_shapes) +def model_reorder_io_list( + model: ModelProto, + input_or_output: Literal['input', 'output'], + target_name: str, + insert_after_name: str, +): + """ + Reorder the input of the model graph by moving the target input after the specified input (in-place operation). + If the given names are not found, the operation will be ignored. + :param model: model to perform the operation on + :param input_or_output: 'input' or 'output' to specify the list to reorder + :param target_name: the name of the input to be reordered + :param insert_after_name: the name of the input to be inserted after (None for the first) + """ + def _reorder_input(input_list: RepeatedCompositeFieldContainer[ValueInfoProto]): + nonlocal input_or_output + target_idx = -1 + insert_after_idx = -1 + for i, value_info in enumerate(input_list): + if value_info.name == target_name: + target_idx = i + if value_info.name == insert_after_name: + insert_after_idx = i + if target_idx != -1 and insert_after_idx != -1: + target = input_list.pop(target_idx) + input_list.insert(insert_after_idx + 1, target) + _verbose(f'| reorder {input_or_output}: \'{target_name}\' after \'{insert_after_name}\'') + + if input_or_output == 'input': + _reorder_input(model.graph.input) + elif input_or_output == 'output': + _reorder_input(model.graph.output) + else: + raise ValueError('Argument \'input_or_output\' should be either \'input\' or \'output\'.') + + def model_add_prefixes( model: ModelProto, initializer_prefix=None, @@ -97,7 +133,7 @@ def _add_prefixes_recursive(subgraph): new_name = initializer_prefix + initializer.name _verbose('| add prefix:', initializer.name, '->', new_name) initializer.name = new_name - + for value_info in subgraph.value_info: if dim_prefix is not None: for dim in value_info.type.tensor_type.shape.dim: @@ -114,7 +150,7 @@ def _add_prefixes_recursive(subgraph): new_name = value_info_prefix + value_info.name _verbose('| add prefix:', value_info.name, '->', new_name) value_info.name = new_name - + if node_prefix is not None: for node in subgraph.node: if ignored_pattern is not None and re.match(ignored_pattern, node.name): @@ -122,7 +158,7 @@ def _add_prefixes_recursive(subgraph): new_name = node_prefix + node.name _verbose('| add prefix:', node.name, '->', new_name) node.name = new_name - + for node in subgraph.node: # For 'If' and 'Loop' nodes, add prefixes recursively if node.op_type == 'If': @@ -134,7 +170,7 @@ def _add_prefixes_recursive(subgraph): if attr.name == 'body': body = onnx.helper.get_attribute_value(attr) _add_prefixes_recursive(body) - + # For each node, rename its inputs and outputs for io_list in [node.input, node.output]: for i, io_value in enumerate(io_list): diff --git a/utils/phoneme_utils.py b/utils/phoneme_utils.py index 269122a6d..ca1af6203 100644 --- a/utils/phoneme_utils.py +++ b/utils/phoneme_utils.py @@ -1,99 +1,211 @@ +import json import pathlib - -try: - from lightning.pytorch.utilities.rank_zero import rank_zero_info -except ModuleNotFoundError: - rank_zero_info = print +from typing import Dict, List, Union from utils.hparams import hparams -_initialized = False -_ALL_CONSONANTS_SET = set() -_ALL_VOWELS_SET = set() -_dictionary = { - 'AP': ['AP'], - 'SP': ['SP'] -} -_phoneme_list: list - - -def locate_dictionary(): - """ - Search and locate the dictionary file. - Order: - 1. hparams['dictionary'] - 2. hparams['g2p_dictionary'] - 3. 'dictionary.txt' in hparams['work_dir'] - 4. file with same name as hparams['g2p_dictionary'] in hparams['work_dir'] - :return: pathlib.Path of the dictionary file - """ - assert 'dictionary' in hparams or 'g2p_dictionary' in hparams, \ - 'Please specify a dictionary file in your config.' - config_dict_path = pathlib.Path(hparams['dictionary']) - if config_dict_path.exists(): - return config_dict_path - work_dir = pathlib.Path(hparams['work_dir']) - ckpt_dict_path = work_dir / config_dict_path.name - if ckpt_dict_path.exists(): - return ckpt_dict_path - ckpt_dict_path = work_dir / 'dictionary.txt' - if ckpt_dict_path.exists(): - return ckpt_dict_path - raise FileNotFoundError('Unable to locate the dictionary file. ' - 'Please specify the right dictionary in your config.') - - -def _build_dict_and_list(): - global _dictionary, _phoneme_list - - _set = set() - with open(locate_dictionary(), 'r', encoding='utf8') as _df: - _lines = _df.readlines() - for _line in _lines: - _pinyin, _ph_str = _line.strip().split('\t') - _dictionary[_pinyin] = _ph_str.split() - for _list in _dictionary.values(): - [_set.add(ph) for ph in _list] - _phoneme_list = sorted(list(_set)) - rank_zero_info('| load phoneme set: ' + str(_phoneme_list)) - - -def _initialize_consonants_and_vowels(): - # Currently we only support two-part consonant-vowel phoneme systems. - for _ph_list in _dictionary.values(): - _ph_count = len(_ph_list) - if _ph_count == 0 or _ph_list[0] in ['AP', 'SP']: - continue - elif len(_ph_list) == 1: - _ALL_VOWELS_SET.add(_ph_list[0]) +PAD_INDEX = 0 + + +class PhonemeDictionary: + def __init__( + self, + dictionaries: Dict[str, pathlib.Path], + extra_phonemes: List[str] = None, + merged_groups: List[List[str]] = None + ): + # Step 1: Collect all phonemes + all_phonemes = {'AP', 'SP'} + if extra_phonemes: + for ph in extra_phonemes: + if '/' in ph: + lang, name = ph.split('/', maxsplit=1) + if lang not in dictionaries: + raise ValueError( + f"Invalid phoneme tag '{ph}' in extra phonemes: " + f"unrecognized language name '{lang}'." + ) + if name in all_phonemes: + raise ValueError( + f"Invalid phoneme tag '{ph}' in extra phonemes: " + f"short name conflicts with existing tag." + ) + all_phonemes.add(ph) + self._multi_langs = len(dictionaries) > 1 + for lang, dict_path in dictionaries.items(): + with open(dict_path, 'r', encoding='utf8') as dict_file: + for line in dict_file: + _, phonemes = line.strip().split('\t') + phonemes = phonemes.split() + for phoneme in phonemes: + if '/' in phoneme: + raise ValueError( + f"Invalid phoneme tag '{phoneme}' in dictionary '{dict_path}': " + f"should not contain the reserved character '/'." + ) + if phoneme in all_phonemes: + continue + if self._multi_langs: + all_phonemes.add(f'{lang}/{phoneme}') + else: + all_phonemes.add(phoneme) + # Step 2: Parse merged phoneme groups + if merged_groups is None: + merged_groups = [] else: - _ALL_CONSONANTS_SET.add(_ph_list[0]) - _ALL_VOWELS_SET.add(_ph_list[1]) - - -def _initialize(): - global _initialized - if not _initialized: - _build_dict_and_list() - _initialize_consonants_and_vowels() - _initialized = True - - -def get_all_consonants(): - _initialize() - return sorted(_ALL_CONSONANTS_SET) - - -def get_all_vowels(): - _initialize() - return sorted(_ALL_VOWELS_SET) - - -def build_dictionary() -> dict: - _initialize() - return _dictionary - - -def build_phoneme_list() -> list: - _initialize() - return _phoneme_list + _merged_groups = [] + for group in merged_groups: + _group = [] + for phoneme in group: + if '/' in phoneme: + lang, name = phoneme.split('/', maxsplit=1) + if lang not in dictionaries: + raise ValueError( + f"Invalid phoneme tag '{phoneme}' in merged group: " + f"unrecognized language name '{lang}'." + ) + if self._multi_langs: + element = phoneme + else: + element = name + else: + element = phoneme + if element not in all_phonemes: + raise ValueError( + f"Invalid phoneme tag '{phoneme}' in merged group: " + f"not found in phoneme set." + ) + _group.append(element) + _merged_groups.append(_group) + merged_groups = [set(phones) for phones in _merged_groups if len(phones) > 1] + # Step 3: Build phoneme index + merged_phonemes_inverted_index = {} + for idx, group in enumerate(merged_groups): + other_idx = None + for phoneme in group: + if phoneme in merged_phonemes_inverted_index: + other_idx = merged_phonemes_inverted_index[phoneme] + break + target_idx = idx if other_idx is None else other_idx + for phoneme in group: + merged_phonemes_inverted_index[phoneme] = target_idx + if other_idx is not None: + merged_groups[other_idx] |= group + group.clear() + phone_to_id = {} + id_to_phone = [] + cross_lingual_phonemes = set() + idx = 1 + for phoneme in sorted(all_phonemes): + if phoneme in merged_phonemes_inverted_index: + has_assigned = True + for alias in merged_groups[merged_phonemes_inverted_index[phoneme]]: + if alias not in phone_to_id: + has_assigned = False + phone_to_id[alias] = idx + if not has_assigned: + merged_group = sorted(merged_groups[merged_phonemes_inverted_index[phoneme]]) + merged_from_langs = { + alias.split('/', maxsplit=1)[0] + for alias in merged_group + if '/' in alias + } + id_to_phone.append(tuple(merged_group)) + idx += 1 + if len(merged_from_langs) > 1: + cross_lingual_phonemes.update(ph for ph in merged_group if '/' in ph) + else: + phone_to_id[phoneme] = idx + id_to_phone.append(phoneme) + idx += 1 + self._phone_to_id: Dict[str, int] = phone_to_id + self._id_to_phone: List[Union[str, tuple]] = id_to_phone + self._cross_lingual_phonemes = frozenset(cross_lingual_phonemes) + + @property + def vocab_size(self): + return len(self._id_to_phone) + 1 + + def __len__(self): + return self.vocab_size + + @property + def cross_lingual_phonemes(self): + return self._cross_lingual_phonemes + + def is_cross_lingual(self, phone): + return phone in self._cross_lingual_phonemes + + def encode_one(self, phone, lang=None): + if '/' in phone: + lang, phone = phone.split('/', maxsplit=1) + if lang is None or not self._multi_langs or phone in self._phone_to_id: + return self._phone_to_id[phone] + if '/' not in phone: + phone = f'{lang}/{phone}' + return self._phone_to_id[phone] + + def encode(self, sentence, lang=None): + phones = sentence.strip().split() if isinstance(sentence, str) else sentence + return [self.encode_one(phone, lang=lang) for phone in phones] + + def decode_one(self, idx, lang=None, scalar=True): + if idx <= 0: + return None + phone = self._id_to_phone[idx - 1] + if not scalar or isinstance(phone, str): + return phone + if lang is None or not self._multi_langs: + return phone[0] + for alias in phone: + if alias.startswith(f'{lang}/'): + return alias + return phone[0] + + def decode(self, ids, lang=None, scalar=True): + ids = list(ids) + return ' '.join([ + self.decode_one(i, lang=lang, scalar=scalar) + for i in ids + if i >= 1 + ]) + + def dump(self, filename): + with open(filename, 'w', encoding='utf8') as fp: + json.dump(self._phone_to_id, fp, ensure_ascii=False, indent=2) + + +_dictionary = None + + +def load_phoneme_dictionary() -> PhonemeDictionary: + if _dictionary is not None: + return _dictionary + config_dicts = hparams.get('dictionaries') + if config_dicts is not None: + dicts = {} + for lang, config_dict_path in config_dicts.items(): + dict_path = pathlib.Path(hparams['work_dir']) / f'dictionary-{lang}.txt' + if not dict_path.exists(): + dict_path = pathlib.Path(config_dict_path) + if not dict_path.exists(): + raise FileNotFoundError( + f"Could not locate dictionary for language '{lang}'." + ) + dicts[lang] = dict_path + else: + dict_path = pathlib.Path(hparams['work_dir']) / 'dictionary.txt' + if not dict_path.exists(): + dict_path = pathlib.Path(hparams['dictionary']) + if not dict_path.exists(): + raise FileNotFoundError( + f"Could not locate dictionary file." + ) + dicts = { + 'default': dict_path + } + return PhonemeDictionary( + dictionaries=dicts, + extra_phonemes=hparams.get('extra_phonemes'), + merged_groups=hparams.get('merged_phoneme_groups') + ) diff --git a/utils/plot.py b/utils/plot.py index b76e0726c..48cb9c430 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -106,7 +106,7 @@ def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None, title return fig -def distribution_to_figure(title, x_label, y_label, items: list, values: list, zoom=0.8): +def distribution_to_figure(title, x_label, y_label, items: list, values: list, zoom=0.8, rotate=False): fig = plt.figure(figsize=(int(len(items) * zoom), 10)) plt.bar(x=items, height=values) plt.tick_params(labelsize=15) @@ -117,4 +117,6 @@ def distribution_to_figure(title, x_label, y_label, items: list, values: list, z plt.title(title, fontsize=30) plt.xlabel(x_label, fontsize=20) plt.ylabel(y_label, fontsize=20) + if rotate: + fig.autofmt_xdate(rotation=45) return fig diff --git a/utils/text_encoder.py b/utils/text_encoder.py deleted file mode 100644 index 4b7815c46..000000000 --- a/utils/text_encoder.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np - -PAD = '' -PAD_INDEX = 0 - - -class TokenTextEncoder: - """Encoder based on a user-supplied vocabulary (file or list).""" - - def __init__(self, vocab_list): - """Initialize from a file or list, one token per line. - - Handling of reserved tokens works as follows: - - When initializing from a list, we add reserved tokens to the vocab. - - Args: - vocab_list: If not None, a list of elements of the vocabulary. - """ - self.vocab_list = sorted(vocab_list) - - def encode(self, sentence): - """Converts a space-separated string of phones to a list of ids.""" - phones = sentence.strip().split() if isinstance(sentence, str) else sentence - return [self.vocab_list.index(ph) + 1 if ph != PAD else PAD_INDEX for ph in phones] - - def decode(self, ids, strip_padding=False): - if strip_padding: - ids = np.trim_zeros(ids) - ids = list(ids) - return ' '.join([ - self.vocab_list[_id - 1] if _id >= 1 else PAD - for _id in ids - ]) - - @property - def vocab_size(self): - return len(self.vocab_list) + 1 - - def __len__(self): - return self.vocab_size - - def store_to_file(self, filename): - """Write vocab file to disk. - - Vocab files have one token per line. The file ends in a newline. Reserved - tokens are written to the vocab file as well. - - Args: - filename: Full path of the file to store the vocab to. - """ - with open(filename, 'w', encoding='utf8') as f: - print(PAD, file=f) - [print(tok, file=f) for tok in self.vocab_list]
visibilityacoustic, variance