From f953b7bb2f70e5a07f8cb4eaef8f11689fab0492 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 28 Jul 2023 22:05:58 +0800 Subject: [PATCH 1/3] Support objective evaluation metrics --- basics/base_task.py | 48 +++++++++++++++++++++++---------------- requirements.txt | 1 + training/acoustic_task.py | 2 +- training/variance_task.py | 2 +- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/basics/base_task.py b/basics/base_task.py index 059440bdb..f82e1442c 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -77,9 +77,10 @@ def __init__(self, *args, **kwargs): self.skip_immediate_validation = False self.skip_immediate_ckpt_save = False - self.valid_metrics: Dict[str, Metric] = { + self.valid_losses: Dict[str, Metric] = { 'total_loss': MeanMetric() } + self.valid_metric_names = set() ########### # Training, validation and testing @@ -91,7 +92,7 @@ def setup(self, stage): if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() - self.build_losses() + self.build_losses_and_metrics() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) @@ -163,9 +164,14 @@ def build_model(self): def print_arch(self): utils.print_arch(self.model) - def build_losses(self): + def build_losses_and_metrics(self): raise NotImplementedError() + def register_metric(self, name: str, metric: Metric): + assert isinstance(metric, Metric) + setattr(self, name, metric) + self.valid_metric_names.add(name) + def run_model(self, sample, infer=False): """ steps: @@ -194,8 +200,8 @@ def training_step(self, sample, batch_idx, optimizer_idx=-1): self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False) # logs to tensorboard if self.global_step % hparams['log_interval'] == 0: - tb_log = {f'tr/{k}': v for k, v in log_outputs.items()} - tb_log['tr/lr'] = self.lr_schedulers().get_last_lr()[0] + tb_log = {f'training/{k}': v for k, v in log_outputs.items()} + tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0] self.logger.log_metrics(tb_log, step=self.global_step) return total_loss @@ -208,7 +214,7 @@ def _on_validation_start(self): def on_validation_start(self): self._on_validation_start() - for metric in self.valid_metrics.values(): + for metric in self.valid_losses.values(): metric.to(self.device) metric.reset() @@ -231,27 +237,31 @@ def validation_step(self, sample, batch_idx): rank_zero_debug(f"Skip validation {batch_idx}") return {} with torch.autocast(self.device.type, enabled=False): - outputs, weight = self._validation_step(sample, batch_idx) - outputs = { - 'total_loss': sum(outputs.values()), - **outputs + losses, weight = self._validation_step(sample, batch_idx) + losses = { + 'total_loss': sum(losses.values()), + **losses } - for k, v in outputs.items(): - if k not in self.valid_metrics: - self.valid_metrics[k] = MeanMetric().to(self.device) - self.valid_metrics[k].update(v, weight=weight) - return outputs + for k, v in losses.items(): + if k not in self.valid_losses: + self.valid_losses[k] = MeanMetric().to(self.device) + self.valid_losses[k].update(v, weight=weight) + return losses def on_validation_epoch_end(self): if self.skip_immediate_validation: self.skip_immediate_validation = False self.skip_immediate_ckpt_save = True return - metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} - self.log('val_loss', metric_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) - self.logger.log_metrics({f'val/{k}': v for k, v in metric_vals.items()}, step=self.global_step) - for metric in self.valid_metrics.values(): + loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} + self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) + self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) + for metric in self.valid_losses.values(): metric.reset() + metric_vals = {k: getattr(self, k).compute() for k in self.valid_metric_names} + self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) + for metric_name in self.valid_metric_names: + getattr(self, metric_name).reset() # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): diff --git a/requirements.txt b/requirements.txt index 297c44405..e9012c128 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ resampy scipy tensorboard tensorboardX +torchmetrics tqdm diff --git a/training/acoustic_task.py b/training/acoustic_task.py index deb603776..b0723912b 100644 --- a/training/acoustic_task.py +++ b/training/acoustic_task.py @@ -77,7 +77,7 @@ def build_model(self): ) # noinspection PyAttributeOutsideInit - def build_losses(self): + def build_losses_and_metrics(self): self.mel_loss = DiffusionNoiseLoss(loss_type=hparams['diff_loss_type']) def run_model(self, sample, infer=False): diff --git a/training/variance_task.py b/training/variance_task.py index ed91ffdf0..72d9a445d 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -92,7 +92,7 @@ def build_model(self): ) # noinspection PyAttributeOutsideInit - def build_losses(self): + def build_losses_and_metrics(self): if self.predict_dur: dur_hparams = hparams['dur_prediction_args'] self.dur_loss = DurationLoss( From ce5ccbbf2983b9ca4ccaf52d8ebdf5b73f0c67cf Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 28 Jul 2023 22:09:48 +0800 Subject: [PATCH 2/3] Add raw pitch accuracy (RPA) logging --- modules/metrics/rca.py | 27 +++++++++++++++++++++++++++ training/variance_task.py | 9 +++++++-- 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 modules/metrics/rca.py diff --git a/modules/metrics/rca.py b/modules/metrics/rca.py new file mode 100644 index 000000000..8f88dc1be --- /dev/null +++ b/modules/metrics/rca.py @@ -0,0 +1,27 @@ +import torch +import torchmetrics +from torch import Tensor + + +class RawCurveAccuracy(torchmetrics.Metric): + def __init__(self, *, delta, **kwargs): + super().__init__(**kwargs) + self.delta = delta + self.add_state('close', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx='sum') + + def update(self, pred: Tensor, target: Tensor, mask=None) -> None: + if mask is None: + assert pred.shape == target.shape, f'shapes of pred and target mismatch: {pred.shape}, {target.shape}' + else: + assert pred.shape == target.shape == mask.shape, \ + f'shapes of pred, target and mask mismatch: {pred.shape}, {target.shape}, {mask.shape}' + close = torch.abs(pred - target) < self.delta + if mask is not None: + close &= mask + + self.close += close.sum() + self.total += pred.numel() if mask is None else mask.sum() + + def compute(self) -> Tensor: + return self.close / self.total diff --git a/training/variance_task.py b/training/variance_task.py index 72d9a445d..e132ef705 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -10,6 +10,7 @@ from basics.base_task import BaseTask from modules.losses.diff_loss import DiffusionNoiseLoss from modules.losses.dur_loss import DurationLoss +from modules.metrics.rca import RawCurveAccuracy from modules.toplevel import DiffSingerVariance from utils.hparams import hparams from utils.plot import dur_to_figure, curve_to_figure @@ -106,6 +107,7 @@ def build_losses_and_metrics(self): self.pitch_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], ) + self.register_metric('pitch_acc', RawCurveAccuracy(delta=0.5)) if self.predict_variances: self.var_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], @@ -178,10 +180,13 @@ def _validation_step(self, sample, batch_idx): self.plot_dur(batch_idx, sample['ph_dur'], dur_pred, txt=sample['tokens']) if pitch_pred is not None: base_pitch = sample['base_pitch'] + pred_pitch = base_pitch + pitch_pred + gt_pitch = sample['pitch'] + self.pitch_acc.update(pred=pred_pitch, target=gt_pitch) self.plot_curve( batch_idx, - gt_curve=sample['pitch'], - pred_curve=base_pitch + pitch_pred, + gt_curve=gt_pitch, + pred_curve=pred_pitch, base_curve=base_pitch, curve_name='pitch', grid=1 From c6d403d26f8aaaddce884243c43d28b8bf0b90cf Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sat, 29 Jul 2023 23:25:00 +0800 Subject: [PATCH 3/3] Exclude unvoiced and padding frames --- modules/metrics/rca.py | 6 ++++++ preprocessing/variance_binarizer.py | 2 ++ training/variance_task.py | 4 +++- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/metrics/rca.py b/modules/metrics/rca.py index 8f88dc1be..d565fd382 100644 --- a/modules/metrics/rca.py +++ b/modules/metrics/rca.py @@ -11,6 +11,12 @@ def __init__(self, *, delta, **kwargs): self.add_state('total', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx='sum') def update(self, pred: Tensor, target: Tensor, mask=None) -> None: + """ + + :param pred: predicted curve + :param target: reference curve + :param mask: valid or non-padding mask + """ if mask is None: assert pred.shape == target.shape, f'shapes of pred and target mismatch: {pred.shape}, {target.shape}' else: diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index c193f18e0..c9009ac2d 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -31,6 +31,7 @@ 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] 'pitch', # actual pitch in semitones, float32[T_s,] + 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] 'energy', # frame-level RMS (dB), float32[T_s,] 'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,] ] @@ -203,6 +204,7 @@ def process_item(self, item_name, meta_data, binarization_args): if hparams['predict_pitch'] or self.predict_variances: processed_input['pitch'] = pitch.cpu().numpy() + processed_input['uv'] = uv # Below: extract energy if hparams['predict_energy']: diff --git a/training/variance_task.py b/training/variance_task.py index e132ef705..5a6673704 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -45,6 +45,7 @@ def collater(self, samples): if hparams['predict_pitch'] or self.predict_variances: batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0) batch['pitch'] = utils.collate_nd([s['pitch'] for s in samples], 0) + batch['uv'] = utils.collate_nd([s['uv'] for s in samples], True) if hparams['predict_energy']: batch['energy'] = utils.collate_nd([s['energy'] for s in samples], 0) if hparams['predict_breathiness']: @@ -182,7 +183,8 @@ def _validation_step(self, sample, batch_idx): base_pitch = sample['base_pitch'] pred_pitch = base_pitch + pitch_pred gt_pitch = sample['pitch'] - self.pitch_acc.update(pred=pred_pitch, target=gt_pitch) + mask = (sample['mel2ph'] > 0) & ~sample['uv'] + self.pitch_acc.update(pred=pred_pitch, target=gt_pitch, mask=mask) self.plot_curve( batch_idx, gt_curve=gt_pitch,