Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
device: Union[str, torch.device] = 'cpu',
cache_dir: Path = None,
ckpt_steps: int = None,
expose_expr: bool = False,
export_spk: List[Tuple[str, Dict[str, float]]] = None,
freeze_spk: Tuple[str, Dict[str, float]] = None
):
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
if self.model.predict_variances else None

# Attributes for exporting
self.expose_expr = expose_expr
self.freeze_spk: Tuple[str, Dict[str, float]] = freeze_spk \
if hparams['use_spk_id'] else None
self.export_spk: List[Tuple[str, Dict[str, float]]] = export_spk \
Expand Down Expand Up @@ -264,6 +266,10 @@ def _torch_export_model(self):
note_midi,
note_dur,
pitch,
*([
torch.ones_like(pitch)
if self.expose_expr else []
]),
retake,
*([torch.rand(
1, 15, hparams['hidden_size'],
Expand All @@ -274,7 +280,9 @@ def _torch_export_model(self):
input_names=[
'encoder_out', 'ph_dur',
'note_midi', 'note_dur',
'pitch', 'retake',
'pitch',
*(['expr'] if self.expose_expr else []),
'retake',
*(['spk_embed'] if input_spk_embed else [])
],
output_names=[
Expand All @@ -293,6 +301,7 @@ def _torch_export_model(self):
'note_dur': {
1: 'n_notes'
},
**({'expr': {1: 'n_frames'}} if self.expose_expr else {}),
'pitch': {
1: 'n_frames'
},
Expand Down
17 changes: 14 additions & 3 deletions deployment/modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,25 @@ def forward_mel2x_gather(self, x_src, x_dur, x_dim=None):

def forward_pitch_preprocess(
self, encoder_out, ph_dur, note_midi, note_dur,
pitch=None, retake=None, spk_embed=None
pitch=None, expr=None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
condition += self.pitch_retake_embed(retake.long())
if expr is None:
retake_embed = self.pitch_retake_embed(retake.long())
else:
retake_true_embed = self.pitch_retake_embed(
torch.ones(1, 1, dtype=torch.long, device=encoder_out.device)
) # [B=1, T=1] => [B=1, T=1, H]
retake_false_embed = self.pitch_retake_embed(
torch.zeros(1, 1, dtype=torch.long, device=encoder_out.device)
) # [B=1, T=1] => [B=1, T=1, H]
expr = (expr * retake)[:, :, None] # [B, T, 1]
retake_embed = expr * retake_true_embed + (1. - expr) * retake_false_embed
pitch_cond = condition + retake_embed
frame_midi_pitch = self.forward_mel2x_gather(note_midi, note_dur, x_dim=None)
base_pitch = self.smooth(frame_midi_pitch)
base_pitch = base_pitch * retake + pitch * ~retake
pitch_cond = condition + self.base_pitch_embed(base_pitch[:, :, None])
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if hparams['use_spk_id'] and spk_embed is not None:
pitch_cond += spk_embed
return pitch_cond, base_pitch
Expand Down
19 changes: 18 additions & 1 deletion inference/ds_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,22 @@ def preprocess_input(
summary['pitch'] = 'manual'
elif self.auto_completion_mode or self.global_predict_pitch:
summary['pitch'] = 'auto'

# Load expressiveness
expr = param.get('expr', 1.)
if isinstance(expr, (int, float, bool)):
summary['expr'] = f'static({expr:.3f})'
batch['expr'] = torch.FloatTensor([expr]).to(self.device)[:, None] # [B=1, T=1]
else:
summary['expr'] = 'dynamic'
expr = resample_align_curve(
np.array(expr.split(), np.float32),
original_timestep=float(param['expr_timestep']),
target_timestep=self.timestep,
align_length=T_s
)
batch['expr'] = torch.from_numpy(expr.astype(np.float32)).to(self.device)[None]

else:
summary['pitch'] = 'ignored'

Expand All @@ -235,6 +251,7 @@ def forward_model(self, sample):
ph_dur = sample['ph_dur']
mel2ph = sample['mel2ph']
base_pitch = sample['base_pitch']
expr = sample.get('expr')
pitch = sample.get('pitch')

if hparams['use_spk_id']:
Expand All @@ -255,7 +272,7 @@ def forward_model(self, sample):

dur_pred, pitch_pred, variance_pred = self.model(
txt_tokens, midi=midi, ph2word=ph2word, word_dur=word_dur, ph_dur=ph_dur,
mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch,
mel2ph=mel2ph, base_pitch=base_pitch, pitch=pitch, pitch_expr=expr,
ph_spk_mix_embed=ph_spk_mix_embed, spk_mix_embed=spk_mix_embed,
infer=True
)
Expand Down
20 changes: 17 additions & 3 deletions modules/toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def __init__(self, vocab_size):

def forward(
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
base_pitch=None, pitch=None, pitch_retake=None, variance_retake: Dict[str, Tensor] = None,
base_pitch=None, pitch=None, pitch_expr=None, pitch_retake=None,
variance_retake: Dict[str, Tensor] = None,
spk_id=None, infer=True, **kwargs
):
if self.use_spk_id:
Expand Down Expand Up @@ -151,10 +152,23 @@ def forward(

if self.predict_pitch:
if pitch_retake is None:
pitch_retake_embed = self.pitch_retake_embed(torch.ones_like(mel2ph))
pitch_retake = torch.ones_like(mel2ph, dtype=torch.bool)
else:
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long())
print(base_pitch, pitch, pitch_retake)
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake

if pitch_expr is None:
pitch_retake_embed = self.pitch_retake_embed(pitch_retake.long())
else:
retake_true_embed = self.pitch_retake_embed(
torch.ones(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
retake_false_embed = self.pitch_retake_embed(
torch.zeros(1, 1, dtype=torch.long, device=txt_tokens.device)
) # [B=1, T=1] => [B=1, T=1, H]
pitch_expr = (pitch_expr * pitch_retake)[:, :, None] # [B, T, 1]
pitch_retake_embed = pitch_expr * retake_true_embed + (1. - pitch_expr) * retake_false_embed

pitch_cond = condition + pitch_retake_embed
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if infer:
Expand Down
4 changes: 4 additions & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def acoustic(
@click.option('--exp', type=str, required=True, metavar='<exp>', help='Choose an experiment to export.')
@click.option('--ckpt', type=int, required=False, metavar='<steps>', help='Checkpoint training steps.')
@click.option('--out', type=str, required=False, metavar='<dir>', help='Output directory for the artifacts.')
@click.option('--expose_expr', is_flag=True, show_default=True,
help='Expose pitch expressiveness control functionality.')
@click.option('--export_spk', type=str, required=False, multiple=True, metavar='<mix>',
help='(for multi-speaker models) Export one or more speaker or speaker mix keys.')
@click.option('--freeze_spk', type=str, required=False, metavar='<mix>',
Expand All @@ -148,6 +150,7 @@ def variance(
exp: str,
ckpt: int = None,
out: str = None,
expose_expr: bool = False,
export_spk: List[str] = None,
freeze_spk: str = None
):
Expand Down Expand Up @@ -177,6 +180,7 @@ def variance(
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
cache_dir=root_dir / 'deployment' / 'cache',
ckpt_steps=ckpt,
expose_expr=expose_expr,
export_spk=export_spk_mix,
freeze_spk=freeze_spk_mix
)
Expand Down
8 changes: 8 additions & 0 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def acoustic(
@click.option('--title', type=str, required=False, help='Title of output file')
@click.option('--num', type=int, required=False, default=1, help='Number of runs')
@click.option('--key', type=int, required=False, default=0, help='Key transition of pitch')
@click.option('--expr', type=float, required=False, help='Static expressiveness control')
@click.option('--seed', type=int, required=False, default=-1, help='Random seed of the inference')
@click.option('--speedup', type=int, required=False, default=0, help='Diffusion acceleration ratio')
def variance(
Expand All @@ -154,6 +155,7 @@ def variance(
title: str,
num: int,
key: int,
expr: float,
seed: int,
speedup: int
):
Expand All @@ -167,6 +169,9 @@ def variance(
if (not out or out.resolve() == proj.parent.resolve()) and not title:
name += '_variance'

if expr is not None:
assert 0 <= expr <= 1, 'Expressiveness must be in [0, 1].'

with open(proj, 'r', encoding='utf-8') as f:
params = json.load(f)

Expand Down Expand Up @@ -202,6 +207,9 @@ def variance(

spk_mix = parse_commandline_spk_mix(spk) if hparams['use_spk_id'] and spk is not None else None
for param in params:
if expr is not None:
param['expr'] = expr

if spk_mix is not None:
param['ph_spk_mix_backup'] = param.get('ph_spk_mix')
param['spk_mix_backup'] = param.get('spk_mix')
Expand Down