From 746c3ae72cc5427782653d41b8fc27fd65b3f851 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sun, 3 Sep 2023 21:57:38 +0800 Subject: [PATCH 01/14] Add melody encoder and support glide notes --- configs/variance.yaml | 6 ++++ modules/fastspeech/variance_encoder.py | 48 +++++++++++++++++++++++++ modules/toplevel.py | 21 +++++++++-- preprocessing/variance_binarizer.py | 50 ++++++++++++++++++++------ training/variance_task.py | 17 +++++++++ 5 files changed, 130 insertions(+), 12 deletions(-) diff --git a/configs/variance.yaml b/configs/variance.yaml index d437729d6..8ae6d4586 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -53,6 +53,12 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_melody_encoder: false +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false + pitch_prediction_args: pitd_norm_min: -8.0 pitd_norm_max: 8.0 diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index e979aaf54..0293841d6 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -86,3 +86,51 @@ def forward(self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, spk_emb return encoder_out, ph_dur_pred else: return encoder_out, None + + +class MelodyEncoder(nn.Module): + def __init__(self, enc_hparams: dict): + super().__init__() + + def get_hparam(key): + return enc_hparams.get(key, hparams.get(key)) + + # MIDI inputs + hidden_size = get_hparam('hidden_size') + self.note_midi_embed = Linear(1, hidden_size) + self.note_dur_embed = Linear(1, hidden_size) + + # ornament inputs + self.use_glide_embed = hparams['use_glide_embed'] + if self.use_glide_embed: + # 0: none, 1: up, 2: down + self.note_glide_embed = Embedding(3, hidden_size, padding_idx=0) + + self.encoder = FastSpeech2Encoder( + None, hidden_size, num_layers=get_hparam('enc_layers'), + ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), + ffn_padding=get_hparam('ffn_padding'), ffn_act=get_hparam('ffn_act'), + dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), + use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos') + ) + self.out_proj = Linear(hidden_size, hparams['hidden_size']) + + def forward(self, note_midi, note_rest, note_dur, glide=None): + """ + :param note_midi: float32 [B, T_n], -1: padding + :param note_rest: bool [B, T_n] + :param note_dur: int64 [B, T_n] + :param glide: int64 [B, T_n] + :return: [B, T_n, H] + """ + midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None] + dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) + ornament_embed = 0 + if self.use_glide_embed: + ornament_embed += self.note_glide_embed(glide) + encoder_out = self.encoder( + midi_embed, dur_embed + ornament_embed, + padding_mask=note_midi < 0 + ) + encoder_out = self.out_proj(encoder_out) + return encoder_out diff --git a/modules/toplevel.py b/modules/toplevel.py index cb63e65e4..8a420dd88 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -16,7 +16,7 @@ from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic from modules.fastspeech.param_adaptor import ParameterAdaptorModule from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator -from modules.fastspeech.variance_encoder import FastSpeech2Variance +from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder from utils.hparams import hparams @@ -85,6 +85,10 @@ def __init__(self, vocab_size): self.lr = LengthRegulator() if self.predict_pitch: + self.use_melody_encoder = hparams['use_melody_encoder'] + if self.use_melody_encoder: + self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) + self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] self.base_pitch_embed = Linear(1, hparams['hidden_size']) @@ -114,6 +118,7 @@ def __init__(self, vocab_size): def forward( self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, + note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None, base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None, spk_id=None, infer=True, **kwargs @@ -151,6 +156,18 @@ def forward( condition += spk_embed if self.predict_pitch: + if self.use_melody_encoder: + melody_encoder_out = self.melody_encoder( + note_midi, note_rest, note_dur, + glide=note_glide + ) + melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0]) + mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']]) + melody_condition = torch.gather(melody_encoder_out, 1, mel2note_) + pitch_cond = condition + melody_condition + else: + pitch_cond = condition + if pitch_retake is None: pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) else: @@ -168,7 +185,7 @@ def forward( pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1] pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed - pitch_cond = condition + pitch_retake_embed + pitch_cond += pitch_retake_embed pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index f868ff362..0c265a599 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -32,6 +32,11 @@ 'midi', # phoneme-level mean MIDI pitch, int64[T_ph,] 'ph2word', # similar to mel2ph format, representing number of phones within each note, int64[T_ph,] 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] + 'note_midi', # note-level MIDI pitch, float32[T_n,] + 'note_rest', # flags for rest notes, bool[T_n,] + 'note_dur', # durations of notes, in number of frames, int64[T_n,] + 'note_glide', # flags for glides, 0 = none, 1 = up, 2 = down, int64[T_n,] + 'mel2note', # mel2ph format representing number of frames within each note, int64[T_s,] 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] 'pitch', # actual pitch in semitones, float32[T_s,] 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] @@ -127,6 +132,9 @@ def require(attr): f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' assert any([note != 'rest' for note in temp_dict['note_seq']]), \ f'All notes are rest in \'{item_name}\'.' + if hparams['use_glide_embed']: + glide_map = {'none': 0, 'up': 1, 'down': 2} + temp_dict['note_glide'] = [glide_map[x] for x in require('note_glide').split()] meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict @@ -244,22 +252,44 @@ def process_item(self, item_name, meta_data, binarization_args): processed_input['midi'] = ph_midi.round().long().clamp(min=0, max=127).cpu().numpy() if hparams['predict_pitch']: - # Below: calculate and interpolate frame-level MIDI pitch, which is a step function curve - note_dur = torch.FloatTensor(meta_data['note_dur']).to(self.device) + # Below: get note sequence and interpolate rest notes + note_midi = np.array( + [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']], + dtype=np.float32 + ) + note_rest = note_midi < 0 + interp_func = interpolate.interp1d( + np.where(~note_rest)[0], note_midi[~note_rest], + kind='nearest', fill_value='extrapolate' + ) + note_midi[note_rest] = interp_func(np.where(note_rest)[0]) + processed_input['note_midi'] = note_midi + processed_input['note_rest'] = note_rest + note_midi = torch.from_numpy(note_midi).to(self.device) + + note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device) + note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long() + note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) + processed_input['note_dur'] = note_dur.cpu().numpy() + mel2note = get_mel2ph_torch( - self.lr, note_dur, mel2ph.shape[0], self.timestep, device=self.device + self.lr, note_dur_sec, mel2ph.shape[0], self.timestep, device=self.device ) - note_pitch = torch.FloatTensor( - [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']] - ).to(self.device) - frame_midi_pitch = torch.gather(F.pad(note_pitch, [1, 0], value=0), 0, mel2note) - rest = (frame_midi_pitch < 0).cpu().numpy() + processed_input['mel2note'] = mel2note.cpu().numpy() + + # Below: get ornament attributes + if hparams['use_glide_embed']: + processed_input['note_glide'] = np.array(meta_data['note_glide'], dtype=np.int64) + + # Below: calculate and interpolate frame-level MIDI pitch, which is a step function curve + frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note) + frame_rest = (frame_midi_pitch < 0).cpu().numpy() frame_midi_pitch = frame_midi_pitch.cpu().numpy() interp_func = interpolate.interp1d( - np.where(~rest)[0], frame_midi_pitch[~rest], + np.where(~frame_rest)[0], frame_midi_pitch[~frame_rest], kind='nearest', fill_value='extrapolate' ) - frame_midi_pitch[rest] = interp_func(np.where(rest)[0]) + frame_midi_pitch[frame_rest] = interp_func(np.where(frame_rest)[0]) frame_midi_pitch = torch.from_numpy(frame_midi_pitch).to(self.device) # Below: smoothen the pitch step curve as the base pitch curve diff --git a/training/variance_task.py b/training/variance_task.py index e3d820078..93221d934 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -42,6 +42,14 @@ def collater(self, samples): batch['ph2word'] = utils.collate_nd([s['ph2word'] for s in samples], 0) batch['midi'] = utils.collate_nd([s['midi'] for s in samples], 0) if hparams['predict_pitch']: + if hparams['use_melody_encoder']: + batch['note_midi'] = utils.collate_nd([s['note_midi'] for s in samples], -1) + batch['note_rest'] = utils.collate_nd([s['note_rest'] for s in samples], True) + batch['note_dur'] = utils.collate_nd([s['note_dur'] for s in samples], 0) + if hparams['use_glide_embed']: + batch['note_glide'] = utils.collate_nd([s['note_glide'] for s in samples], 0) + batch['mel2note'] = utils.collate_nd([s['mel2note'] for s in samples], 0) + batch['base_pitch'] = utils.collate_nd([s['base_pitch'] for s in samples], 0) if hparams['predict_pitch'] or self.predict_variances: batch['mel2ph'] = utils.collate_nd([s['mel2ph'] for s in samples], 0) @@ -124,6 +132,13 @@ def run_model(self, sample, infer=False): ph2word = sample.get('ph2word') # [B, T_ph] midi = sample.get('midi') # [B, T_ph] mel2ph = sample.get('mel2ph') # [B, T_s] + + note_midi = sample.get('note_midi') # [B, T_n] + note_rest = sample.get('note_rest') # [B, T_n] + note_dur = sample.get('note_dur') # [B, T_n] + note_glide = sample.get('note_glide') # [B, T_n] + mel2note = sample.get('mel2note') # [B, T_s] + base_pitch = sample.get('base_pitch') # [B, T_s] pitch = sample.get('pitch') # [B, T_s] energy = sample.get('energy') # [B, T_s] @@ -146,6 +161,8 @@ def run_model(self, sample, infer=False): output = self.model( txt_tokens, midi=midi, ph2word=ph2word, ph_dur=ph_dur, mel2ph=mel2ph, + note_midi=note_midi, note_rest=note_rest, + note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, base_pitch=base_pitch, pitch=pitch, energy=energy, breathiness=breathiness, pitch_retake=pitch_retake, variance_retake=variance_retake, From 69fd654ad597d607f8f4ad37c9007420e7e21d10 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 4 Sep 2023 14:34:02 +0800 Subject: [PATCH 02/14] Support melody encoder inference --- inference/ds_variance.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/inference/ds_variance.py b/inference/ds_variance.py index fdc999295..5bc41c3d0 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -98,6 +98,7 @@ def preprocess_input( note_seq = torch.FloatTensor( [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in param['note_seq'].split()] ).to(self.device)[None] # [B=1, T_n] + T_n = note_seq.shape[1] note_dur_sec = torch.from_numpy(np.array([param['note_dur'].split()], np.float32)).to(self.device) # [B=1, T_n] note_acc = torch.round(torch.cumsum(note_dur_sec, dim=1) / self.timestep + 0.5).long() note_dur = torch.diff(note_acc, dim=1, prepend=note_acc.new_zeros(1, 1)) @@ -105,7 +106,7 @@ def preprocess_input( T_s = mel2note.shape[1] summary['words'] = T_w - summary['notes'] = note_seq.shape[1] + summary['notes'] = T_n summary['tokens'] = T_ph summary['frames'] = T_s summary['seconds'] = '%.2f' % (T_s * self.timestep) @@ -156,6 +157,18 @@ def preprocess_input( word_dur = mel2ph_to_dur(mel2word, T_w) batch['word_dur'] = word_dur + batch['note_midi'] = note_seq + batch['note_dur'] = note_dur + batch['note_rest'] = note_seq < 0 + if hparams['use_glide_embed'] and param.get('note_glide') is not None: + glide_map = {'none': 0, 'up': 1, 'down': 2} + batch['note_glide'] = torch.LongTensor( + [[glide_map[x] for x in param['note_glide'].split()]] + ).to(self.device) + else: + batch['note_glide'] = torch.zeros(1, T_n, dtype=torch.long, device=self.device) + batch['mel2note'] = mel2note + # Calculate frame-level MIDI pitch, which is a step function curve frame_midi_pitch = torch.gather( F.pad(note_seq, [1, 0]), 1, mel2note @@ -250,6 +263,11 @@ def forward_model(self, sample): word_dur = sample['word_dur'] ph_dur = sample['ph_dur'] mel2ph = sample['mel2ph'] + note_midi = sample['note_midi'] + note_rest = sample['note_rest'] + note_dur = sample['note_dur'] + note_glide = sample['note_glide'] + mel2note = sample['mel2note'] base_pitch = sample['base_pitch'] expr = sample.get('expr') pitch = sample.get('pitch') @@ -271,8 +289,9 @@ def forward_model(self, sample): ph_spk_mix_embed = spk_mix_embed = None dur_pred, pitch_pred, variance_pred = self.model( - txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, - mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, + txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, + note_midi=note_midi, note_rest=note_rest, note_dur=note_dur, note_glide=note_glide, mel2note=mel2note, + base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, infer=True ) From a922d750208155167d4667006e5953d2ae0def0d Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 11 Sep 2023 16:41:48 +0800 Subject: [PATCH 03/14] Add note visualization on TensorBoard --- training/variance_task.py | 28 +++++++++++++++++++--------- utils/plot.py | 39 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/training/variance_task.py b/training/variance_task.py index 93221d934..d80c116fa 100644 --- a/training/variance_task.py +++ b/training/variance_task.py @@ -14,7 +14,7 @@ from modules.metrics.duration import RhythmCorrectness, PhonemeDurationAccuracy from modules.toplevel import DiffSingerVariance from utils.hparams import hparams -from utils.plot import dur_to_figure, curve_to_figure +from utils.plot import dur_to_figure, pitch_note_to_figure, curve_to_figure matplotlib.use('Agg') @@ -210,18 +210,17 @@ def _validation_step(self, sample, batch_idx): ) self.plot_dur(batch_idx, dur_gt, dur_pred, txt=tokens) if pitch_pred is not None: - base_pitch = sample['base_pitch'] - pred_pitch = base_pitch + pitch_pred + pred_pitch = sample['base_pitch'] + pitch_pred gt_pitch = sample['pitch'] mask = (sample['mel2ph'] > 0) & ~sample['uv'] self.pitch_acc.update(pred=pred_pitch, target=gt_pitch, mask=mask) - self.plot_curve( + self.plot_pitch( batch_idx, - gt_curve=gt_pitch, - pred_curve=pred_pitch, - base_curve=base_pitch, - curve_name='pitch', - grid=1 + gt_pitch=gt_pitch, + pred_pitch=pred_pitch, + note_midi=sample['note_midi'], + note_dur=sample['note_dur'], + note_rest=sample['note_rest'] ) for name in self.variance_prediction_list: variance = sample[name] @@ -245,6 +244,17 @@ def plot_dur(self, batch_idx, gt_dur, pred_dur, txt=None): txt = self.phone_encoder.decode(txt[0].cpu().numpy()).split() self.logger.experiment.add_figure(name, dur_to_figure(gt_dur, pred_dur, txt), self.global_step) + def plot_pitch(self, batch_idx, gt_pitch, pred_pitch, note_midi, note_dur, note_rest): + name = f'pitch_{batch_idx}' + gt_pitch = gt_pitch[0].cpu().numpy() + pred_pitch = pred_pitch[0].cpu().numpy() + note_midi = note_midi[0].cpu().numpy() + note_dur = note_dur[0].cpu().numpy() + note_rest = note_rest[0].cpu().numpy() + self.logger.experiment.add_figure(name, pitch_note_to_figure( + gt_pitch, pred_pitch, note_midi, note_dur, note_rest + ), self.global_step) + def plot_curve(self, batch_idx, gt_curve, pred_curve, base_curve=None, grid=None, curve_name='curve'): name = f'{curve_name}_{batch_idx}' gt_curve = gt_curve[0].cpu().numpy() diff --git a/utils/plot.py b/utils/plot.py index 9a1971e76..d02d68653 100644 --- a/utils/plot.py +++ b/utils/plot.py @@ -3,8 +3,6 @@ import torch from matplotlib.ticker import MultipleLocator -LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] - def spec_to_figure(spec, vmin=None, vmax=None): if isinstance(spec, torch.Tensor): @@ -42,6 +40,43 @@ def dur_to_figure(dur_gt, dur_pred, txt): return fig +def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None): + if isinstance(pitch_gt, torch.Tensor): + pitch_gt = pitch_gt.cpu().numpy() + if isinstance(pitch_pred, torch.Tensor): + pitch_pred = pitch_pred.cpu().numpy() + if isinstance(note_midi, torch.Tensor): + note_midi = note_midi.cpu().numpy() + if isinstance(note_dur, torch.Tensor): + note_dur = note_dur.cpu().numpy() + if isinstance(note_rest, torch.Tensor): + note_rest = note_rest.cpu().numpy() + fig = plt.figure() + if note_midi is not None and note_dur is not None: + note_dur_acc = np.cumsum(note_dur) + if note_rest is None: + note_rest = np.zeros_like(note_midi, dtype=np.bool_) + for i in range(len(note_midi)): + # if note_rest[i]: + # continue + plt.gca().add_patch( + plt.Rectangle( + xy=(note_dur_acc[i-1] if i > 0 else 0, note_midi[i] - 0.5), + width=note_dur[i], height=1, + edgecolor='grey', fill=False, + linewidth=1.5, linestyle='--' if note_rest[i] else '-' + ) + ) + plt.plot(pitch_gt, color='b', label='gt') + if pitch_pred is not None: + plt.plot(pitch_pred, color='r', label='pred') + plt.gca().yaxis.set_major_locator(MultipleLocator(1)) + plt.grid(axis='y') + plt.legend() + plt.tight_layout() + return fig + + def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None): if isinstance(curve_gt, torch.Tensor): curve_gt = curve_gt.cpu().numpy() From 9506b7aafc29e8cb5495a8f100ff701ff6e58712 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Thu, 14 Sep 2023 21:34:31 +0800 Subject: [PATCH 04/14] Remove base pitch embedding when using melody encoder --- modules/toplevel.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/modules/toplevel.py b/modules/toplevel.py index 8a420dd88..74173b285 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -88,10 +88,12 @@ def __init__(self, vocab_size): self.use_melody_encoder = hparams['use_melody_encoder'] if self.use_melody_encoder: self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) + self.delta_pitch_embed = Linear(1, hparams['hidden_size']) + else: + self.base_pitch_embed = Linear(1, hparams['hidden_size']) self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] - self.base_pitch_embed = Linear(1, hparams['hidden_size']) self.pitch_predictor = PitchDiffusion( vmin=pitch_hparams['pitd_norm_min'], vmax=pitch_hparams['pitd_norm_max'], @@ -170,8 +172,6 @@ def forward( if pitch_retake is None: pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) - else: - base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake if pitch_expr is None: pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) @@ -186,7 +186,13 @@ def forward( pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed pitch_cond += pitch_retake_embed - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if self.use_melody_encoder: + delta_pitch_in = (pitch - base_pitch) * ~pitch_retake + pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) + else: + base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) else: From fc1ca74ebb08ac48ff5273696d3a6b5e5a1c4213 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Thu, 14 Sep 2023 22:38:48 +0800 Subject: [PATCH 05/14] Fix KeyError `use_glide_embed` --- inference/ds_variance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/ds_variance.py b/inference/ds_variance.py index 5bc41c3d0..421ada0d6 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -160,7 +160,7 @@ def preprocess_input( batch['note_midi'] = note_seq batch['note_dur'] = note_dur batch['note_rest'] = note_seq < 0 - if hparams['use_glide_embed'] and param.get('note_glide') is not None: + if hparams.get('use_glide_embed', False) and param.get('note_glide') is not None: glide_map = {'none': 0, 'up': 1, 'down': 2} batch['note_glide'] = torch.LongTensor( [[glide_map[x] for x in param['note_glide'].split()]] From 015bb830f538c6bdc0a643df0fcb6e92460e88e2 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 15 Sep 2023 00:57:12 +0800 Subject: [PATCH 06/14] Fix invalid access to NoneType --- modules/toplevel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/toplevel.py b/modules/toplevel.py index 74173b285..35591c35e 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -170,7 +170,8 @@ def forward( else: pitch_cond = condition - if pitch_retake is None: + retake_unset = pitch_retake is None + if retake_unset: pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) if pitch_expr is None: @@ -187,7 +188,10 @@ def forward( pitch_cond += pitch_retake_embed if self.use_melody_encoder: - delta_pitch_in = (pitch - base_pitch) * ~pitch_retake + if retake_unset: # generate from scratch + delta_pitch_in = torch.zeros_like(base_pitch) + else: + delta_pitch_in = (pitch - base_pitch) * ~pitch_retake pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None]) else: base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake From 36253258b23136802f994562ebc0a76dcbb4048d Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 15 Sep 2023 12:45:10 +0800 Subject: [PATCH 07/14] Support melody encoder ONNX export --- deployment/exporters/variance_exporter.py | 16 ++++++---- deployment/modules/toplevel.py | 36 +++++++++++++++++++---- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 0f1dc6fc5..f73156063 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -253,18 +253,21 @@ def _torch_export_model(self): ) if self.model.predict_pitch: + use_melody_encoder = hparams.get('use_melody_encoder', False) # Prepare inputs for preprocessor of PitchDiffusion note_midi = torch.FloatTensor([[60.] * 4]).to(self.device) + note_rest = note_midi >= 0 note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device) pitch = torch.FloatTensor([[60.] * 15]).to(self.device) retake = torch.ones_like(pitch, dtype=torch.bool) pitch_input_args = ( encoder_out, ph_dur, - note_midi, - note_dur, - pitch, { + 'note_midi': note_midi, + **({'note_rest': note_midi >= 0} if use_melody_encoder else {}), + 'note_dur': note_dur, + 'pitch': pitch, **({'expr': torch.ones_like(pitch)} if self.expose_expr else {}), 'retake': retake, **({'spk_embed': torch.rand( @@ -277,9 +280,9 @@ def _torch_export_model(self): pitch_input_args, self.pitch_preprocess_cache_path, input_names=[ - 'encoder_out', 'ph_dur', - 'note_midi', 'note_dur', - 'pitch', + 'encoder_out', 'ph_dur', 'note_midi', + *(['note_rest'] if use_melody_encoder else []), + 'note_dur', 'pitch', *(['expr'] if self.expose_expr else []), 'retake', *(['spk_embed'] if input_spk_embed else []) @@ -297,6 +300,7 @@ def _torch_export_model(self): 'note_midi': { 1: 'n_notes' }, + **({'note_rest': {1: 'n_notes'}} if use_melody_encoder else {}), 'note_dur': { 1: 'n_notes' }, diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 2cbbda8fb..9b63a39c7 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -161,10 +161,16 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): return x_cond def forward_pitch_preprocess( - self, encoder_out, ph_dur, note_midi, note_dur, + self, encoder_out, ph_dur, note_midi=None, note_rest=None, note_dur=None, pitch=None, expr=None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) + if self.use_melody_encoder: + melody_encoder_out = self.melody_encoder( + note_midi, note_rest, note_dur + ) + melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) + condition += melody_encoder_out if expr is None: retake_embed = self.pitch_retake_embed(retake.long()) else: @@ -179,8 +185,12 @@ def forward_pitch_preprocess( pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) base_pitch = self.smooth(frame_midi_pitch) - base_pitch = base_pitch * retake + pitch * ~retake - pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) + if self.use_melody_encoder: + delta_pitch = (pitch - base_pitch) * ~retake + pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None]) + else: + base_pitch = base_pitch * retake + pitch * ~retake + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if hparams['use_spk_id'] and spk_embed is not None: pitch_cond += spk_embed return pitch_cond, base_pitch @@ -230,6 +240,8 @@ def view_as_linguistic_encoder(self): model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.fs2 = model.fs2.view_as_encoder() @@ -240,12 +252,14 @@ def view_as_linguistic_encoder(self): return model def view_as_dur_predictor(self): + assert self.predict_dur model = copy.deepcopy(self) if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor - assert self.predict_dur model.fs2 = model.fs2.view_as_dur_predictor() model.forward = model.forward_dur_predictor return model @@ -261,18 +275,22 @@ def view_as_pitch_preprocess(self): return model def view_as_pitch_diffusion(self): + assert self.predict_pitch model = copy.deepcopy(self) del model.fs2 del model.lr + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor - assert self.predict_pitch model.forward = model.forward_pitch_diffusion return model def view_as_pitch_postprocess(self): model = copy.deepcopy(self) del model.fs2 + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_pitch_postprocess @@ -283,18 +301,22 @@ def view_as_variance_preprocess(self): del model.fs2 if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder if self.predict_variances: del model.variance_predictor model.forward = model.forward_variance_preprocess return model def view_as_variance_diffusion(self): + assert self.predict_variances model = copy.deepcopy(self) del model.fs2 del model.lr if self.predict_pitch: del model.pitch_predictor - assert self.predict_variances + if self.use_melody_encoder: + del model.melody_encoder model.forward = model.forward_variance_diffusion return model @@ -303,5 +325,7 @@ def view_as_variance_postprocess(self): del model.fs2 if self.predict_pitch: del model.pitch_predictor + if self.use_melody_encoder: + del model.melody_encoder model.forward = model.forward_variance_postprocess return model From ed2c8077dfd13ed6f5eaac13dafc52eda21305e7 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 15 Sep 2023 12:55:19 +0800 Subject: [PATCH 08/14] Fix KeyError `use_melody_encoder` --- modules/toplevel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/toplevel.py b/modules/toplevel.py index 35591c35e..4b472c49e 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -85,7 +85,7 @@ def __init__(self, vocab_size): self.lr = LengthRegulator() if self.predict_pitch: - self.use_melody_encoder = hparams['use_melody_encoder'] + self.use_melody_encoder = hparams.get('use_melody_encoder', False) if self.use_melody_encoder: self.melody_encoder = MelodyEncoder(enc_hparams=hparams['melody_encoder_args']) self.delta_pitch_embed = Linear(1, hparams['hidden_size']) From 3e2e908c0cff609f8b79295deab1f41647a48a06 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Thu, 5 Oct 2023 02:13:50 +0800 Subject: [PATCH 09/14] Support note glide ONNX export --- deployment/exporters/variance_exporter.py | 10 +++++++--- deployment/modules/toplevel.py | 6 ++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index f73156063..f1e7ac65d 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -254,9 +254,9 @@ def _torch_export_model(self): if self.model.predict_pitch: use_melody_encoder = hparams.get('use_melody_encoder', False) + use_glide_embed = use_melody_encoder and hparams['use_glide_embed'] # Prepare inputs for preprocessor of PitchDiffusion note_midi = torch.FloatTensor([[60.] * 4]).to(self.device) - note_rest = note_midi >= 0 note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device) pitch = torch.FloatTensor([[60.] * 15]).to(self.device) retake = torch.ones_like(pitch, dtype=torch.bool) @@ -267,6 +267,7 @@ def _torch_export_model(self): 'note_midi': note_midi, **({'note_rest': note_midi >= 0} if use_melody_encoder else {}), 'note_dur': note_dur, + **({'note_glide': torch.zeros_like(note_midi, dtype=torch.long)} if use_glide_embed else {}), 'pitch': pitch, **({'expr': torch.ones_like(pitch)} if self.expose_expr else {}), 'retake': retake, @@ -282,7 +283,9 @@ def _torch_export_model(self): input_names=[ 'encoder_out', 'ph_dur', 'note_midi', *(['note_rest'] if use_melody_encoder else []), - 'note_dur', 'pitch', + 'note_dur', + *(['note_glide'] if use_glide_embed else []), + 'pitch', *(['expr'] if self.expose_expr else []), 'retake', *(['spk_embed'] if input_spk_embed else []) @@ -304,10 +307,11 @@ def _torch_export_model(self): 'note_dur': { 1: 'n_notes' }, - **({'expr': {1: 'n_frames'}} if self.expose_expr else {}), + **({'note_glide': {1: 'n_notes'}} if use_glide_embed else {}), 'pitch': { 1: 'n_frames' }, + **({'expr': {1: 'n_frames'}} if self.expose_expr else {}), 'retake': { 1: 'n_frames' }, diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 06db05b63..deb7aa84e 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -160,13 +160,15 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): return x_cond def forward_pitch_preprocess( - self, encoder_out, ph_dur, note_midi=None, note_rest=None, note_dur=None, + self, encoder_out, ph_dur, + note_midi=None, note_rest=None, note_dur=None, note_glide=None, pitch=None, expr=None, retake=None, spk_embed=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) if self.use_melody_encoder: melody_encoder_out = self.melody_encoder( - note_midi, note_rest, note_dur + note_midi, note_rest, note_dur, + glide=note_glide ) melody_encoder_out = self.forward_mel2x_gather(melody_encoder_out, note_dur, x_dim=self.hidden_size) condition += melody_encoder_out From beb33a50a537a9aa0170e5bec482941f04ba5d31 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 6 Oct 2023 01:37:30 +0800 Subject: [PATCH 10/14] Support custom glide types --- configs/variance.yaml | 1 + inference/ds_variance.py | 9 +++++++-- modules/fastspeech/variance_encoder.py | 2 +- preprocessing/variance_binarizer.py | 10 ++++++++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/configs/variance.yaml b/configs/variance.yaml index 8ae6d4586..0da23d23c 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -58,6 +58,7 @@ melody_encoder_args: hidden_size: 128 enc_layers: 4 use_glide_embed: false +glide_types: [none, up, down] pitch_prediction_args: pitd_norm_min: -8.0 diff --git a/inference/ds_variance.py b/inference/ds_variance.py index 421ada0d6..d56926555 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -58,6 +58,12 @@ def __init__( smooth_kernel /= smooth_kernel.sum() self.smooth.weight.data = smooth_kernel[None, None] + glide_types = hparams.get('glide_types', ['none']) + self.glide_map = { + typename: idx + for idx, typename in enumerate(glide_types) + } + self.auto_completion_mode = len(predictions) == 0 self.global_predict_dur = 'dur' in predictions and hparams['predict_dur'] self.global_predict_pitch = 'pitch' in predictions and hparams['predict_pitch'] @@ -161,9 +167,8 @@ def preprocess_input( batch['note_dur'] = note_dur batch['note_rest'] = note_seq < 0 if hparams.get('use_glide_embed', False) and param.get('note_glide') is not None: - glide_map = {'none': 0, 'up': 1, 'down': 2} batch['note_glide'] = torch.LongTensor( - [[glide_map[x] for x in param['note_glide'].split()]] + [[self.glide_map.get(x, 0) for x in param['note_glide'].split()]] ).to(self.device) else: batch['note_glide'] = torch.zeros(1, T_n, dtype=torch.long, device=self.device) diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 0293841d6..1d504e828 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -104,7 +104,7 @@ def get_hparam(key): self.use_glide_embed = hparams['use_glide_embed'] if self.use_glide_embed: # 0: none, 1: up, 2: down - self.note_glide_embed = Embedding(3, hidden_size, padding_idx=0) + self.note_glide_embed = Embedding(len(hparams['glide_types']), hidden_size, padding_idx=0) self.encoder = FastSpeech2Encoder( None, hidden_size, num_layers=get_hparam('enc_layers'), diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 0c265a599..89a54d5ad 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -57,6 +57,13 @@ class VarianceBinarizer(BaseBinarizer): def __init__(self): super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) + glide_types = hparams['glide_types'] + assert glide_types[0] == 'none', 'The first glide type must be \'none\'.' + self.glide_map = { + typename: idx + for idx, typename in enumerate(glide_types) + } + predict_energy = hparams['predict_energy'] predict_breathiness = hparams['predict_breathiness'] self.predict_variances = predict_energy or predict_breathiness @@ -133,8 +140,7 @@ def require(attr): assert any([note != 'rest' for note in temp_dict['note_seq']]), \ f'All notes are rest in \'{item_name}\'.' if hparams['use_glide_embed']: - glide_map = {'none': 0, 'up': 1, 'down': 2} - temp_dict['note_glide'] = [glide_map[x] for x in require('note_glide').split()] + temp_dict['note_glide'] = [self.glide_map.get(x, 0) for x in require('note_glide').split()] meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict From 16005ddbbe0df3ab6cbe0c28cebe1ffa72032b70 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 6 Oct 2023 01:50:48 +0800 Subject: [PATCH 11/14] Support glide embedding scale --- configs/variance.yaml | 1 + modules/fastspeech/variance_encoder.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/configs/variance.yaml b/configs/variance.yaml index 0da23d23c..43a873255 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -59,6 +59,7 @@ melody_encoder_args: enc_layers: 4 use_glide_embed: false glide_types: [none, up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) pitch_prediction_args: pitd_norm_min: -8.0 diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 1d504e828..75c2306b0 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -102,6 +102,7 @@ def get_hparam(key): # ornament inputs self.use_glide_embed = hparams['use_glide_embed'] + self.glide_embed_scale = hparams['glide_embed_scale'] if self.use_glide_embed: # 0: none, 1: up, 2: down self.note_glide_embed = Embedding(len(hparams['glide_types']), hidden_size, padding_idx=0) @@ -127,7 +128,7 @@ def forward(self, note_midi, note_rest, note_dur, glide=None): dur_embed = self.note_dur_embed(note_dur.float()[:, :, None]) ornament_embed = 0 if self.use_glide_embed: - ornament_embed += self.note_glide_embed(glide) + ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale encoder_out = self.encoder( midi_embed, dur_embed + ornament_embed, padding_mask=note_midi < 0 From 6ecaedab3fd37839a99c1fc234a620177ef8a596 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sat, 7 Oct 2023 23:20:37 +0800 Subject: [PATCH 12/14] Adjust glide type format --- configs/templates/config_variance.yaml | 8 ++++++++ configs/variance.yaml | 2 +- inference/ds_variance.py | 10 +++++++--- modules/fastspeech/variance_encoder.py | 2 +- preprocessing/variance_binarizer.py | 11 +++++++---- 5 files changed, 24 insertions(+), 9 deletions(-) diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index a54ca94a7..5c8322dcd 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -44,6 +44,14 @@ dur_prediction_args: lambda_wdur_loss: 1.0 lambda_sdur_loss: 3.0 +use_melody_encoder: false +melody_encoder_args: + hidden_size: 128 + enc_layers: 4 +use_glide_embed: false +glide_types: [up, down] +glide_embed_scale: 11.313708498984760 # sqrt(128) + pitch_prediction_args: pitd_norm_min: -8.0 pitd_norm_max: 8.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index 43a873255..44c534820 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -58,7 +58,7 @@ melody_encoder_args: hidden_size: 128 enc_layers: 4 use_glide_embed: false -glide_types: [none, up, down] +glide_types: [up, down] glide_embed_scale: 11.313708498984760 # sqrt(128) pitch_prediction_args: diff --git a/inference/ds_variance.py b/inference/ds_variance.py index d56926555..29c8a9bb2 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -58,10 +58,14 @@ def __init__( smooth_kernel /= smooth_kernel.sum() self.smooth.weight.data = smooth_kernel[None, None] - glide_types = hparams.get('glide_types', ['none']) + glide_types = hparams.get('glide_types', []) + assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' self.glide_map = { - typename: idx - for idx, typename in enumerate(glide_types) + 'none': 0, + **{ + typename: idx + 1 + for idx, typename in enumerate(glide_types) + } } self.auto_completion_mode = len(predictions) == 0 diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index 75c2306b0..8e2117f6b 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -105,7 +105,7 @@ def get_hparam(key): self.glide_embed_scale = hparams['glide_embed_scale'] if self.use_glide_embed: # 0: none, 1: up, 2: down - self.note_glide_embed = Embedding(len(hparams['glide_types']), hidden_size, padding_idx=0) + self.note_glide_embed = Embedding(len(hparams['glide_types']) + 1, hidden_size, padding_idx=0) self.encoder = FastSpeech2Encoder( None, hidden_size, num_layers=get_hparam('enc_layers'), diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 89a54d5ad..d4ddc64fa 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -57,11 +57,14 @@ class VarianceBinarizer(BaseBinarizer): def __init__(self): super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) - glide_types = hparams['glide_types'] - assert glide_types[0] == 'none', 'The first glide type must be \'none\'.' + glide_types = hparams.get('glide_types', []) + assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' self.glide_map = { - typename: idx - for idx, typename in enumerate(glide_types) + 'none': 0, + **{ + typename: idx + 1 + for idx, typename in enumerate(glide_types) + } } predict_energy = hparams['predict_energy'] From f01af3845141b1c095033af7dab34f175e415b16 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sun, 8 Oct 2023 00:53:20 +0800 Subject: [PATCH 13/14] Remove unnecessary interpolation on frame-level MIDI pitch --- preprocessing/variance_binarizer.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index d4ddc64fa..670731ed3 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -290,18 +290,10 @@ def process_item(self, item_name, meta_data, binarization_args): if hparams['use_glide_embed']: processed_input['note_glide'] = np.array(meta_data['note_glide'], dtype=np.int64) - # Below: calculate and interpolate frame-level MIDI pitch, which is a step function curve + # Below: + # 1. Get the frame-level MIDI pitch, which is a step function curve + # 2. smoothen the pitch step curve as the base pitch curve frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note) - frame_rest = (frame_midi_pitch < 0).cpu().numpy() - frame_midi_pitch = frame_midi_pitch.cpu().numpy() - interp_func = interpolate.interp1d( - np.where(~frame_rest)[0], frame_midi_pitch[~frame_rest], - kind='nearest', fill_value='extrapolate' - ) - frame_midi_pitch[frame_rest] = interp_func(np.where(frame_rest)[0]) - frame_midi_pitch = torch.from_numpy(frame_midi_pitch).to(self.device) - - # Below: smoothen the pitch step curve as the base pitch curve global midi_smooth if midi_smooth is None: midi_smooth = SinusoidalSmoothingConv1d( From c8b439a9c661a1c472fbd5649de9f8e36df12e2b Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Sun, 8 Oct 2023 01:11:52 +0800 Subject: [PATCH 14/14] Add glide type coverage checks --- preprocessing/variance_binarizer.py | 40 ++++++++++++++++++++++++++--- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/preprocessing/variance_binarizer.py b/preprocessing/variance_binarizer.py index 670731ed3..16a672cb2 100644 --- a/preprocessing/variance_binarizer.py +++ b/preprocessing/variance_binarizer.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from scipy import interpolate -from basics.base_binarizer import BaseBinarizer +from basics.base_binarizer import BaseBinarizer, BinarizationError from basics.base_pe import BasePE from modules.fastspeech.tts_modules import LengthRegulator from modules.pe import initialize_pe @@ -57,7 +57,8 @@ class VarianceBinarizer(BaseBinarizer): def __init__(self): super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) - glide_types = hparams.get('glide_types', []) + self.use_glide_embed = hparams['use_glide_embed'] + glide_types = hparams['glide_types'] assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' self.glide_map = { 'none': 0, @@ -143,7 +144,7 @@ def require(attr): assert any([note != 'rest' for note in temp_dict['note_seq']]), \ f'All notes are rest in \'{item_name}\'.' if hparams['use_glide_embed']: - temp_dict['note_glide'] = [self.glide_map.get(x, 0) for x in require('note_glide').split()] + temp_dict['note_glide'] = require('note_glide').split() meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict @@ -190,6 +191,35 @@ def check_coverage(self): pad_inches=0.25) print(f'| save summary to \'{filename}\'') + if self.use_glide_embed: + # Glide type distribution summary + glide_count = { + g: 0 + for g in self.glide_map + } + for item_name in self.items: + for glide in self.items[item_name]['note_glide']: + if glide == 'none' or glide not in self.glide_map: + glide_count['none'] += 1 + else: + glide_count[glide] += 1 + + print('===== Glide Type Distribution Summary =====') + for i, key in enumerate(sorted(glide_count.keys(), key=lambda k: self.glide_map[k])): + if i == len(glide_count) - 1: + end = '\n' + elif i % 10 == 9: + end = ',\n' + else: + end = ', ' + print(f'\'{key}\': {glide_count[key]}', end=end) + + if any(n == 0 for _, n in glide_count.items()): + raise BinarizationError( + f'Missing glide types in dataset: ' + f'{sorted([g for g, n in glide_count.items() if n == 0], key=lambda k: self.glide_map[k])}' + ) + @torch.no_grad() def process_item(self, item_name, meta_data, binarization_args): ds_id, name = item_name.split(':', maxsplit=1) @@ -288,7 +318,9 @@ def process_item(self, item_name, meta_data, binarization_args): # Below: get ornament attributes if hparams['use_glide_embed']: - processed_input['note_glide'] = np.array(meta_data['note_glide'], dtype=np.int64) + processed_input['note_glide'] = np.array([ + self.glide_map.get(x, 0) for x in meta_data['note_glide'] + ], dtype=np.int64) # Below: # 1. Get the frame-level MIDI pitch, which is a step function curve