diff --git a/basics/base_binarizer.py b/basics/base_binarizer.py index c66b6059f..e84c5333a 100644 --- a/basics/base_binarizer.py +++ b/basics/base_binarizer.py @@ -1,5 +1,6 @@ import json import pathlib +import pickle import random import shutil import warnings @@ -94,55 +95,55 @@ def build_spk_map(self): def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): raise NotImplementedError() - def split_train_valid_set(self): + def split_train_valid_set(self, item_names): """ Split the dataset into training set and validation set. :return: train_item_names, valid_item_names """ - prefixes = set([str(pr) for pr in hparams['test_prefixes']]) - valid_item_names = set() + prefixes = {str(pr): 1 for pr in hparams['test_prefixes']} + valid_item_names = {} # Add prefixes that specified speaker index and matches exactly item name to test set for prefix in deepcopy(prefixes): - if prefix in self.item_names: - valid_item_names.add(prefix) - prefixes.remove(prefix) + if prefix in item_names: + valid_item_names[prefix] = 1 + prefixes.pop(prefix) # Add prefixes that exactly matches item name without speaker id to test set for prefix in deepcopy(prefixes): matched = False - for name in self.item_names: + for name in item_names: if name.split(':')[-1] == prefix: - valid_item_names.add(name) + valid_item_names[name] = 1 matched = True if matched: - prefixes.remove(prefix) + prefixes.pop(prefix) # Add names with one of the remaining prefixes to test set for prefix in deepcopy(prefixes): matched = False - for name in self.item_names: + for name in item_names: if name.startswith(prefix): - valid_item_names.add(name) + valid_item_names[name] = 1 matched = True if matched: - prefixes.remove(prefix) + prefixes.pop(prefix) for prefix in deepcopy(prefixes): matched = False - for name in self.item_names: + for name in item_names: if name.split(':')[-1].startswith(prefix): - valid_item_names.add(name) + valid_item_names[name] = 1 matched = True if matched: - prefixes.remove(prefix) + prefixes.pop(prefix) if len(prefixes) != 0: warnings.warn( - f'The following rules in test_prefixes have no matching names in the dataset: {sorted(prefixes)}', + f'The following rules in test_prefixes have no matching names in the dataset: {", ".join(prefixes.keys())}', category=UserWarning ) warnings.filterwarnings('default') - valid_item_names = sorted(list(valid_item_names)) + valid_item_names = list(valid_item_names.keys()) assert len(valid_item_names) > 0, 'Validation set is empty!' - train_item_names = [x for x in self.item_names if x not in set(valid_item_names)] + train_item_names = [x for x in item_names if x not in set(valid_item_names)] assert len(train_item_names) > 0, 'Training set is empty!' return train_item_names, valid_item_names @@ -169,7 +170,7 @@ def process(self): for ds_id, spk_id, data_dir in zip(range(len(self.raw_data_dirs)), self.spk_ids, self.raw_data_dirs): self.load_meta_data(pathlib.Path(data_dir), ds_id=ds_id, spk_id=spk_id) self.item_names = sorted(list(self.items.keys())) - self._train_item_names, self._valid_item_names = self.split_train_valid_set() + self._train_item_names, self._valid_item_names = self.split_train_valid_set(self.item_names) if self.binarization_args['shuffle']: random.seed(hparams['seed']) @@ -249,9 +250,9 @@ def check_coverage(self): def process_dataset(self, prefix, num_workers=0, apply_augmentation=False): args = [] builder = IndexedDatasetBuilder(self.binary_data_dir, prefix=prefix, allowed_attr=self.data_attrs) - lengths = [] - total_sec = 0 - total_raw_sec = 0 + total_sec = {k: 0.0 for k in self.spk_map} + total_raw_sec = {k: 0.0 for k in self.spk_map} + extra_info = {'names': {}, 'spk_ids': {}, 'spk_names': {}, 'lengths': {}} for item_name, meta_data in self.meta_data_iterator(prefix): args.append([item_name, meta_data, self.binarization_args]) @@ -259,19 +260,35 @@ def process_dataset(self, prefix, num_workers=0, apply_augmentation=False): aug_map = self.arrange_data_augmentation(self.meta_data_iterator(prefix)) if apply_augmentation else {} def postprocess(_item): - nonlocal total_sec, total_raw_sec + nonlocal total_sec, total_raw_sec, extra_info if _item is None: return - builder.add_item(_item) - lengths.append(_item['length']) - total_sec += _item['seconds'] - total_raw_sec += _item['seconds'] + item_no = builder.add_item(_item) + for k, v in _item.items(): + if isinstance(v, np.ndarray): + if k not in extra_info: + extra_info[k] = {} + extra_info[k][item_no] = v.shape[0] + extra_info['names'][item_no] = _item['name'].split(':', 1)[-1] + extra_info['spk_ids'][item_no] = _item['spk_id'] + extra_info['spk_names'][item_no] = _item['spk_name'] + extra_info['lengths'][item_no] = _item['length'] + total_raw_sec[_item['spk_name']] += _item['seconds'] + total_sec[_item['spk_name']] += _item['seconds'] for task in aug_map.get(_item['name'], []): aug_item = task['func'](_item, **task['kwargs']) - builder.add_item(aug_item) - lengths.append(aug_item['length']) - total_sec += aug_item['seconds'] + aug_item_no = builder.add_item(aug_item) + for k, v in aug_item.items(): + if isinstance(v, np.ndarray): + if k not in extra_info: + extra_info[k] = {} + extra_info[k][aug_item_no] = v.shape[0] + extra_info['names'][aug_item_no] = aug_item['name'].split(':', 1)[-1] + extra_info['spk_ids'][aug_item_no] = aug_item['spk_id'] + extra_info['spk_names'][aug_item_no] = aug_item['spk_name'] + extra_info['lengths'][aug_item_no] = aug_item['length'] + total_sec[aug_item['spk_name']] += aug_item['seconds'] try: if num_workers > 0: @@ -286,21 +303,38 @@ def postprocess(_item): for a in tqdm(args): item = self.process_item(*a) postprocess(item) + for k in extra_info: + for item_no in range(len(args)): + assert item_no in extra_info[k], f'Item numbering is not consecutive.' + extra_info[k] = list(map(lambda x: x[1], sorted(extra_info[k].items(), key=lambda x: x[0]))) except KeyboardInterrupt: builder.finalize() raise builder.finalize() - with open(self.binary_data_dir / f'{prefix}.lengths', 'wb') as f: + if prefix == "train": + extra_info.pop("names") + extra_info.pop("spk_names") + with open(self.binary_data_dir / f"{prefix}.meta", "wb") as f: # noinspection PyTypeChecker - np.save(f, lengths) - + pickle.dump(extra_info, f) if apply_augmentation: - print(f'| {prefix} total duration (before augmentation): {total_raw_sec:.2f}s') + print(f"| {prefix} total duration (before augmentation): {sum(total_raw_sec.values()):.2f}s") + print( + f"| {prefix} respective duration (before augmentation): " + + ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items()) + ) print( - f'| {prefix} total duration (after augmentation): {total_sec:.2f}s ({total_sec / total_raw_sec:.2f}x)') + f"| {prefix} total duration (after augmentation): " + f"{sum(total_sec.values()):.2f}s ({sum(total_sec.values()) / sum(total_raw_sec.values()):.2f}x)" + ) + print( + f"| {prefix} respective duration (after augmentation): " + + ', '.join(f'{k}={v:.2f}s' for k, v in total_sec.items()) + ) else: - print(f'| {prefix} total duration: {total_raw_sec:.2f}s') + print(f"| {prefix} total duration: {sum(total_raw_sec.values()):.2f}s") + print(f"| {prefix} respective duration: " + ', '.join(f'{k}={v:.2f}s' for k, v in total_raw_sec.items())) def arrange_data_augmentation(self, data_iterator): """ diff --git a/basics/base_dataset.py b/basics/base_dataset.py index 5aeab9ae5..72c64d766 100644 --- a/basics/base_dataset.py +++ b/basics/base_dataset.py @@ -1,6 +1,7 @@ import os +import pickle -import numpy as np +import torch from torch.utils.data import Dataset from utils.hparams import hparams @@ -22,32 +23,36 @@ class BaseDataset(Dataset): the index function. """ - def __init__(self, prefix): + def __init__(self, prefix, size_key='lengths', preload=False): super().__init__() self.prefix = prefix self.data_dir = hparams['binary_data_dir'] - self.sizes = np.load(os.path.join(self.data_dir, f'{self.prefix}.lengths')) - self.indexed_ds = IndexedDataset(self.data_dir, self.prefix) - - @property - def _sizes(self): - return self.sizes + with open(os.path.join(self.data_dir, f'{self.prefix}.meta'), 'rb') as f: + self.metadata = pickle.load(f) + self.sizes = self.metadata[size_key] + self._indexed_ds = IndexedDataset(self.data_dir, self.prefix) + if preload: + self.indexed_ds = [self._indexed_ds[i] for i in range(len(self._indexed_ds))] + del self._indexed_ds + else: + self.indexed_ds = self._indexed_ds def __getitem__(self, index): - return self.indexed_ds[index] + return {'_idx': index, **self.indexed_ds[index]} def __len__(self): - return len(self._sizes) + return len(self.sizes) def num_frames(self, index): - return self.size(index) + return self.sizes[index] def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" - return self._sizes[index] + return self.sizes[index] def collater(self, samples): return { - 'size': len(samples) + 'size': len(samples), + 'indices': torch.LongTensor([s['_idx'] for s in samples]) } diff --git a/basics/base_task.py b/basics/base_task.py index 70d3d692c..5e6890eef 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -17,14 +17,13 @@ from torchmetrics import Metric, MeanMetric import lightning.pytorch as pl from lightning.pytorch.callbacks import LearningRateMonitor -from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.rank_zero import rank_zero_debug, rank_zero_info, rank_zero_only from basics.base_module import CategorizedModule from utils.hparams import hparams from utils.training_utils import ( DsModelCheckpoint, DsTQDMProgressBar, - DsBatchSampler, DsEvalBatchSampler, + DsBatchSampler, DsTensorBoardLogger, get_latest_checkpoint_path, get_strategy ) from utils.phoneme_utils import locate_dictionary, build_phoneme_list @@ -60,9 +59,7 @@ class BaseTask(pl.LightningModule): """ def __init__(self, *args, **kwargs): - # dataset configs super().__init__(*args, **kwargs) - self.dataset_cls = None self.max_batch_frames = hparams['max_batch_frames'] self.max_batch_size = hparams['max_batch_size'] self.max_val_batch_frames = hparams['max_val_batch_frames'] @@ -73,31 +70,26 @@ def __init__(self, *args, **kwargs): hparams['max_val_batch_size'] = self.max_val_batch_size = self.max_batch_size self.training_sampler = None - self.model = None self.skip_immediate_validation = False self.skip_immediate_ckpt_save = False - self.valid_losses: Dict[str, Metric] = { - 'total_loss': MeanMetric() - } - self.valid_metric_names = set() + self.phone_encoder = self.build_phone_encoder() + self.build_model() + + self.valid_losses: Dict[str, Metric] = {} + self.valid_metrics: Dict[str, Metric] = {} + + def _finish_init(self): + self.register_validation_loss('total_loss') + self.build_losses_and_metrics() ########### # Training, validation and testing ########### def setup(self, stage): - self.phone_encoder = self.build_phone_encoder() - self.model = self.build_model() - # utils.load_warp(self) - self.unfreeze_all_params() - if hparams['freezing_enabled']: - self.freeze_params() - if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: - self.load_finetune_ckpt(self.load_pre_train_model()) - self.print_arch() - self.build_losses_and_metrics() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) + self.num_replicas = (self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1) def get_need_freeze_state_dict_key(self, model_state_dict) -> list: key_list = [] @@ -123,7 +115,6 @@ def unfreeze_all_params(self) -> None: def load_finetune_ckpt( self, state_dict ) -> None: - adapt_shapes = hparams['finetune_strict_shapes'] if not adapt_shapes: cur_model_state_dict = self.state_dict() @@ -139,7 +130,6 @@ def load_finetune_ckpt( self.load_state_dict(state_dict, strict=False) def load_pre_train_model(self): - pre_train_ckpt_path = hparams.get('finetune_ckpt_path') blacklist = hparams.get('finetune_ignored_params') # whitelist=hparams.get('pre_train_whitelist') @@ -181,9 +171,19 @@ def build_phone_encoder(): phone_list = build_phoneme_list() return TokenTextEncoder(vocab_list=phone_list) - def build_model(self): + def _build_model(self): raise NotImplementedError() + def build_model(self): + self.model = self._build_model() + # utils.load_warp(self) + self.unfreeze_all_params() + if hparams['freezing_enabled']: + self.freeze_params() + if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: + self.load_finetune_ckpt(self.load_pre_train_model()) + self.print_arch() + @rank_zero_only def print_arch(self): utils.print_arch(self.model) @@ -191,10 +191,13 @@ def print_arch(self): def build_losses_and_metrics(self): raise NotImplementedError() - def register_metric(self, name: str, metric: Metric): + def register_validation_metric(self, name: str, metric: Metric): assert isinstance(metric, Metric) - setattr(self, name, metric) - self.valid_metric_names.add(name) + self.valid_metrics[name] = metric + + def register_validation_loss(self, name: str, Aggregator: Metric = MeanMetric): + assert issubclass(Aggregator, Metric) + self.valid_losses[name] = Aggregator() def run_model(self, sample, infer=False): """ @@ -216,7 +219,7 @@ def _training_step(self, sample): total_loss = sum(losses.values()) return total_loss, {**losses, 'batch_size': float(sample['size'])} - def training_step(self, sample, batch_idx, optimizer_idx=-1): + def training_step(self, sample, batch_idx): total_loss, log_outputs = self._training_step(sample) # logs to progress bar @@ -237,10 +240,16 @@ def _on_validation_start(self): pass def on_validation_start(self): + if self.skip_immediate_validation: + rank_zero_debug("Skip validation") + return self._on_validation_start() for metric in self.valid_losses.values(): metric.to(self.device) metric.reset() + for metric in self.valid_metrics.values(): + metric.to(self.device) + metric.reset() def _validation_step(self, sample, batch_idx): """ @@ -258,34 +267,32 @@ def validation_step(self, sample, batch_idx): :param batch_idx: """ if self.skip_immediate_validation: - rank_zero_debug(f"Skip validation {batch_idx}") - return {} - with torch.autocast(self.device.type, enabled=False): - losses, weight = self._validation_step(sample, batch_idx) - losses = { - 'total_loss': sum(losses.values()), - **losses - } - for k, v in losses.items(): - if k not in self.valid_losses: - self.valid_losses[k] = MeanMetric().to(self.device) - self.valid_losses[k].update(v, weight=weight) - return losses + rank_zero_debug("Skip validation") + return + if sample['size'] > 0: + with torch.autocast(self.device.type, enabled=False): + losses, weight = self._validation_step(sample, batch_idx) + losses = { + 'total_loss': sum(losses.values()), + **losses + } + for k, v in losses.items(): + self.valid_losses[k].update(v, weight=weight) + + def _on_validation_epoch_end(self): + pass def on_validation_epoch_end(self): if self.skip_immediate_validation: self.skip_immediate_validation = False self.skip_immediate_ckpt_save = True return + self._on_validation_epoch_end() loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} + metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) - for metric in self.valid_losses.values(): - metric.reset() - metric_vals = {k: getattr(self, k).compute() for k in self.valid_metric_names} self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) - for metric_name in self.valid_metric_names: - getattr(self, metric_name).reset() # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): @@ -331,36 +338,45 @@ def train_dataloader(self): self.train_dataset, max_batch_frames=self.max_batch_frames, max_batch_size=self.max_batch_size, - num_replicas=(self.trainer.distributed_sampler_kwargs or {}).get('num_replicas', 1), - rank=(self.trainer.distributed_sampler_kwargs or {}).get('rank', 0), + num_replicas=self.num_replicas, + rank=self.global_rank, sort_by_similar_size=hparams['sort_by_len'], + size_reversed=True, required_batch_count_multiple=hparams['accumulate_grad_batches'], shuffle_sample=True, - shuffle_batch=False, + shuffle_batch=True, seed=hparams['seed'] ) - return torch.utils.data.DataLoader(self.train_dataset, - collate_fn=self.train_dataset.collater, - batch_sampler=self.training_sampler, - num_workers=hparams['ds_workers'], - prefetch_factor=hparams['dataloader_prefetch_factor'], - pin_memory=True, - persistent_workers=True) + return torch.utils.data.DataLoader( + self.train_dataset, + collate_fn=self.train_dataset.collater, + batch_sampler=self.training_sampler, + num_workers=hparams['ds_workers'], + prefetch_factor=hparams['dataloader_prefetch_factor'], + pin_memory=True, + persistent_workers=True + ) def val_dataloader(self): - sampler = DsEvalBatchSampler( + sampler = DsBatchSampler( self.valid_dataset, max_batch_frames=self.max_val_batch_frames, max_batch_size=self.max_val_batch_size, - rank=(self.trainer.distributed_sampler_kwargs or {}).get('rank', 0), - batch_by_size=False + num_replicas=self.num_replicas, + rank=self.global_rank, + shuffle_sample=False, + shuffle_batch=False, + disallow_empty_batch=False, + pad_batch_assignment=False + ) + return torch.utils.data.DataLoader( + self.valid_dataset, + collate_fn=self.valid_dataset.collater, + batch_sampler=sampler, + num_workers=hparams['ds_workers'], + prefetch_factor=hparams['dataloader_prefetch_factor'], + persistent_workers=True ) - return torch.utils.data.DataLoader(self.valid_dataset, - collate_fn=self.valid_dataset.collater, - batch_sampler=sampler, - num_workers=hparams['ds_workers'], - prefetch_factor=hparams['dataloader_prefetch_factor'], - shuffle=False) def test_dataloader(self): return self.val_dataloader() @@ -392,13 +408,7 @@ def start(cls): accelerator=hparams['pl_trainer_accelerator'], devices=hparams['pl_trainer_devices'], num_nodes=hparams['pl_trainer_num_nodes'], - strategy=get_strategy( - accelerator=hparams['pl_trainer_accelerator'], - devices=hparams['pl_trainer_devices'], - num_nodes=hparams['pl_trainer_num_nodes'], - strategy=hparams['pl_trainer_strategy'], - backend=hparams['ddp_backend'] - ), + strategy=get_strategy(hparams['pl_trainer_strategy']), precision=hparams['pl_trainer_precision'], callbacks=[ DsModelCheckpoint( @@ -417,10 +427,10 @@ def start(cls): # LearningRateMonitor(logging_interval='step'), DsTQDMProgressBar(), ], - logger=TensorBoardLogger( + logger=DsTensorBoardLogger( save_dir=str(work_dir), name='lightning_logs', - version='lastest' + version='latest' ), gradient_clip_val=hparams['clip_grad_norm'], val_check_interval=hparams['val_check_interval'] * hparams['accumulate_grad_batches'], diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index dc3b14bc8..4b4d91897 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -103,6 +103,7 @@ lr_scheduler_args: gamma: 0.5 max_batch_frames: 80000 max_batch_size: 48 +dataset_size_key: 'lengths' val_with_vocoder: true val_check_interval: 2000 num_valid_plots: 10 diff --git a/configs/base.yaml b/configs/base.yaml index 6f16e22f6..536c4e875 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -88,8 +88,11 @@ pl_trainer_accelerator: 'auto' pl_trainer_devices: 'auto' pl_trainer_precision: '32-true' pl_trainer_num_nodes: 1 -pl_trainer_strategy: 'auto' -ddp_backend: 'nccl' # choose from 'gloo', 'nccl', 'nccl_no_p2p' +pl_trainer_strategy: + name: auto + process_group_backend: nccl + find_unused_parameters: false +nccl_p2p: true ########### # finetune diff --git a/configs/variance.yaml b/configs/variance.yaml index 44c534820..1951f685e 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -106,6 +106,7 @@ lr_scheduler_args: gamma: 0.75 max_batch_frames: 80000 max_batch_size: 48 +dataset_size_key: 'lengths' val_check_interval: 2000 num_valid_plots: 10 max_updates: 288000 diff --git a/docs/ConfigurationSchemas.md b/docs/ConfigurationSchemas.md index 43242efd1..6632d22a6 100644 --- a/docs/ConfigurationSchemas.md +++ b/docs/ConfigurationSchemas.md @@ -634,9 +634,9 @@ bool true -### ddp_backend +### dataset_size_key -The distributed training backend. +The key that indexes the binarized metadata to be used as the `sizes` when batching by size #### visibility @@ -648,7 +648,7 @@ training #### customizability -normal +not recommended #### type @@ -656,11 +656,7 @@ str #### default -nccl - -#### constraints - -Choose from 'gloo', 'nccl', 'nccl_no_p2p'. Windows platforms may use 'gloo'; Linux platforms may use 'nccl'; if Linux ddp gets stuck, use 'nccl_no_p2p'. +lengths ### dictionary @@ -1852,7 +1848,7 @@ training #### customizability -reserved +normal #### type @@ -1876,7 +1872,7 @@ training #### customizability -reserved +normal #### type @@ -1958,6 +1954,30 @@ float 0.06 +### nccl_p2p + +Whether to enable P2P when using NCCL as the backend. Turn it to `false` if the training process is stuck upon beginning. + +#### visibility + +all + +#### scope + +training + +#### customizability + +normal + +#### type + +bool + +#### default + +true + ### num_ckpt_keep Number of newest checkpoints kept during training. @@ -2488,7 +2508,15 @@ int ### pl_trainer_strategy -Strategies of the Lightning trainer behavior. +Arguments of Lightning Strategy. Values will be used as keyword arguments when constructing the Strategy object. + +#### type + +dict + +### pl_trainer_strategy.name + +Strategy name for the Lightning trainer. #### visibility diff --git a/modules/nsf_hifigan/env.py b/modules/nsf_hifigan/env.py index b576e130e..ebb9486d3 100644 --- a/modules/nsf_hifigan/env.py +++ b/modules/nsf_hifigan/env.py @@ -1,7 +1,30 @@ class AttrDict(dict): + """A dictionary with attribute-style access. It maps attribute access to + the real dictionary. """ def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self + dict.__init__(self, *args, **kwargs) - def __getattr__(self, item): - return self[item] + def __getstate__(self): + return self.__dict__.items() + + def __setstate__(self, items): + for key, val in items: + self.__dict__[key] = val + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, dict.__repr__(self)) + + def __setitem__(self, key, value): + return super(AttrDict, self).__setitem__(key, value) + + def __getitem__(self, name): + return super(AttrDict, self).__getitem__(name) + + def __delitem__(self, name): + return super(AttrDict, self).__delitem__(name) + + __getattr__ = __getitem__ + __setattr__ = __setitem__ + + def copy(self): + return AttrDict(self) diff --git a/preprocessing/acoustic_binarizer.py b/preprocessing/acoustic_binarizer.py index 227f56963..c6cf48b24 100644 --- a/preprocessing/acoustic_binarizer.py +++ b/preprocessing/acoustic_binarizer.py @@ -56,19 +56,19 @@ def __init__(self): def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): meta_data_dict = {} if (raw_data_dir / 'transcriptions.csv').exists(): - for utterance_label in csv.DictReader( - open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf-8') - ): - item_name = utterance_label['name'] - temp_dict = { - 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': utterance_label['ph_seq'].split(), - 'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()], - 'spk_id': spk_id - } - assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ - f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' - meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict + with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf-8') as f: + for utterance_label in csv.DictReader(f): + item_name = utterance_label['name'] + temp_dict = { + 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), + 'ph_seq': utterance_label['ph_seq'].split(), + 'ph_dur': [float(x) for x in utterance_label['ph_dur'].split()], + 'spk_id': spk_id, + 'spk_name': self.speakers[ds_id], + } + assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ + f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' + meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict else: raise FileNotFoundError( f'transcriptions.csv not found in {raw_data_dir}. ' @@ -90,6 +90,7 @@ def process_item(self, item_name, meta_data, binarization_args): 'name': item_name, 'wav_fn': meta_data['wav_fn'], 'spk_id': meta_data['spk_id'], + 'spk_name': meta_data['spk_name'], 'seconds': seconds, 'length': length, 'mel': mel, diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 16a672cb2..ba6b831b1 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -102,51 +102,51 @@ def load_attr_from_ds(self, ds_id, name, attr, idx=0): def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk_id): meta_data_dict = {} - for utterance_label in csv.DictReader( - open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') - ): - utterance_label: dict - item_name = utterance_label['name'] - item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 - - def require(attr): - if self.prefer_ds: - value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) - else: - value = None - if value is None: - value = utterance_label.get(attr) - if value is None: - raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.') - return value - - temp_dict = { - 'ds_idx': item_idx, - 'spk_id': spk_id, - 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), - 'ph_seq': require('ph_seq').split(), - 'ph_dur': [float(x) for x in require('ph_dur').split()] - } - - assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ - f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' - - if hparams['predict_dur']: - temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()] - assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ - f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' - - if hparams['predict_pitch']: - temp_dict['note_seq'] = require('note_seq').split() - temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()] - assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ - f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' - assert any([note != 'rest' for note in temp_dict['note_seq']]), \ - f'All notes are rest in \'{item_name}\'.' - if hparams['use_glide_embed']: - temp_dict['note_glide'] = require('note_glide').split() - - meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict + with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f: + for utterance_label in csv.DictReader(f): + utterance_label: dict + item_name = utterance_label['name'] + item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 + + def require(attr): + if self.prefer_ds: + value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) + else: + value = None + if value is None: + value = utterance_label.get(attr) + if value is None: + raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.') + return value + + temp_dict = { + 'ds_idx': item_idx, + 'spk_id': spk_id, + 'spk_name': self.speakers[ds_id], + 'wav_fn': str(raw_data_dir / 'wavs' / f'{item_name}.wav'), + 'ph_seq': require('ph_seq').split(), + 'ph_dur': [float(x) for x in require('ph_dur').split()] + } + + assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ + f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' + + if hparams['predict_dur']: + temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()] + assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ + f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' + + if hparams['predict_pitch']: + temp_dict['note_seq'] = require('note_seq').split() + temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()] + assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ + f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' + assert any([note != 'rest' for note in temp_dict['note_seq']]), \ + f'All notes are rest in \'{item_name}\'.' + if hparams['use_glide_embed']: + temp_dict['note_glide'] = require('note_glide').split() + + meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict self.items.update(meta_data_dict) @@ -233,6 +233,7 @@ def process_item(self, item_name, meta_data, binarization_args): 'name': item_name, 'wav_fn': meta_data['wav_fn'], 'spk_id': meta_data['spk_id'], + 'spk_name': meta_data['spk_name'], 'seconds': seconds, 'length': length, 'tokens': np.array(self.phone_encoder.encode(meta_data['ph_seq']), dtype=np.int64) diff --git a/scripts/train.py b/scripts/train.py index 1df7b6bc7..d4a223677 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -13,7 +13,7 @@ from utils.hparams import set_hparams, hparams set_hparams() -if hparams['ddp_backend'] == 'nccl_no_p2p': +if not hparams['nccl_p2p']: print("Disabling NCCL P2P") os.environ['NCCL_P2P_DISABLE'] = '1' diff --git a/training/acoustic_task.py b/training/acoustic_task.py index 04dedb65c..68cb5f9a4 100644 --- a/training/acoustic_task.py +++ b/training/acoustic_task.py @@ -14,14 +14,14 @@ from modules.toplevel import DiffSingerAcoustic, ShallowDiffusionOutput from modules.vocoders.registry import get_vocoder_cls from utils.hparams import hparams -from utils.plot import spec_to_figure, curve_to_figure +from utils.plot import spec_to_figure matplotlib.use('Agg') class AcousticDataset(BaseDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, prefix, preload=False): + super(AcousticDataset, self).__init__(prefix, hparams['dataset_size_key'], preload) self.required_variances = {} # key: variance name, value: padding value if hparams.get('use_energy_embed', False): self.required_variances['energy'] = 0.0 @@ -34,6 +34,8 @@ def __init__(self, *args, **kwargs): def collater(self, samples): batch = super().collater(samples) + if batch['size'] == 0: + return batch tokens = utils.collate_nd([s['tokens'] for s in samples], 0) f0 = utils.collate_nd([s['f0'] for s in samples], 0.0) @@ -76,8 +78,9 @@ def __init__(self): self.required_variances.append('energy') if hparams.get('use_breathiness_embed', False): self.required_variances.append('breathiness') + super()._finish_init() - def build_model(self): + def _build_model(self): return DiffSingerAcoustic( vocab_size=len(self.phone_encoder), out_dims=hparams['audio_num_mel_bins'] @@ -88,7 +91,9 @@ def build_losses_and_metrics(self): if self.use_shallow_diffusion: self.aux_mel_loss = build_aux_loss(self.shallow_args['aux_decoder_arch']) self.lambda_aux_mel_loss = hparams['lambda_aux_mel_loss'] + self.register_validation_loss('aux_mel_loss') self.mel_loss = DiffusionNoiseLoss(loss_type=hparams['diff_loss_type']) + self.register_validation_loss('mel_loss') def run_model(self, sample, infer=False): txt_tokens = sample['tokens'] # [B, T_ph] @@ -140,57 +145,66 @@ def _on_validation_start(self): def _validation_step(self, sample, batch_idx): losses = self.run_model(sample, infer=False) - - if batch_idx < hparams['num_valid_plots'] \ - and (self.trainer.distributed_sampler_kwargs or {}).get('rank', 0) == 0: + if sample['size'] > 0 and min(sample['indices']) < hparams['num_valid_plots']: mel_out: ShallowDiffusionOutput = self.run_model(sample, infer=True) - - if self.use_vocoder: - self.plot_wav( - batch_idx, gt_mel=sample['mel'], - aux_mel=mel_out.aux_out, diff_mel=mel_out.diff_out, - f0=sample['f0'] - ) - if mel_out.aux_out is not None: - self.plot_mel(batch_idx, sample['mel'], mel_out.aux_out, name=f'auxmel_{batch_idx}') - if mel_out.diff_out is not None: - self.plot_mel(batch_idx, sample['mel'], mel_out.diff_out, name=f'diffmel_{batch_idx}') - + for i in range(len(sample['indices'])): + data_idx = sample['indices'][i] + if data_idx < hparams['num_valid_plots']: + if self.use_vocoder: + self.plot_wav( + data_idx, sample['mel'][i], + mel_out.aux_out[i] if mel_out.aux_out is not None else None, + mel_out.diff_out[i], + sample['f0'][i] + ) + if mel_out.aux_out is not None: + self.plot_mel(data_idx, sample['mel'][i], mel_out.aux_out[i], 'auxmel') + if mel_out.diff_out is not None: + self.plot_mel(data_idx, sample['mel'][i], mel_out.diff_out[i], 'diffmel') return losses, sample['size'] + ############ # validation plots ############ - def plot_wav(self, batch_idx, gt_mel, aux_mel=None, diff_mel=None, f0=None): - gt_mel = gt_mel[0].cpu().numpy() + def plot_wav(self, data_idx, gt_mel, aux_mel, diff_mel, f0): + f0_len = self.valid_dataset.metadata['f0'][data_idx] + mel_len = self.valid_dataset.metadata['mel'][data_idx] + gt_mel = gt_mel[:mel_len].unsqueeze(0) if aux_mel is not None: - aux_mel = aux_mel[0].cpu().numpy() + aux_mel = aux_mel[:mel_len].unsqueeze(0) if diff_mel is not None: - diff_mel = diff_mel[0].cpu().numpy() - f0 = f0[0].cpu().numpy() - if batch_idx not in self.logged_gt_wav: - gt_wav = self.vocoder.spec2wav(gt_mel, f0=f0) - self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], - global_step=self.global_step) - self.logged_gt_wav.add(batch_idx) + diff_mel = diff_mel[:mel_len].unsqueeze(0) + f0 = f0[:f0_len].unsqueeze(0) + if data_idx not in self.logged_gt_wav: + gt_wav = self.vocoder.spec2wav_torch(gt_mel, f0=f0) + self.logger.all_rank_experiment.add_audio( + f'gt_{data_idx}', gt_wav, + sample_rate=hparams['audio_sample_rate'], + global_step=self.global_step + ) + self.logged_gt_wav.add(data_idx) if aux_mel is not None: - aux_wav = self.vocoder.spec2wav(aux_mel, f0=f0) - self.logger.experiment.add_audio(f'aux_{batch_idx}', aux_wav, sample_rate=hparams['audio_sample_rate'], - global_step=self.global_step) + aux_wav = self.vocoder.spec2wav_torch(aux_mel, f0=f0) + self.logger.all_rank_experiment.add_audio( + f'aux_{data_idx}', aux_wav, + sample_rate=hparams['audio_sample_rate'], + global_step=self.global_step + ) if diff_mel is not None: - diff_wav = self.vocoder.spec2wav(diff_mel, f0=f0) - self.logger.experiment.add_audio(f'diff_{batch_idx}', diff_wav, sample_rate=hparams['audio_sample_rate'], - global_step=self.global_step) - - def plot_mel(self, batch_idx, spec, spec_out, name=None): - name = f'mel_{batch_idx}' if name is None else name + diff_wav = self.vocoder.spec2wav_torch(diff_mel, f0=f0) + self.logger.all_rank_experiment.add_audio( + f'diff_{data_idx}', diff_wav, + sample_rate=hparams['audio_sample_rate'], + global_step=self.global_step + ) + + def plot_mel(self, data_idx, gt_spec, out_spec, name_prefix='mel'): vmin = hparams['mel_vmin'] vmax = hparams['mel_vmax'] - spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) - self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step) - - def plot_curve(self, batch_idx, gt_curve, pred_curve, curve_name='curve'): - name = f'{curve_name}_{batch_idx}' - gt_curve = gt_curve[0].cpu().numpy() - pred_curve = pred_curve[0].cpu().numpy() - self.logger.experiment.add_figure(name, curve_to_figure(gt_curve, pred_curve), self.global_step) + mel_len = self.valid_dataset.metadata['mel'][data_idx] + spec_cat = torch.cat([(out_spec - gt_spec).abs() + vmin, gt_spec, out_spec], -1) + title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" + self.logger.all_rank_experiment.add_figure(f'{name_prefix}_{data_idx}', spec_to_figure( + spec_cat[:mel_len], vmin, vmax, title_text + ), global_step=self.global_step) diff --git a/training/variance_task.py b/training/variance_task.py index 730bfccd6..6ae7bed1d 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -20,14 +20,16 @@ class VarianceDataset(BaseDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, prefix, preload=False): + super(VarianceDataset, self).__init__(prefix, hparams['dataset_size_key'], preload) need_energy = hparams['predict_energy'] need_breathiness = hparams['predict_breathiness'] self.predict_variances = need_energy or need_breathiness def collater(self, samples): batch = super().collater(samples) + if batch['size'] == 0: + return batch tokens = utils.collate_nd([s['tokens'] for s in samples], 0) ph_dur = utils.collate_nd([s['ph_dur'] for s in samples], 0) @@ -94,8 +96,9 @@ def __init__(self): self.variance_prediction_list.append('breathiness') self.predict_variances = len(self.variance_prediction_list) > 0 self.lambda_var_loss = hparams['lambda_var_loss'] + super()._finish_init() - def build_model(self): + def _build_model(self): return DiffSingerVariance( vocab_size=len(self.phone_encoder), ) @@ -111,17 +114,20 @@ def build_losses_and_metrics(self): lambda_wdur=dur_hparams['lambda_wdur_loss'], lambda_sdur=dur_hparams['lambda_sdur_loss'] ) - self.register_metric('rhythm_corr', RhythmCorrectness(tolerance=0.05)) - self.register_metric('ph_dur_acc', PhonemeDurationAccuracy(tolerance=0.2)) + self.register_validation_loss('dur_loss') + self.register_validation_metric('rhythm_corr', RhythmCorrectness(tolerance=0.05)) + self.register_validation_metric('ph_dur_acc', PhonemeDurationAccuracy(tolerance=0.2)) if self.predict_pitch: self.pitch_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], ) - self.register_metric('pitch_acc', RawCurveAccuracy(tolerance=0.5)) + self.register_validation_loss('pitch_loss') + self.register_validation_metric('pitch_acc', RawCurveAccuracy(tolerance=0.5)) if self.predict_variances: self.var_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], ) + self.register_validation_loss('var_loss') def run_model(self, sample, infer=False): spk_ids = sample['spk_ids'] if self.use_spk_id else None # [B,] @@ -191,74 +197,83 @@ def run_model(self, sample, infer=False): def _validation_step(self, sample, batch_idx): losses = self.run_model(sample, infer=False) - - if batch_idx < hparams['num_valid_plots'] \ - and (self.trainer.distributed_sampler_kwargs or {}).get('rank', 0) == 0: - dur_pred, pitch_pred, variances_pred = self.run_model(sample, infer=True) - if dur_pred is not None: - tokens = sample['tokens'] - dur_gt = sample['ph_dur'] - ph2word = sample['ph2word'] - mask = tokens != 0 - self.rhythm_corr.update( - pdur_pred=dur_pred, pdur_target=dur_gt, ph2word=ph2word, mask=mask - ) - self.ph_dur_acc.update( - pdur_pred=dur_pred, pdur_target=dur_gt, ph2word=ph2word, mask=mask - ) - self.plot_dur(batch_idx, dur_gt, dur_pred, txt=tokens) - if pitch_pred is not None: - pred_pitch = sample['base_pitch'] + pitch_pred - gt_pitch = sample['pitch'] - mask = (sample['mel2ph'] > 0) & ~sample['uv'] - self.pitch_acc.update(pred=pred_pitch, target=gt_pitch, mask=mask) - self.plot_pitch( - batch_idx, - gt_pitch=gt_pitch, - pred_pitch=pred_pitch, - note_midi=sample['note_midi'], - note_dur=sample['note_dur'], - note_rest=sample['note_rest'] - ) - for name in self.variance_prediction_list: - variance = sample[name] - variance_pred = variances_pred[name] - self.plot_curve( - batch_idx, - gt_curve=variance, - pred_curve=variance_pred, - curve_name=name - ) - + if min(sample['indices']) < hparams['num_valid_plots']: + def sample_get(key, idx, abs_idx): + return sample[key][idx][:self.valid_dataset.metadata[key][abs_idx]].unsqueeze(0) + dur_preds, pitch_preds, variances_preds = self.run_model(sample, infer=True) + for i in range(len(sample['indices'])): + data_idx = sample['indices'][i] + if data_idx < hparams['num_valid_plots']: + if dur_preds is not None: + dur_len = self.valid_dataset.metadata['ph_dur'][data_idx] + tokens = sample_get('tokens', i, data_idx) + gt_dur = sample_get('ph_dur', i, data_idx) + pred_dur = dur_preds[i][:dur_len].unsqueeze(0) + ph2word = sample_get('ph2word', i, data_idx) + mask = tokens != 0 + self.valid_metrics['rhythm_corr'].update( + pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask + ) + self.valid_metrics['ph_dur_acc'].update( + pdur_pred=pred_dur, pdur_target=gt_dur, ph2word=ph2word, mask=mask + ) + self.plot_dur(data_idx, gt_dur, pred_dur, tokens) + if pitch_preds is not None: + pitch_len = self.valid_dataset.metadata['pitch'][data_idx] + pred_pitch = sample_get('base_pitch', i, data_idx) + pitch_preds[i][:pitch_len].unsqueeze(0) + gt_pitch = sample_get('pitch', i, data_idx) + mask = (sample_get('mel2ph', i, data_idx) > 0) & ~sample_get('uv', i, data_idx) + self.valid_metrics['pitch_acc'].update(pred=pred_pitch, target=gt_pitch, mask=mask) + self.plot_pitch( + data_idx, + gt_pitch=gt_pitch, + pred_pitch=pred_pitch, + note_midi=sample_get('note_midi', i, data_idx), + note_dur=sample_get('note_dur', i, data_idx), + note_rest=sample_get('note_rest', i, data_idx) + ) + for name in self.variance_prediction_list: + variance_len = self.valid_dataset.metadata[name][data_idx] + gt_variances = sample[name][i][:variance_len].unsqueeze(0) + pred_variances = variances_preds[name][i][:variance_len].unsqueeze(0) + self.plot_curve( + data_idx, + gt_curve=gt_variances, + pred_curve=pred_variances, + curve_name=name + ) return losses, sample['size'] + ############ # validation plots ############ - def plot_dur(self, batch_idx, gt_dur, pred_dur, txt=None): - name = f'dur_{batch_idx}' + def plot_dur(self, data_idx, gt_dur, pred_dur, txt=None): gt_dur = gt_dur[0].cpu().numpy() pred_dur = pred_dur[0].cpu().numpy() txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split() - self.logger.experiment.add_figure(name, dur_to_figure(gt_dur, pred_dur, txt), self.global_step) + title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" + self.logger.all_rank_experiment.add_figure(f'dur_{data_idx}', dur_to_figure( + gt_dur, pred_dur, txt, title_text + ), self.global_step) - def plot_pitch(self, batch_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest): - name = f'pitch_{batch_idx}' + def plot_pitch(self, data_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest): gt_pitch = gt_pitch[0].cpu().numpy() pred_pitch = pred_pitch[0].cpu().numpy() note_midi = note_midi[0].cpu().numpy() note_dur = note_dur[0].cpu().numpy() note_rest = note_rest[0].cpu().numpy() - self.logger.experiment.add_figure(name, pitch_note_to_figure( - gt_pitch, pred_pitch, note_midi, note_dur, note_rest + title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" + self.logger.all_rank_experiment.add_figure(f'pitch_{data_idx}', pitch_note_to_figure( + gt_pitch, pred_pitch, note_midi, note_dur, note_rest, title_text ), self.global_step) - def plot_curve(self, batch_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'): - name = f'{curve_name}_{batch_idx}' + def plot_curve(self, data_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'): gt_curve = gt_curve[0].cpu().numpy() pred_curve = pred_curve[0].cpu().numpy() if base_curve is not None: base_curve = base_curve[0].cpu().numpy() - self.logger.experiment.add_figure(name, curve_to_figure( - gt_curve, pred_curve, base_curve, grid=grid + title_text = f"{self.valid_dataset.metadata['spk_names'][data_idx]} - {self.valid_dataset.metadata['names'][data_idx]}" + self.logger.all_rank_experiment.add_figure(f'{curve_name}_{data_idx}', curve_to_figure( + gt_curve, pred_curve, base_curve, grid=grid, title=title_text ), self.global_step) diff --git a/utils/indexed_datasets.py b/utils/indexed_datasets.py index 81dd52637..fea8c8124 100644 --- a/utils/indexed_datasets.py +++ b/utils/indexed_datasets.py @@ -45,36 +45,37 @@ def __len__(self): class IndexedDatasetBuilder: - def __init__(self, path, prefix, allowed_attr=None): + def __init__(self, path, prefix, allowed_attr=None, auto_increment=True): self.path = pathlib.Path(path) / f'{prefix}.data' self.prefix = prefix - self.dset = None + self.dset = h5py.File(self.path, 'w') self.counter = 0 - self.lock = multiprocessing.Lock() + self.auto_increment = auto_increment if allowed_attr is not None: self.allowed_attr = set(allowed_attr) else: self.allowed_attr = None - def add_item(self, item): - if self.dset is None: - self.dset = h5py.File(self.path, 'w') + def add_item(self, item, item_no=None): + if self.auto_increment and item_no is not None or not self.auto_increment and item_no is None: + raise ValueError('auto_increment and provided item_no are mutually exclusive') if self.allowed_attr is not None: item = { k: item[k] for k in self.allowed_attr if k in item } - item_no = self.counter - self.counter += 1 + if self.auto_increment: + item_no = self.counter + self.counter += 1 for k, v in item.items(): if v is None: continue self.dset.create_dataset(f'{item_no}/{k}', data=v) + return item_no def finalize(self): - if self.dset is not None: - self.dset.close() + self.dset.close() if __name__ == "__main__": diff --git a/utils/plot.py b/utils/plot.py index d02d68653..e109c8d9c 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -4,16 +4,17 @@ from matplotlib.ticker import MultipleLocator -def spec_to_figure(spec, vmin=None, vmax=None): +def spec_to_figure(spec, vmin=None, vmax=None, title=''): if isinstance(spec, torch.Tensor): spec = spec.cpu().numpy() fig = plt.figure(figsize=(12, 9)) plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + plt.title(title, fontsize=22) plt.tight_layout() return fig -def dur_to_figure(dur_gt, dur_pred, txt): +def dur_to_figure(dur_gt, dur_pred, txt, title=''): if isinstance(dur_gt, torch.Tensor): dur_gt = dur_gt.cpu().numpy() if isinstance(dur_pred, torch.Tensor): @@ -35,12 +36,13 @@ def dur_to_figure(dur_gt, dur_pred, txt): plt.plot([dur_pred[i], dur_gt[i]], [12, 10], color='black', linewidth=2, linestyle=':') plt.yticks([]) plt.xlim(0, max(dur_pred[-1], dur_gt[-1])) - fig.legend() - fig.tight_layout() + plt.legend() + plt.title(title, fontsize=22) + plt.tight_layout() return fig -def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None): +def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None, title=''): if isinstance(pitch_gt, torch.Tensor): pitch_gt = pitch_gt.cpu().numpy() if isinstance(pitch_pred, torch.Tensor): @@ -51,7 +53,8 @@ def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=Non note_dur = note_dur.cpu().numpy() if isinstance(note_rest, torch.Tensor): note_rest = note_rest.cpu().numpy() - fig = plt.figure() + width = max(12, min(24, len(pitch_gt) // 200)) + fig = plt.figure(figsize=(width, 8)) if note_midi is not None and note_dur is not None: note_dur_acc = np.cumsum(note_dur) if note_rest is None: @@ -73,18 +76,20 @@ def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=Non plt.gca().yaxis.set_major_locator(MultipleLocator(1)) plt.grid(axis='y') plt.legend() + plt.title(title, fontsize=22) plt.tight_layout() return fig -def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None): +def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None, title=''): if isinstance(curve_gt, torch.Tensor): curve_gt = curve_gt.cpu().numpy() if isinstance(curve_pred, torch.Tensor): curve_pred = curve_pred.cpu().numpy() if isinstance(curve_base, torch.Tensor): curve_base = curve_base.cpu().numpy() - fig = plt.figure() + width = max(12, min(24, len(curve_gt) // 200)) + fig = plt.figure(figsize=(width, 8)) if curve_base is not None: plt.plot(curve_base, color='g', label='base') plt.plot(curve_gt, color='b', label='gt') @@ -94,6 +99,7 @@ def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None): plt.gca().yaxis.set_major_locator(MultipleLocator(grid)) plt.grid(axis='y') plt.legend() + plt.title(title, fontsize=22) plt.tight_layout() return fig diff --git a/utils/training_utils.py b/utils/training_utils.py index ad7f0473a..241b49d39 100644 --- a/utils/training_utils.py +++ b/utils/training_utils.py @@ -7,9 +7,10 @@ import lightning.pytorch as pl import numpy as np import torch +from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar -from lightning.pytorch.strategies import DDPStrategy -from lightning.pytorch.utilities.rank_zero import rank_zero_info +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only from torch.optim.lr_scheduler import LambdaLR from torch.utils.data.distributed import Sampler @@ -74,7 +75,11 @@ class DsBatchSampler(Sampler): def __init__(self, dataset, max_batch_frames, max_batch_size, sub_indices=None, num_replicas=None, rank=None, required_batch_count_multiple=1, batch_by_size=True, sort_by_similar_size=True, - shuffle_sample=False, shuffle_batch=False, seed=0, drop_last=False) -> None: + size_reversed=False, shuffle_sample=False, shuffle_batch=False, + disallow_empty_batch=True, pad_batch_assignment=True, seed=0, drop_last=False) -> None: + if rank >= num_replicas or rank < 0: + raise ValueError( + f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.max_batch_frames = max_batch_frames self.max_batch_size = max_batch_size @@ -84,8 +89,11 @@ def __init__(self, dataset, max_batch_frames, max_batch_size, sub_indices=None, self.required_batch_count_multiple = required_batch_count_multiple self.batch_by_size = batch_by_size self.sort_by_similar_size = sort_by_similar_size + self.size_reversed = size_reversed self.shuffle_sample = shuffle_sample self.shuffle_batch = shuffle_batch + self.disallow_empty_batch = disallow_empty_batch + self.pad_batch_assignment = pad_batch_assignment self.seed = seed self.drop_last = drop_last self.epoch = 0 @@ -96,6 +104,7 @@ def __form_batches(self): if self.formed == self.epoch + self.seed: return rng = np.random.default_rng(self.seed + self.epoch) + # Create indices if self.shuffle_sample: if self.sub_indices is not None: rng.shuffle(self.sub_indices) @@ -104,16 +113,17 @@ def __form_batches(self): indices = rng.permutation(len(self.dataset)) if self.sort_by_similar_size: - grid = int(hparams.get('sampler_frame_count_grid', 200)) + grid = int(hparams.get('sampler_frame_count_grid', 6)) assert grid > 0 - sizes = (np.round(np.array(self.dataset._sizes)[indices] / grid) * grid).clip(grid, None).astype( - np.int64) + sizes = (np.round(np.array(self.dataset.sizes)[indices] / grid) * grid).clip(grid, None) + sizes *= (-1 if self.size_reversed else 1) indices = indices[np.argsort(sizes, kind='mergesort')] indices = indices.tolist() else: indices = self.sub_indices if self.sub_indices is not None else list(range(len(self.dataset))) + # Batching if self.batch_by_size: batches = utils.batch_by_size( indices, self.dataset.num_frames, @@ -122,36 +132,51 @@ def __form_batches(self): ) else: batches = [indices[i:i + self.max_batch_size] for i in range(0, len(indices), self.max_batch_size)] + if len(batches) < self.num_replicas and self.disallow_empty_batch: + raise RuntimeError("There is not enough batch to assign to each node.") + # Either drop_last or separate the leftovers. floored_total_batch_count = (len(batches) // self.num_replicas) * self.num_replicas if self.drop_last and len(batches) > floored_total_batch_count: batches = batches[:floored_total_batch_count] leftovers = [] - else: + if len(batches) == 0: + raise RuntimeError("There is no batch left after dropping the last batch.") + elif self.shuffle_batch: leftovers = (rng.permutation(len(batches) - floored_total_batch_count) + floored_total_batch_count).tolist() + else: + leftovers = list(range(floored_total_batch_count, len(batches))) - batch_assignment = rng.permuted( - np.arange(floored_total_batch_count).reshape(-1, self.num_replicas).transpose(), axis=0 - )[self.rank].tolist() + # Initial batch assignment to current rank. + batch_assignment = np.arange(floored_total_batch_count).reshape(-1, self.num_replicas).transpose() + if self.shuffle_batch: + batch_assignment = rng.permuted(batch_assignment, axis=0)[self.rank].tolist() + else: + batch_assignment = batch_assignment[self.rank].tolist() + + # Assign leftovers or pad the batch assignment. floored_batch_count = len(batch_assignment) - ceiled_batch_count = floored_batch_count + (1 if len(leftovers) > 0 else 0) if self.rank < len(leftovers): batch_assignment.append(leftovers[self.rank]) - elif len(leftovers) > 0: + floored_batch_count += 1 + elif len(leftovers) > 0 and self.pad_batch_assignment: + if not batch_assignment: + raise RuntimeError("Cannot pad empty batch assignment.") batch_assignment.append(batch_assignment[self.epoch % floored_batch_count]) - if self.required_batch_count_multiple > 1 and ceiled_batch_count % self.required_batch_count_multiple != 0: - # batch_assignment = batch_assignment[:((floored_batch_count \ - # // self.required_batch_count_multiple) * self.required_batch_count_multiple)] + # Ensure the batch count is multiple of required_batch_count_multiple. + if self.required_batch_count_multiple > 1 and len(batch_assignment) % self.required_batch_count_multiple != 0: ceiled_batch_count = math.ceil( - ceiled_batch_count / self.required_batch_count_multiple) * self.required_batch_count_multiple + len(batch_assignment) / self.required_batch_count_multiple + ) * self.required_batch_count_multiple for i in range(ceiled_batch_count - len(batch_assignment)): batch_assignment.append( batch_assignment[(i + self.epoch * self.required_batch_count_multiple) % floored_batch_count]) - self.batches = [deepcopy(batches[i]) for i in batch_assignment] - - if self.shuffle_batch: - rng.shuffle(self.batches) + if batch_assignment: + self.batches = [deepcopy(batches[i]) for i in batch_assignment] + else: + self.batches = [[]] + self.formed = self.epoch + self.seed del indices del batches @@ -169,40 +194,9 @@ def __len__(self): def set_epoch(self, epoch): self.epoch = epoch + self.__form_batches() -class DsEvalBatchSampler(Sampler): - def __init__(self, dataset, max_batch_frames, max_batch_size, rank=None, batch_by_size=True) -> None: - self.dataset = dataset - self.max_batch_frames = max_batch_frames - self.max_batch_size = max_batch_size - self.rank = rank - self.batch_by_size = batch_by_size - self.batches = None - self.batch_size = max_batch_size - self.drop_last = False - - if self.rank == 0: - indices = list(range(len(self.dataset))) - if self.batch_by_size: - self.batches = utils.batch_by_size( - indices, self.dataset.num_frames, - max_batch_frames=self.max_batch_frames, max_batch_size=self.max_batch_size - ) - else: - self.batches = [ - indices[i:i + self.max_batch_size] - for i in range(0, len(indices), self.max_batch_size) - ] - else: - self.batches = [[0]] - - def __iter__(self): - return iter(self.batches) - - def __len__(self): - return len(self.batches) - # ==========PL related========== @@ -331,73 +325,50 @@ def get_metrics(self, trainer, model): return items -def get_strategy(accelerator, devices, num_nodes, strategy, backend): - if accelerator != 'auto' and accelerator != 'gpu': - return strategy - - from lightning.fabric.utilities.imports import _IS_INTERACTIVE - from lightning.pytorch.accelerators import AcceleratorRegistry - from lightning.pytorch.accelerators.cuda import CUDAAccelerator - from lightning.pytorch.accelerators.hpu import HPUAccelerator - from lightning.pytorch.accelerators.ipu import IPUAccelerator - from lightning.pytorch.accelerators.mps import MPSAccelerator - from lightning.pytorch.accelerators.tpu import TPUAccelerator - from lightning.pytorch.utilities.exceptions import MisconfigurationException - - def _choose_auto_accelerator(): - if TPUAccelerator.is_available(): - return "tpu" - if IPUAccelerator.is_available(): - return "ipu" - if HPUAccelerator.is_available(): - return "hpu" - if MPSAccelerator.is_available(): - return "mps" - if CUDAAccelerator.is_available(): - return "cuda" - return "cpu" - - def _choose_gpu_accelerator_backend(): - if MPSAccelerator.is_available(): - return "mps" - if CUDAAccelerator.is_available(): - return "cuda" - raise MisconfigurationException("No supported gpu backend found!") - - if accelerator == "auto": - _accelerator_flag = _choose_auto_accelerator() - elif accelerator == "gpu": - _accelerator_flag = _choose_gpu_accelerator_backend() - else: - return strategy - - if _accelerator_flag != "mps" and _accelerator_flag != "cuda": - return strategy - - _num_nodes_flag = int(num_nodes) if num_nodes is not None else 1 - _devices_flag = devices - - accelerator = AcceleratorRegistry.get(_accelerator_flag) - accelerator_cls = accelerator.__class__ - - if _devices_flag == "auto": - _devices_flag = accelerator.auto_device_count() - - _devices_flag = accelerator_cls.parse_devices(_devices_flag) - _parallel_devices = accelerator_cls.get_parallel_devices(_devices_flag) - - def get_ddp_strategy(_backend): - if _backend == 'gloo': - return DDPStrategy(process_group_backend='gloo', find_unused_parameters=False) - elif _backend == 'nccl' or _backend == 'nccl_no_p2p': - return DDPStrategy(process_group_backend='nccl', find_unused_parameters=False) +class DsTensorBoardLogger(TensorBoardLogger): + @property + def all_rank_experiment(self): + if rank_zero_only.rank == 0: + return self.experiment + if hasattr(self, "_all_rank_experiment") and self._all_rank_experiment is not None: + return self._all_rank_experiment + + assert rank_zero_only.rank != 0 + if self.root_dir: + self._fs.makedirs(self.root_dir, exist_ok=True) + + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter else: - raise ValueError(f'backend {_backend} is not valid.') - - if _num_nodes_flag > 1: - return get_ddp_strategy(backend) - if len(_parallel_devices) <= 1: - return strategy - if len(_parallel_devices) > 1 and _IS_INTERACTIVE: - return strategy - return get_ddp_strategy(backend) + from tensorboardX import SummaryWriter # type: ignore[no-redef] + + self._all_rank_experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) + return self._all_rank_experiment + + def finalize(self, status: str) -> None: + if rank_zero_only.rank == 0: + super().finalize(status) + elif hasattr(self, "_all_rank_experiment") and self._all_rank_experiment is not None: + self.all_rank_experiment.flush() + self.all_rank_experiment.close() + + def __getstate__(self): + state = super().__getstate__() + if "_all_rank_experiment" in state: + del state["_all_rank_experiment"] + return state + + +def get_strategy(strategy): + if strategy['name'] == 'auto': + return 'auto' + + from lightning.pytorch.strategies import StrategyRegistry + if strategy['name'] not in StrategyRegistry: + available_names = ", ".join(sorted(StrategyRegistry.keys())) or "none" + raise ValueError(f"Invalid strategy name {strategy['name']}. Available names: {available_names}") + + data = StrategyRegistry[strategy['name']] + params = data['init_params'] + params.update({k: v for k, v in strategy.items() if k != 'name'}) + return data['strategy'](**utils.filter_kwargs(params, data['strategy']))