diff --git a/basics/base_task.py b/basics/base_task.py index 059440bdb..f82e1442c 100644 --- a/basics/base_task.py +++ b/basics/base_task.py @@ -77,9 +77,10 @@ def __init__(self, *args, **kwargs): self.skip_immediate_validation = False self.skip_immediate_ckpt_save = False - self.valid_metrics: Dict[str, Metric] = { + self.valid_losses: Dict[str, Metric] = { 'total_loss': MeanMetric() } + self.valid_metric_names = set() ########### # Training, validation and testing @@ -91,7 +92,7 @@ def setup(self, stage): if hparams['finetune_enabled'] and get_latest_checkpoint_path(pathlib.Path(hparams['work_dir'])) is None: self.load_finetune_ckpt( self.load_pre_train_model()) self.print_arch() - self.build_losses() + self.build_losses_and_metrics() self.train_dataset = self.dataset_cls(hparams['train_set_name']) self.valid_dataset = self.dataset_cls(hparams['valid_set_name']) @@ -163,9 +164,14 @@ def build_model(self): def print_arch(self): utils.print_arch(self.model) - def build_losses(self): + def build_losses_and_metrics(self): raise NotImplementedError() + def register_metric(self, name: str, metric: Metric): + assert isinstance(metric, Metric) + setattr(self, name, metric) + self.valid_metric_names.add(name) + def run_model(self, sample, infer=False): """ steps: @@ -194,8 +200,8 @@ def training_step(self, sample, batch_idx, optimizer_idx=-1): self.log('lr', self.lr_schedulers().get_last_lr()[0], prog_bar=True, logger=False, on_step=True, on_epoch=False) # logs to tensorboard if self.global_step % hparams['log_interval'] == 0: - tb_log = {f'tr/{k}': v for k, v in log_outputs.items()} - tb_log['tr/lr'] = self.lr_schedulers().get_last_lr()[0] + tb_log = {f'training/{k}': v for k, v in log_outputs.items()} + tb_log['training/lr'] = self.lr_schedulers().get_last_lr()[0] self.logger.log_metrics(tb_log, step=self.global_step) return total_loss @@ -208,7 +214,7 @@ def _on_validation_start(self): def on_validation_start(self): self._on_validation_start() - for metric in self.valid_metrics.values(): + for metric in self.valid_losses.values(): metric.to(self.device) metric.reset() @@ -231,27 +237,31 @@ def validation_step(self, sample, batch_idx): rank_zero_debug(f"Skip validation {batch_idx}") return {} with torch.autocast(self.device.type, enabled=False): - outputs, weight = self._validation_step(sample, batch_idx) - outputs = { - 'total_loss': sum(outputs.values()), - **outputs + losses, weight = self._validation_step(sample, batch_idx) + losses = { + 'total_loss': sum(losses.values()), + **losses } - for k, v in outputs.items(): - if k not in self.valid_metrics: - self.valid_metrics[k] = MeanMetric().to(self.device) - self.valid_metrics[k].update(v, weight=weight) - return outputs + for k, v in losses.items(): + if k not in self.valid_losses: + self.valid_losses[k] = MeanMetric().to(self.device) + self.valid_losses[k].update(v, weight=weight) + return losses def on_validation_epoch_end(self): if self.skip_immediate_validation: self.skip_immediate_validation = False self.skip_immediate_ckpt_save = True return - metric_vals = {k: v.compute() for k, v in self.valid_metrics.items()} - self.log('val_loss', metric_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) - self.logger.log_metrics({f'val/{k}': v for k, v in metric_vals.items()}, step=self.global_step) - for metric in self.valid_metrics.values(): + loss_vals = {k: v.compute() for k, v in self.valid_losses.items()} + self.log('val_loss', loss_vals['total_loss'], on_epoch=True, prog_bar=True, logger=False, sync_dist=True) + self.logger.log_metrics({f'validation/{k}': v for k, v in loss_vals.items()}, step=self.global_step) + for metric in self.valid_losses.values(): metric.reset() + metric_vals = {k: getattr(self, k).compute() for k in self.valid_metric_names} + self.logger.log_metrics({f'metrics/{k}': v for k, v in metric_vals.items()}, step=self.global_step) + for metric_name in self.valid_metric_names: + getattr(self, metric_name).reset() # noinspection PyMethodMayBeStatic def build_scheduler(self, optimizer): diff --git a/modules/metrics/rca.py b/modules/metrics/rca.py new file mode 100644 index 000000000..d565fd382 --- /dev/null +++ b/modules/metrics/rca.py @@ -0,0 +1,33 @@ +import torch +import torchmetrics +from torch import Tensor + + +class RawCurveAccuracy(torchmetrics.Metric): + def __init__(self, *, delta, **kwargs): + super().__init__(**kwargs) + self.delta = delta + self.add_state('close', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx='sum') + + def update(self, pred: Tensor, target: Tensor, mask=None) -> None: + """ + + :param pred: predicted curve + :param target: reference curve + :param mask: valid or non-padding mask + """ + if mask is None: + assert pred.shape == target.shape, f'shapes of pred and target mismatch: {pred.shape}, {target.shape}' + else: + assert pred.shape == target.shape == mask.shape, \ + f'shapes of pred, target and mask mismatch: {pred.shape}, {target.shape}, {mask.shape}' + close = torch.abs(pred - target) < self.delta + if mask is not None: + close &= mask + + self.close += close.sum() + self.total += pred.numel() if mask is None else mask.sum() + + def compute(self) -> Tensor: + return self.close / self.total diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index c193f18e0..c9009ac2d 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -31,6 +31,7 @@ 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] 'pitch', # actual pitch in semitones, float32[T_s,] + 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] 'energy', # frame-level RMS (dB), float32[T_s,] 'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,] ] @@ -203,6 +204,7 @@ def process_item(self, item_name, meta_data, binarization_args): if hparams['predict_pitch'] or self.predict_variances: processed_input['pitch'] = pitch.cpu().numpy() + processed_input['uv'] = uv # Below: extract energy if hparams['predict_energy']: diff --git a/requirements.txt b/requirements.txt index 297c44405..e9012c128 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ resampy scipy tensorboard tensorboardX +torchmetrics tqdm diff --git a/training/acoustic_task.py b/training/acoustic_task.py index deb603776..b0723912b 100644 --- a/training/acoustic_task.py +++ b/training/acoustic_task.py @@ -77,7 +77,7 @@ def build_model(self): ) # noinspection PyAttributeOutsideInit - def build_losses(self): + def build_losses_and_metrics(self): self.mel_loss = DiffusionNoiseLoss(loss_type=hparams['diff_loss_type']) def run_model(self, sample, infer=False): diff --git a/training/variance_task.py b/training/variance_task.py index ed91ffdf0..5a6673704 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -10,6 +10,7 @@ from basics.base_task import BaseTask from modules.losses.diff_loss import DiffusionNoiseLoss from modules.losses.dur_loss import DurationLoss +from modules.metrics.rca import RawCurveAccuracy from modules.toplevel import DiffSingerVariance from utils.hparams import hparams from utils.plot import dur_to_figure, curve_to_figure @@ -44,6 +45,7 @@ def collater(self, samples): if hparams['predict_pitch'] or self.predict_variances: batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0) batch['pitch'] = utils.collate_nd([s['pitch'] for s in samples], 0) + batch['uv'] = utils.collate_nd([s['uv'] for s in samples], True) if hparams['predict_energy']: batch['energy'] = utils.collate_nd([s['energy'] for s in samples], 0) if hparams['predict_breathiness']: @@ -92,7 +94,7 @@ def build_model(self): ) # noinspection PyAttributeOutsideInit - def build_losses(self): + def build_losses_and_metrics(self): if self.predict_dur: dur_hparams = hparams['dur_prediction_args'] self.dur_loss = DurationLoss( @@ -106,6 +108,7 @@ def build_losses(self): self.pitch_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], ) + self.register_metric('pitch_acc', RawCurveAccuracy(delta=0.5)) if self.predict_variances: self.var_loss = DiffusionNoiseLoss( loss_type=hparams['diff_loss_type'], @@ -178,10 +181,14 @@ def _validation_step(self, sample, batch_idx): self.plot_dur(batch_idx, sample['ph_dur'], dur_pred, txt=sample['tokens']) if pitch_pred is not None: base_pitch = sample['base_pitch'] + pred_pitch = base_pitch + pitch_pred + gt_pitch = sample['pitch'] + mask = (sample['mel2ph'] > 0) & ~sample['uv'] + self.pitch_acc.update(pred=pred_pitch, target=gt_pitch, mask=mask) self.plot_curve( batch_idx, - gt_curve=sample['pitch'], - pred_curve=base_pitch + pitch_pred, + gt_curve=gt_pitch, + pred_curve=pred_pitch, base_curve=base_pitch, curve_name='pitch', grid=1