diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index a54ca94a7..5c8322dcd 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -44,6 +44,14 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_melody_encoder: false +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false +glide_types: [up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) + pitch_prediction_args: pitd_norm_min: -8.0 pitd_norm_max: 8.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index d437729d6..44c534820 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -53,6 +53,14 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_melody_encoder: false +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false +glide_types: [up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) + pitch_prediction_args: pitd_norm_min: -8.0 pitd_norm_max: 8.0 diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 0f1dc6fc5..f1e7ac65d 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -253,6 +253,8 @@ def _torch_export_model(self): ) if self.model.predict_pitch: + use_melody_encoder = hparams.get('use_melody_encoder', False) + use_glide_embed = use_melody_encoder and hparams['use_glide_embed'] # Prepare inputs for preprocessor of PitchDiffusion note_midi = torch.FloatTensor([[60.] * 4]).to(self.device) note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device) @@ -261,10 +263,12 @@ def _torch_export_model(self): pitch_input_args = ( encoder_out, ph_dur, - note_midi, - note_dur, - pitch, { + 'note_midi': note_midi, + **({'note_rest': note_midi >= 0} if use_melody_encoder else {}), + 'note_dur': note_dur, + **({'note_glide': torch.zeros_like(note_midi, dtype=torch.long)} if use_glide_embed else {}), + 'pitch': pitch, **({'expr': torch.ones_like(pitch)} if self.expose_expr else {}), 'retake': retake, **({'spk_embed': torch.rand( @@ -277,8 +281,10 @@ def _torch_export_model(self): pitch_input_args, self.pitch_preprocess_cache_path, input_names=[ - 'encoder_out', 'ph_dur', - 'note_midi', 'note_dur', + 'encoder_out', 'ph_dur', 'note_midi', + *(['note_rest'] if use_melody_encoder else []), + 'note_dur', + *(['note_glide'] if use_glide_embed else []), 'pitch', *(['expr'] if self.expose_expr else []), 'retake', @@ -297,13 +303,15 @@ def _torch_export_model(self): 'note_midi': { 1: 'n_notes' }, + **({'note_rest': {1: 'n_notes'}} if use_melody_encoder else {}), 'note_dur': { 1: 'n_notes' }, - **({'expr': {1: 'n_frames'}} if self.expose_expr else {}), + **({'note_glide': {1: 'n_notes'}} if use_glide_embed else {}), 'pitch': { 1: 'n_frames' }, + **({'expr': {1: 'n_frames'}} if self.expose_expr else {}), 'retake': { 1: 'n_frames' }, diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 47f734fc0..deb7aa84e 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -160,10 +160,18 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): return x_cond def forward_pitch_preprocess( - self, encoder_out, ph_dur, note_midi, note_dur, + self, encoder_out, ph_dur, + note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + if self.use_melody_encoder: + melody_encoder_out = self.melody_encoder( + note_midi, note_rest, note_dur, + glide=note_glide + ) + melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) + condition += melody_encoder_out if expr is None: retake_embed = self.pitch_retake_embed(retake.long()) else: @@ -178,8 +186,12 @@ def forward_pitch_preprocess( pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) base_pitch = self.smooth(frame_midi_pitch) - base_pitch = base_pitch * retake + pitch * ~retake - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if self.use_melody_encoder: + delta_pitch = (pitch - base_pitch) * ~retake + pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) + else: + base_pitch = base_pitch * retake + pitch * ~retake + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if hparams['use_spk_id'] and spk_embed is not None: pitch_cond += spk_embed return pitch_cond, base_pitch @@ -229,6 +241,8 @@ def view_as_linguistic_encoder(self): model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.fs2 = model.fs2.view_as_encoder() @@ -239,12 +253,14 @@ def view_as_linguistic_encoder(self): return model def view_as_dur_predictor(self): + assert self.predict_dur model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor - assert self.predict_dur model.fs2 = model.fs2.view_as_dur_predictor() model.forward = model.forward_dur_predictor return model @@ -260,18 +276,22 @@ def view_as_pitch_preprocess(self): return model def view_as_pitch_diffusion(self): + assert self.predict_pitch model = copy.deepcopy(self) del model.fs2 del model.lr + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor - assert self.predict_pitch model.forward = model.forward_pitch_diffusion return model def view_as_pitch_postprocess(self): model = copy.deepcopy(self) del model.fs2 + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_pitch_postprocess @@ -282,18 +302,22 @@ def view_as_variance_preprocess(self): del model.fs2 if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_variance_preprocess return model def view_as_variance_diffusion(self): + assert self.predict_variances model = copy.deepcopy(self) del model.fs2 del model.lr if self.predict_pitch: del model.pitch_predictor - assert self.predict_variances + if self.use_melody_encoder: + del model.melody_encoder model.forward = model.forward_variance_diffusion return model @@ -302,5 +326,7 @@ def view_as_variance_postprocess(self): del model.fs2 if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder model.forward = model.forward_variance_postprocess return model diff --git a/inference/ds_variance.py b/inference/ds_variance.py index fdc999295..29c8a9bb2 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -58,6 +58,16 @@ def __init__( smooth_kernel /= smooth_kernel.sum() self.smooth.weight.data = smooth_kernel[None, None] + glide_types = hparams.get('glide_types', []) + assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' + self.glide_map = { + 'none': 0, + **{ + typename: idx + 1 + for idx, typename in enumerate(glide_types) + } + } + self.auto_completion_mode = len(predictions) == 0 self.global_predict_dur = 'dur' in predictions and hparams['predict_dur'] self.global_predict_pitch = 'pitch' in predictions and hparams['predict_pitch'] @@ -98,6 +108,7 @@ def preprocess_input( note_seq = torch.FloatTensor( [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in param['note_seq'].split()] ).to(self.device)[None] # [B=1, T_n] + T_n = note_seq.shape[1] note_dur_sec = torch.from_numpy(np.array([param['note_dur'].split()], np.float32)).to(self.device) # [B=1, T_n] note_acc = torch.round(torch.cumsum(note_dur_sec, dim=1) / self.timestep + 0.5).long() note_dur = torch.diff(note_acc, dim=1, prepend=note_acc.new_zeros(1, 1)) @@ -105,7 +116,7 @@ def preprocess_input( T_s = mel2note.shape[1] summary['words'] = T_w - summary['notes'] = note_seq.shape[1] + summary['notes'] = T_n summary['tokens'] = T_ph summary['frames'] = T_s summary['seconds'] = '%.2f' % (T_s * self.timestep) @@ -156,6 +167,17 @@ def preprocess_input( word_dur = mel2ph_to_dur(mel2word, T_w) batch['word_dur'] = word_dur + batch['note_midi'] = note_seq + batch['note_dur'] = note_dur + batch['note_rest'] = note_seq < 0 + if hparams.get('use_glide_embed', False) and param.get('note_glide') is not None: + batch['note_glide'] = torch.LongTensor( + [[self.glide_map.get(x, 0) for x in param['note_glide'].split()]] + ).to(self.device) + else: + batch['note_glide'] = torch.zeros(1, T_n, dtype=torch.long, device=self.device) + batch['mel2note'] = mel2note + # Calculate frame-level MIDI pitch, which is a step function curve frame_midi_pitch = torch.gather( F.pad(note_seq, [1, 0]), 1, mel2note @@ -250,6 +272,11 @@ def forward_model(self, sample): word_dur = sample['word_dur'] ph_dur = sample['ph_dur'] mel2ph = sample['mel2ph'] + note_midi = sample['note_midi'] + note_rest = sample['note_rest'] + note_dur = sample['note_dur'] + note_glide = sample['note_glide'] + mel2note = sample['mel2note'] base_pitch = sample['base_pitch'] expr = sample.get('expr') pitch = sample.get('pitch') @@ -271,8 +298,9 @@ def forward_model(self, sample): ph_spk_mix_embed = spk_mix_embed = None dur_pred, pitch_pred, variance_pred = self.model( - txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, - mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, + txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, + note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, + base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, infer=True ) diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index e979aaf54..8e2117f6b 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -86,3 +86,52 @@ def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_emb return encoder_out, ph_dur_pred else: return encoder_out, None + + +class MelodyEncoder(nn.Module): + def __init__(self, enc_hparams: dict): + super().__init__() + + def get_hparam(key): + return enc_hparams.get(key, hparams.get(key)) + + # MIDI inputs + hidden_size = get_hparam('hidden_size') + self.note_midi_embed = Linear(1, hidden_size) + self.note_dur_embed = Linear(1, hidden_size) + + # ornament inputs + self.use_glide_embed = hparams['use_glide_embed'] + self.glide_embed_scale = hparams['glide_embed_scale'] + if self.use_glide_embed: + # 0: none, 1: up, 2: down + self.note_glide_embed = Embedding(len(hparams['glide_types']) + 1, hidden_size, padding_idx=0) + + self.encoder = FastSpeech2Encoder( + None, hidden_size, num_layers=get_hparam('enc_layers'), + ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), + ffn_padding=get_hparam('ffn_padding'), ffn_act=get_hparam('ffn_act'), + dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), + use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos') + ) + self.out_proj = Linear(hidden_size, hparams['hidden_size']) + + def forward(self, note_midi, note_rest, note_dur, glide=None): + """ + :param note_midi: float32 [B, T_n], -1: padding + :param note_rest: bool [B, T_n] + :param note_dur: int64 [B, T_n] + :param glide: int64 [B, T_n] + :return: [B, T_n, H] + """ + midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None] + dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) + ornament_embed = 0 + if self.use_glide_embed: + ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale + encoder_out = self.encoder( + midi_embed, dur_embed + ornament_embed, + padding_mask=note_midi < 0 + ) + encoder_out = self.out_proj(encoder_out) + return encoder_out diff --git a/modules/toplevel.py b/modules/toplevel.py index a1a7647a2..97abeec38 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -17,7 +17,7 @@ from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.param_adaptor import ParameterAdaptorModule from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator -from modules.fastspeech.variance_encoder import FastSpeech2Variance +from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder from utils.hparams import hparams @@ -130,9 +130,15 @@ def __init__(self, vocab_size): self.lr = LengthRegulator() if self.predict_pitch: + self.use_melody_encoder = hparams.get('use_melody_encoder', False) + if self.use_melody_encoder: + self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) + self.delta_pitch_embed = Linear(1, hparams['hidden_size']) + else: + self.base_pitch_embed = Linear(1, hparams['hidden_size']) + self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] - self.base_pitch_embed = Linear(1, hparams['hidden_size']) self.pitch_predictor = PitchDiffusion( vmin=pitch_hparams['pitd_norm_min'], vmax=pitch_hparams['pitd_norm_max'], @@ -159,6 +165,7 @@ def __init__(self, vocab_size): def forward( self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, + note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None, base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None, spk_id=None, infer=True, **kwargs @@ -196,10 +203,21 @@ def forward( condition += spk_embed if self.predict_pitch: - if pitch_retake is None: - pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) + if self.use_melody_encoder: + melody_encoder_out = self.melody_encoder( + note_midi, note_rest, note_dur, + glide=note_glide + ) + melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0]) + mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']]) + melody_condition = torch.gather(melody_encoder_out, 1, mel2note_) + pitch_cond = condition + melody_condition else: - base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake + pitch_cond = condition + + retake_unset = pitch_retake is None + if retake_unset: + pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) if pitch_expr is None: pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) @@ -213,8 +231,17 @@ def forward( pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1] pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed - pitch_cond = condition + pitch_retake_embed - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + pitch_cond += pitch_retake_embed + if self.use_melody_encoder: + if retake_unset: # generate from scratch + delta_pitch_in = torch.zeros_like(base_pitch) + else: + delta_pitch_in = (pitch - base_pitch) * ~pitch_retake + pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) + else: + base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) else: diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index f868ff362..16a672cb2 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from scipy import interpolate -from basics.base_binarizer import BaseBinarizer +from basics.base_binarizer import BaseBinarizer, BinarizationError from basics.base_pe import BasePE from modules.fastspeech.tts_modules import LengthRegulator from modules.pe import initialize_pe @@ -32,6 +32,11 @@ 'midi', # phoneme-level mean MIDI pitch, int64[T_ph,] 'ph2word', # similar to mel2ph format, representing number of phones within each note, int64[T_ph,] 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] + 'note_midi', # note-level MIDI pitch, float32[T_n,] + 'note_rest', # flags for rest notes, bool[T_n,] + 'note_dur', # durations of notes, in number of frames, int64[T_n,] + 'note_glide', # flags for glides, 0 = none, 1 = up, 2 = down, int64[T_n,] + 'mel2note', # mel2ph format representing number of frames within each note, int64[T_s,] 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] 'pitch', # actual pitch in semitones, float32[T_s,] 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] @@ -52,6 +57,17 @@ class VarianceBinarizer(BaseBinarizer): def __init__(self): super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) + self.use_glide_embed = hparams['use_glide_embed'] + glide_types = hparams['glide_types'] + assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' + self.glide_map = { + 'none': 0, + **{ + typename: idx + 1 + for idx, typename in enumerate(glide_types) + } + } + predict_energy = hparams['predict_energy'] predict_breathiness = hparams['predict_breathiness'] self.predict_variances = predict_energy or predict_breathiness @@ -127,6 +143,8 @@ def require(attr): f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' assert any([note != 'rest' for note in temp_dict['note_seq']]), \ f'All notes are rest in \'{item_name}\'.' + if hparams['use_glide_embed']: + temp_dict['note_glide'] = require('note_glide').split() meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict @@ -173,6 +191,35 @@ def check_coverage(self): pad_inches=0.25) print(f'| save summary to \'{filename}\'') + if self.use_glide_embed: + # Glide type distribution summary + glide_count = { + g: 0 + for g in self.glide_map + } + for item_name in self.items: + for glide in self.items[item_name]['note_glide']: + if glide == 'none' or glide not in self.glide_map: + glide_count['none'] += 1 + else: + glide_count[glide] += 1 + + print('===== Glide Type Distribution Summary =====') + for i, key in enumerate(sorted(glide_count.keys(), key=lambda k: self.glide_map[k])): + if i == len(glide_count) - 1: + end = '\n' + elif i % 10 == 9: + end = ',\n' + else: + end = ', ' + print(f'\'{key}\': {glide_count[key]}', end=end) + + if any(n == 0 for _, n in glide_count.items()): + raise BinarizationError( + f'Missing glide types in dataset: ' + f'{sorted([g for g, n in glide_count.items() if n == 0], key=lambda k: self.glide_map[k])}' + ) + @torch.no_grad() def process_item(self, item_name, meta_data, binarization_args): ds_id, name = item_name.split(':', maxsplit=1) @@ -244,25 +291,41 @@ def process_item(self, item_name, meta_data, binarization_args): processed_input['midi'] = ph_midi.round().long().clamp(min=0, max=127).cpu().numpy() if hparams['predict_pitch']: - # Below: calculate and interpolate frame-level MIDI pitch, which is a step function curve - note_dur = torch.FloatTensor(meta_data['note_dur']).to(self.device) - mel2note = get_mel2ph_torch( - self.lr, note_dur, mel2ph.shape[0], self.timestep, device=self.device + # Below: get note sequence and interpolate rest notes + note_midi = np.array( + [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']], + dtype=np.float32 ) - note_pitch = torch.FloatTensor( - [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']] - ).to(self.device) - frame_midi_pitch = torch.gather(F.pad(note_pitch, [1, 0], value=0), 0, mel2note) - rest = (frame_midi_pitch < 0).cpu().numpy() - frame_midi_pitch = frame_midi_pitch.cpu().numpy() + note_rest = note_midi < 0 interp_func = interpolate.interp1d( - np.where(~rest)[0], frame_midi_pitch[~rest], + np.where(~note_rest)[0], note_midi[~note_rest], kind='nearest', fill_value='extrapolate' ) - frame_midi_pitch[rest] = interp_func(np.where(rest)[0]) - frame_midi_pitch = torch.from_numpy(frame_midi_pitch).to(self.device) + note_midi[note_rest] = interp_func(np.where(note_rest)[0]) + processed_input['note_midi'] = note_midi + processed_input['note_rest'] = note_rest + note_midi = torch.from_numpy(note_midi).to(self.device) - # Below: smoothen the pitch step curve as the base pitch curve + note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device) + note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long() + note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) + processed_input['note_dur'] = note_dur.cpu().numpy() + + mel2note = get_mel2ph_torch( + self.lr, note_dur_sec, mel2ph.shape[0], self.timestep, device=self.device + ) + processed_input['mel2note'] = mel2note.cpu().numpy() + + # Below: get ornament attributes + if hparams['use_glide_embed']: + processed_input['note_glide'] = np.array([ + self.glide_map.get(x, 0) for x in meta_data['note_glide'] + ], dtype=np.int64) + + # Below: + # 1. Get the frame-level MIDI pitch, which is a step function curve + # 2. smoothen the pitch step curve as the base pitch curve + frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note) global midi_smooth if midi_smooth is None: midi_smooth = SinusoidalSmoothingConv1d( diff --git a/training/variance_task.py b/training/variance_task.py index e3d820078..d80c116fa 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -14,7 +14,7 @@ from modules.metrics.duration import RhythmCorrectness, PhonemeDurationAccuracy from modules.toplevel import DiffSingerVariance from utils.hparams import hparams -from utils.plot import dur_to_figure, curve_to_figure +from utils.plot import dur_to_figure, pitch_note_to_figure, curve_to_figure matplotlib.use('Agg') @@ -42,6 +42,14 @@ def collater(self, samples): batch['ph2word'] = utils.collate_nd([s['ph2word'] for s in samples], 0) batch['midi'] = utils.collate_nd([s['midi'] for s in samples], 0) if hparams['predict_pitch']: + if hparams['use_melody_encoder']: + batch['note_midi'] = utils.collate_nd([s['note_midi'] for s in samples], -1) + batch['note_rest'] = utils.collate_nd([s['note_rest'] for s in samples], True) + batch['note_dur'] = utils.collate_nd([s['note_dur'] for s in samples], 0) + if hparams['use_glide_embed']: + batch['note_glide'] = utils.collate_nd([s['note_glide'] for s in samples], 0) + batch['mel2note'] = utils.collate_nd([s['mel2note'] for s in samples], 0) + batch['base_pitch'] = utils.collate_nd([s['base_pitch'] for s in samples], 0) if hparams['predict_pitch'] or self.predict_variances: batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0) @@ -124,6 +132,13 @@ def run_model(self, sample, infer=False): ph2word = sample.get('ph2word') # [B, T_ph] midi = sample.get('midi') # [B, T_ph] mel2ph = sample.get('mel2ph') # [B, T_s] + + note_midi = sample.get('note_midi') # [B, T_n] + note_rest = sample.get('note_rest') # [B, T_n] + note_dur = sample.get('note_dur') # [B, T_n] + note_glide = sample.get('note_glide') # [B, T_n] + mel2note = sample.get('mel2note') # [B, T_s] + base_pitch = sample.get('base_pitch') # [B, T_s] pitch = sample.get('pitch') # [B, T_s] energy = sample.get('energy') # [B, T_s] @@ -146,6 +161,8 @@ def run_model(self, sample, infer=False): output = self.model( txt_tokens, midi=midi, ph2word=ph2word, ph_dur=ph_dur, mel2ph=mel2ph, + note_midi=note_midi, note_rest=note_rest, + note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, base_pitch=base_pitch, pitch=pitch, energy=energy, breathiness=breathiness, pitch_retake=pitch_retake, variance_retake=variance_retake, @@ -193,18 +210,17 @@ def _validation_step(self, sample, batch_idx): ) self.plot_dur(batch_idx, dur_gt, dur_pred, txt=tokens) if pitch_pred is not None: - base_pitch = sample['base_pitch'] - pred_pitch = base_pitch + pitch_pred + pred_pitch = sample['base_pitch'] + pitch_pred gt_pitch = sample['pitch'] mask = (sample['mel2ph'] > 0) & ~sample['uv'] self.pitch_acc.update(pred=pred_pitch, target=gt_pitch, mask=mask) - self.plot_curve( + self.plot_pitch( batch_idx, - gt_curve=gt_pitch, - pred_curve=pred_pitch, - base_curve=base_pitch, - curve_name='pitch', - grid=1 + gt_pitch=gt_pitch, + pred_pitch=pred_pitch, + note_midi=sample['note_midi'], + note_dur=sample['note_dur'], + note_rest=sample['note_rest'] ) for name in self.variance_prediction_list: variance = sample[name] @@ -228,6 +244,17 @@ def plot_dur(self, batch_idx, gt_dur, pred_dur, txt=None): txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split() self.logger.experiment.add_figure(name, dur_to_figure(gt_dur, pred_dur, txt), self.global_step) + def plot_pitch(self, batch_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest): + name = f'pitch_{batch_idx}' + gt_pitch = gt_pitch[0].cpu().numpy() + pred_pitch = pred_pitch[0].cpu().numpy() + note_midi = note_midi[0].cpu().numpy() + note_dur = note_dur[0].cpu().numpy() + note_rest = note_rest[0].cpu().numpy() + self.logger.experiment.add_figure(name, pitch_note_to_figure( + gt_pitch, pred_pitch, note_midi, note_dur, note_rest + ), self.global_step) + def plot_curve(self, batch_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'): name = f'{curve_name}_{batch_idx}' gt_curve = gt_curve[0].cpu().numpy() diff --git a/utils/plot.py b/utils/plot.py index 9a1971e76..d02d68653 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -3,8 +3,6 @@ import torch from matplotlib.ticker import MultipleLocator -LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] - def spec_to_figure(spec, vmin=None, vmax=None): if isinstance(spec, torch.Tensor): @@ -42,6 +40,43 @@ def dur_to_figure(dur_gt, dur_pred, txt): return fig +def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None): + if isinstance(pitch_gt, torch.Tensor): + pitch_gt = pitch_gt.cpu().numpy() + if isinstance(pitch_pred, torch.Tensor): + pitch_pred = pitch_pred.cpu().numpy() + if isinstance(note_midi, torch.Tensor): + note_midi = note_midi.cpu().numpy() + if isinstance(note_dur, torch.Tensor): + note_dur = note_dur.cpu().numpy() + if isinstance(note_rest, torch.Tensor): + note_rest = note_rest.cpu().numpy() + fig = plt.figure() + if note_midi is not None and note_dur is not None: + note_dur_acc = np.cumsum(note_dur) + if note_rest is None: + note_rest = np.zeros_like(note_midi, dtype=np.bool_) + for i in range(len(note_midi)): + # if note_rest[i]: + # continue + plt.gca().add_patch( + plt.Rectangle( + xy=(note_dur_acc[i-1] if i > 0 else 0, note_midi[i] - 0.5), + width=note_dur[i], height=1, + edgecolor='grey', fill=False, + linewidth=1.5, linestyle='--' if note_rest[i] else '-' + ) + ) + plt.plot(pitch_gt, color='b', label='gt') + if pitch_pred is not None: + plt.plot(pitch_pred, color='r', label='pred') + plt.gca().yaxis.set_major_locator(MultipleLocator(1)) + plt.grid(axis='y') + plt.legend() + plt.tight_layout() + return fig + + def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None): if isinstance(curve_gt, torch.Tensor): curve_gt = curve_gt.cpu().numpy()