From f2b9f50a2e5a6ed98b4463b30469818381942a11 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 12 Jun 2023 17:25:16 +0800 Subject: [PATCH 1/8] Add expressiveness in model `forward()` --- modules/toplevel.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/modules/toplevel.py b/modules/toplevel.py index 4a8fb8570..997a0c5c4 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -114,7 +114,8 @@ def __init__(self, vocab_size): def forward( self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None, - base_pitch=None, pitch=None, retake=None, spk_id=None, infer=True, **kwargs + base_pitch=None, pitch=None, retake=None, expressiveness=None, + spk_id=None, infer=True, **kwargs ): if self.use_spk_id: ph_spk_mix_embed = kwargs.get('ph_spk_mix_embed') @@ -147,10 +148,16 @@ def forward( if self.use_spk_id: condition += spk_embed + if retake is None: - retake_embed = self.retake_embed(torch.ones_like(mel2ph)) - else: + retake = torch.ones_like(mel2ph, dtype=torch.bool) + if expressiveness is None: retake_embed = self.retake_embed(retake.long()) + else: + retake_true_embed = self.retake_embed(torch.ones_like(mel2ph)) # [B, T, H] + retake_false_embed = self.retake_embed(torch.zeros_like(mel2ph)) # [B, T, H] + expressiveness = (expressiveness * retake)[:, :, None] # [B, T, 1] + retake_embed = expressiveness * retake_true_embed + (1. - expressiveness) * retake_false_embed condition += retake_embed if self.predict_pitch: From 4bbedaf815bbd7b94733d30d72094f0bf360945b Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jun 2023 00:45:35 +0800 Subject: [PATCH 2/8] Support inference with static or dynamic expressiveness --- inference/ds_variance.py | 24 +++++++++++++++++++++--- modules/toplevel.py | 38 ++++++++++++++++++++++++++------------ scripts/infer.py | 8 ++++++++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/inference/ds_variance.py b/inference/ds_variance.py index 536112b45..d7ac9ef97 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -211,6 +211,22 @@ def preprocess_input( summary['pitch'] = 'manual' elif self.auto_completion_mode or self.global_predict_pitch: summary['pitch'] = 'auto' + + # Load expressiveness + expressiveness = param.get('expressiveness', 1.) + if isinstance(expressiveness, (int, float, bool)): + summary['expressiveness'] = f'static({expressiveness:.3f})' + batch['expressiveness'] = torch.FloatTensor([expressiveness]).to(self.device)[:, None] # [B=1, T=1] + else: + summary['expressiveness'] = 'dynamic' + expressiveness = resample_align_curve( + np.array(expressiveness.split(), np.float32), + original_timestep=float(param['expressiveness_timestep']), + target_timestep=self.timestep, + align_length=T_s + ) + batch['expressiveness'] = torch.from_numpy(expressiveness.astype(np.float32)).to(self.device)[None] + else: summary['pitch'] = 'ignored' @@ -234,6 +250,7 @@ def forward_model(self, sample): ph_dur = sample['ph_dur'] mel2ph = sample['mel2ph'] base_pitch = sample['base_pitch'] + expressiveness = sample.get('expressiveness') pitch = sample.get('pitch') if hparams['use_spk_id']: @@ -256,7 +273,7 @@ def forward_model(self, sample): txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, - retake=None, infer=True + retake=None, expressiveness=expressiveness, infer=True ) if dur_pred is not None: dur_pred = self.rr(dur_pred, ph2word, word_dur) @@ -303,10 +320,11 @@ def run_inference( else: predict_variances = self.model.predict_variances and self.global_predict_variances predict_pitch = self.model.predict_pitch and ( - self.global_predict_pitch or (param.get('f0_seq') is None and predict_variances) + self.global_predict_pitch or (param.get('f0_seq') is None and predict_variances) ) predict_dur = self.model.predict_dur and ( - self.global_predict_dur or (param.get('ph_dur') is None and (predict_pitch or predict_variances)) + self.global_predict_dur or ( + param.get('ph_dur') is None and (predict_pitch or predict_variances)) ) flag = (predict_dur, predict_pitch, predict_variances) predictor_flags.append(flag) diff --git a/modules/toplevel.py b/modules/toplevel.py index 997a0c5c4..8028a2e25 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -149,21 +149,35 @@ def forward( if self.use_spk_id: condition += spk_embed - if retake is None: - retake = torch.ones_like(mel2ph, dtype=torch.bool) + retake_ = torch.ones(1, 1, dtype=torch.bool, device=txt_tokens.device) # [B=1, T=1] + if expressiveness is None: - retake_embed = self.retake_embed(retake.long()) + retake_embed = self.retake_embed(retake_.long()) + pitch_cond = var_cond = condition + retake_embed else: - retake_true_embed = self.retake_embed(torch.ones_like(mel2ph)) # [B, T, H] - retake_false_embed = self.retake_embed(torch.zeros_like(mel2ph)) # [B, T, H] - expressiveness = (expressiveness * retake)[:, :, None] # [B, T, 1] - retake_embed = expressiveness * retake_true_embed + (1. - expressiveness) * retake_false_embed - condition += retake_embed + if self.predict_pitch: + retake_true_embed = self.retake_embed( + torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device) + ) # [B=1, T=1] => [B=1, T=1, H] + retake_false_embed = self.retake_embed( + torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device) + ) # [B=1, T=1] => [B=1, T=1, H] + expressiveness = (expressiveness * retake_)[:, :, None] # [B, T, 1] + pitch_retake_embed = expressiveness * retake_true_embed + (1. - expressiveness) * retake_false_embed + pitch_cond = condition + pitch_retake_embed + else: + pitch_cond = None + + if self.predict_variances: + var_retake_embed = self.retake_embed(retake_.long()) + var_cond = condition + var_retake_embed + else: + var_cond = None if self.predict_pitch: if retake is not None: base_pitch = base_pitch * retake + pitch * ~retake - pitch_cond = condition + self.base_pitch_embed(base_pitch[:, :, None]) + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) else: @@ -176,7 +190,7 @@ def forward( if pitch is None: pitch = base_pitch + pitch_pred_out - condition += self.pitch_embed(pitch[:, :, None]) + var_cond += self.pitch_embed(pitch[:, :, None]) variance_inputs = self.collect_variance_inputs(**kwargs) if retake is None: @@ -189,9 +203,9 @@ def forward( self.variance_embeds[v_name]((v_input * ~retake)[:, :, None]) for v_name, v_input in zip(self.variance_prediction_list, variance_inputs) ] - condition += torch.stack(variance_embeds, dim=-1).sum(-1) + var_cond += torch.stack(variance_embeds, dim=-1).sum(-1) - variance_outputs = self.variance_predictor(condition, variance_inputs, infer) + variance_outputs = self.variance_predictor(var_cond, variance_inputs, infer) if infer: variances_pred_out = self.collect_variance_outputs(variance_outputs) diff --git a/scripts/infer.py b/scripts/infer.py index 239ac5f7c..af809fadc 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -140,6 +140,7 @@ def acoustic( @click.option('--title', type=str, required=False, help='Title of output file') @click.option('--num', type=int, required=False, default=1, help='Number of runs') @click.option('--key', type=int, required=False, default=0, help='Key transition of pitch') +@click.option('--expressiveness', type=float, required=False, help='Static expressiveness control') @click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference') @click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio') def variance( @@ -152,6 +153,7 @@ def variance( title: str, num: int, key: int, + expressiveness: float, seed: int, speedup: int ): @@ -165,6 +167,9 @@ def variance( if (not out or out.resolve() == proj.parent.resolve()) and not title: name += '_variance' + if expressiveness is not None: + assert 0 <= expressiveness <= 1, 'Expressiveness must be in [-1, 1].' + with open(proj, 'r', encoding='utf-8') as f: params = json.load(f) @@ -200,6 +205,9 @@ def variance( spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None for param in params: + if expressiveness is not None: + param['expressiveness'] = expressiveness + if spk_mix is not None: param['ph_spk_mix_backup'] = param.get('ph_spk_mix') param['spk_mix_backup'] = param.get('spk_mix') From 4d906d2fc72afcd15b8c14d344d382b20c7e0326 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jun 2023 00:49:11 +0800 Subject: [PATCH 3/8] Fix assignment of `retake_` --- modules/toplevel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/toplevel.py b/modules/toplevel.py index 8028a2e25..ffae35243 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -149,7 +149,8 @@ def forward( if self.use_spk_id: condition += spk_embed - retake_ = torch.ones(1, 1, dtype=torch.bool, device=txt_tokens.device) # [B=1, T=1] + retake_ = torch.ones(1, 1, dtype=torch.bool, device=txt_tokens.device) \ + if retake is None else retake # [B=1, T=1] if expressiveness is None: retake_embed = self.retake_embed(retake_.long()) From 41e54c90d8b8db3450b6695dabc050da72d5702f Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jun 2023 12:33:44 +0800 Subject: [PATCH 4/8] Format code --- inference/ds_variance.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/inference/ds_variance.py b/inference/ds_variance.py index d7ac9ef97..eac4eeb19 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -320,11 +320,10 @@ def run_inference( else: predict_variances = self.model.predict_variances and self.global_predict_variances predict_pitch = self.model.predict_pitch and ( - self.global_predict_pitch or (param.get('f0_seq') is None and predict_variances) + self.global_predict_pitch or (param.get('f0_seq') is None and predict_variances) ) predict_dur = self.model.predict_dur and ( - self.global_predict_dur or ( - param.get('ph_dur') is None and (predict_pitch or predict_variances)) + self.global_predict_dur or (param.get('ph_dur') is None and (predict_pitch or predict_variances)) ) flag = (predict_dur, predict_pitch, predict_variances) predictor_flags.append(flag) From 2b9e011ea3cb0b7f7a26ec7b6815591aa886108f Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jun 2023 13:20:55 +0800 Subject: [PATCH 5/8] Add `expressiveness` in ONNX model --- deployment/exporters/variance_exporter.py | 9 +++++++-- deployment/modules/toplevel.py | 14 +++++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 7146087f7..90939f78d 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -221,6 +221,7 @@ def _torch_export_model(self): note_midi = torch.FloatTensor([[60.] * 4]).to(self.device) note_dur = torch.LongTensor([[2, 6, 3, 4]]).to(self.device) pitch = torch.FloatTensor([[60.] * 15]).to(self.device) + expressiveness = torch.ones_like(pitch) retake = torch.ones_like(pitch, dtype=torch.bool) torch.onnx.export( self.model.view_as_pitch_preprocess(), @@ -230,13 +231,14 @@ def _torch_export_model(self): note_midi, note_dur, pitch, + expressiveness, retake ), self.pitch_preprocess_cache_path, input_names=[ 'encoder_out', 'ph_dur', 'note_midi', 'note_dur', - 'pitch', 'retake' + 'pitch', 'expressiveness', 'retake' ], output_names=[ 'pitch_cond', 'base_pitch' @@ -257,6 +259,9 @@ def _torch_export_model(self): 'pitch': { 1: 'n_frames' }, + 'expressiveness': { + 1: 'n_frames' + }, 'retake': { 1: 'n_frames' }, @@ -271,7 +276,7 @@ def _torch_export_model(self): ) # Prepare inputs for denoiser tracing and PitchDiffusion scripting - shape = (1, 1, hparams['pitch_prediction_args']['num_pitch_bins'], 15) + shape = (1, 1, hparams['pitch_prediction_args']['repeat_bins'], 15) noise = torch.randn(shape, device=self.device) condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device) step = (torch.rand((1,), device=self.device) * hparams['K_step']).long() diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index df8290e3a..962b8411c 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -153,14 +153,22 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): def forward_pitch_preprocess( self, encoder_out, ph_dur, note_midi, note_dur, - pitch=None, retake=None + pitch=None, expressiveness=None, retake=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) - condition += self.retake_embed(retake.long()) + retake_true_embed = self.retake_embed( + torch.ones(1, 1, dtype=torch.long, device=encoder_out.device) + ) # [B=1, T=1] => [B=1, T=1, H] + retake_false_embed = self.retake_embed( + torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device) + ) # [B=1, T=1] => [B=1, T=1, H] + expressiveness = (expressiveness * retake)[:, :, None] # [B, T, 1] + retake_embed = expressiveness * retake_true_embed + (1. - expressiveness) * retake_false_embed + pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) base_pitch = self.smooth(frame_midi_pitch) base_pitch += (pitch - base_pitch) * ~retake - pitch_cond = condition + self.base_pitch_embed(base_pitch[:, :, None]) + pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) return pitch_cond, base_pitch def forward_pitch_diffusion( From 21be051c5a6a6039da0215adfe86027546a99dd6 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 13 Jun 2023 13:32:53 +0800 Subject: [PATCH 6/8] Swap input order --- deployment/exporters/variance_exporter.py | 8 ++++---- deployment/modules/toplevel.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/deployment/exporters/variance_exporter.py b/deployment/exporters/variance_exporter.py index 90939f78d..e784f4bc4 100644 --- a/deployment/exporters/variance_exporter.py +++ b/deployment/exporters/variance_exporter.py @@ -230,15 +230,15 @@ def _torch_export_model(self): ph_dur, note_midi, note_dur, - pitch, expressiveness, + pitch, retake ), self.pitch_preprocess_cache_path, input_names=[ 'encoder_out', 'ph_dur', 'note_midi', 'note_dur', - 'pitch', 'expressiveness', 'retake' + 'expressiveness', 'pitch', 'retake' ], output_names=[ 'pitch_cond', 'base_pitch' @@ -256,10 +256,10 @@ def _torch_export_model(self): 'note_dur': { 1: 'n_notes' }, - 'pitch': { + 'expressiveness': { 1: 'n_frames' }, - 'expressiveness': { + 'pitch': { 1: 'n_frames' }, 'retake': { diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 962b8411c..cd6abccc0 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -153,7 +153,7 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None): def forward_pitch_preprocess( self, encoder_out, ph_dur, note_midi, note_dur, - pitch=None, expressiveness=None, retake=None + expressiveness=None, pitch=None, retake=None ): condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size) retake_true_embed = self.retake_embed( @@ -166,8 +166,7 @@ def forward_pitch_preprocess( retake_embed = expressiveness * retake_true_embed + (1. - expressiveness) * retake_false_embed pitch_cond = condition + retake_embed frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None) - base_pitch = self.smooth(frame_midi_pitch) - base_pitch += (pitch - base_pitch) * ~retake + base_pitch = self.smooth(frame_midi_pitch) * retake + pitch * ~retake pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) return pitch_cond, base_pitch From 6e51c85fa1e1ca8c7c41605c49d52f48f2287417 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Fri, 16 Jun 2023 22:58:11 +0800 Subject: [PATCH 7/8] Fix typo --- scripts/infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/infer.py b/scripts/infer.py index af809fadc..4db09bd65 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -168,7 +168,7 @@ def variance( name += '_variance' if expressiveness is not None: - assert 0 <= expressiveness <= 1, 'Expressiveness must be in [-1, 1].' + assert 0 <= expressiveness <= 1, 'Expressiveness must be in [0, 1].' with open(proj, 'r', encoding='utf-8') as f: params = json.load(f) From 38a08ab7d003074e907d58778acc89e99f081dd0 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Mon, 31 Jul 2023 19:09:23 +0800 Subject: [PATCH 8/8] Adapt latest updates from main branch --- inference/ds_variance.py | 4 ++-- modules/toplevel.py | 8 +++++--- scripts/export.py | 4 ++++ scripts/infer.py | 12 ++++++------ 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/inference/ds_variance.py b/inference/ds_variance.py index 462fc04b6..fdc999295 100644 --- a/inference/ds_variance.py +++ b/inference/ds_variance.py @@ -272,9 +272,9 @@ def forward_model(self, sample): dur_pred, pitch_pred, variance_pred = self.model( txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur, - mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, + mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr, ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed, - retake=None, expr=expr, infer=True + infer=True ) if dur_pred is not None: dur_pred = self.rr(dur_pred, ph2word, word_dur) diff --git a/modules/toplevel.py b/modules/toplevel.py index 310da269a..01c52c09d 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -153,21 +153,23 @@ def forward( if self.predict_pitch: if pitch_retake is None: pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool) + else: + print(base_pitch, pitch, pitch_retake) + base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake if pitch_expr is None: pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long()) else: - retake_true_embed = self.retake_embed( + retake_true_embed = self.pitch_retake_embed( torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device) ) # [B=1, T=1] => [B=1, T=1, H] - retake_false_embed = self.retake_embed( + retake_false_embed = self.pitch_retake_embed( torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device) ) # [B=1, T=1] => [B=1, T=1, H] pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1] pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed pitch_cond = condition + pitch_retake_embed - base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake pitch_cond += self.base_pitch_embed(base_pitch[:, :, None]) if infer: pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True) diff --git a/scripts/export.py b/scripts/export.py index 29b4f3f5a..f7e6de5a8 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -140,6 +140,8 @@ def acoustic( @click.option('--exp', type=str, required=True, metavar='', help='Choose an experiment to export.') @click.option('--ckpt', type=int, required=False, metavar='', help='Checkpoint training steps.') @click.option('--out', type=str, required=False, metavar='', help='Output directory for the artifacts.') +@click.option('--expose_expr', is_flag=True, show_default=True, + help='Expose pitch expressiveness control functionality.') @click.option('--export_spk', type=str, required=False, multiple=True, metavar='', help='(for multi-speaker models) Export one or more speaker or speaker mix keys.') @click.option('--freeze_spk', type=str, required=False, metavar='', @@ -148,6 +150,7 @@ def variance( exp: str, ckpt: int = None, out: str = None, + expose_expr: bool = False, export_spk: List[str] = None, freeze_spk: str = None ): @@ -177,6 +180,7 @@ def variance( device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), cache_dir=root_dir / 'deployment' / 'cache', ckpt_steps=ckpt, + expose_expr=expose_expr, export_spk=export_spk_mix, freeze_spk=freeze_spk_mix ) diff --git a/scripts/infer.py b/scripts/infer.py index d1601409b..dbb28ac85 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -142,7 +142,7 @@ def acoustic( @click.option('--title', type=str, required=False, help='Title of output file') @click.option('--num', type=int, required=False, default=1, help='Number of runs') @click.option('--key', type=int, required=False, default=0, help='Key transition of pitch') -@click.option('--expressiveness', type=float, required=False, help='Static expressiveness control') +@click.option('--expr', type=float, required=False, help='Static expressiveness control') @click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference') @click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio') def variance( @@ -155,7 +155,7 @@ def variance( title: str, num: int, key: int, - expressiveness: float, + expr: float, seed: int, speedup: int ): @@ -169,8 +169,8 @@ def variance( if (not out or out.resolve() == proj.parent.resolve()) and not title: name += '_variance' - if expressiveness is not None: - assert 0 <= expressiveness <= 1, 'Expressiveness must be in [0, 1].' + if expr is not None: + assert 0 <= expr <= 1, 'Expressiveness must be in [0, 1].' with open(proj, 'r', encoding='utf-8') as f: params = json.load(f) @@ -207,8 +207,8 @@ def variance( spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None for param in params: - if expressiveness is not None: - param['expressiveness'] = expressiveness + if expr is not None: + param['expr'] = expr if spk_mix is not None: param['ph_spk_mix_backup'] = param.get('ph_spk_mix')