From 36c0f6655a48c494d0edfd5ea79d0e6d880be5d0 Mon Sep 17 00:00:00 2001 From: yxlllc <33565655+yxlllc@users.noreply.github.com> Date: Sat, 3 Aug 2024 22:11:40 +0800 Subject: [PATCH 1/5] Faster NSF inference (#202) * faster nsf * faster nsf * Slightly optimize * Pad negative length --------- Co-authored-by: yqzhishen --- modules/nsf_hifigan/models.py | 48 ++++++++++------------------------- 1 file changed, 13 insertions(+), 35 deletions(-) diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py index a77eb0a38..cc21039f7 100644 --- a/modules/nsf_hifigan/models.py +++ b/modules/nsf_hifigan/models.py @@ -130,41 +130,20 @@ def _f02uv(self, f0): uv = uv * (f0 > self.voiced_threshold) return uv - def _f02sine(self, f0_values, upp): - """ f0_values: (batchsize, length, dim) + def _f02sine(self, f0, upp): + """ f0: (batchsize, length, dim) where dim indicates fundamental tone and overtones """ - rad_values = (f0_values / self.sampling_rate).fmod(1.) # %1意味着n_har的乘积无法后处理优化 - rand_ini = torch.rand(1, self.dim, device=f0_values.device) - rand_ini[:, 0] = 0 - rad_values[:, 0, :] += rand_ini - is_half = rad_values.dtype is not torch.float32 - tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 - if is_half: - tmp_over_one = tmp_over_one.half() - else: - tmp_over_one = tmp_over_one.float() - tmp_over_one *= upp - tmp_over_one = F.interpolate( - tmp_over_one.transpose(2, 1), scale_factor=upp, - mode='linear', align_corners=True - ).transpose(2, 1) - rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) - tmp_over_one = tmp_over_one.fmod(1.) - diff = F.conv2d( - tmp_over_one.unsqueeze(1), torch.FloatTensor([[[[-1.], [1.]]]]).to(tmp_over_one.device), - stride=(1, 1), padding=0, dilation=(1, 1) - ).squeeze(1) # Equivalent to torch.diff, but able to export ONNX - cumsum_shift = (diff < 0).double() - cumsum_shift = torch.cat(( - torch.zeros((1, 1, self.dim), dtype=torch.double).to(f0_values.device), - cumsum_shift - ), dim=1) - sines = torch.sin(torch.cumsum(rad_values.double() + cumsum_shift, dim=1) * 2 * np.pi) - if is_half: - sines = sines.half() - else: - sines = sines.float() + rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device) + rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5 + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0) + rad += F.pad(rad_acc, (0, 0, 1, -1)) + rad = rad.reshape(f0.shape[0], -1, 1) + rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1)) + rand_ini = torch.rand(1, 1, self.dim, device=f0.device) + rand_ini[..., 0] = 0 + rad += rand_ini + sines = torch.sin(2 * np.pi * rad) return sines @torch.no_grad() @@ -176,8 +155,7 @@ def forward(self, f0, upp): output uv: tensor(batchsize=1, length, 1) """ f0 = f0.unsqueeze(-1) - fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) - sine_waves = self._f02sine(fn, upp) * self.sine_amp + sine_waves = self._f02sine(f0, upp) * self.sine_amp uv = (f0 > self.voiced_threshold).float() uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 From 96e8ac2831b1b78e20f0f5aaa6e59880494e5f0a Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 28 Aug 2024 20:18:21 +0800 Subject: [PATCH 2/5] [DONE]New AUX_Decoder/Backbone Network : LYNXNet (#200) * Update __init__.py * Update LYNXNet * add dropout --- modules/aux_decoder/LYNXNetDecoder.py | 70 +++++++++++ modules/aux_decoder/__init__.py | 7 +- modules/backbones/LYNXNet.py | 171 ++++++++++++++++++++++++++ modules/backbones/__init__.py | 4 +- 4 files changed, 249 insertions(+), 3 deletions(-) create mode 100644 modules/aux_decoder/LYNXNetDecoder.py create mode 100644 modules/backbones/LYNXNet.py diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py new file mode 100644 index 000000000..42b552fbf --- /dev/null +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -0,0 +1,70 @@ +# refer to: +# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/model_conformer_naive.py +# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/naive_v2_diff.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modules.backbones.LYNXNet import LYNXConvModule + + +class LYNXNetDecoderLayer(nn.Module): + """ + LYNXNet Decoder Layer + + Args: + dim (int): Dimension of model + expansion_factor (int): Expansion factor of conv module, default 2 + kernel_size (int): Kernel size of conv module, default 31 + in_norm (bool): Whether to use norm + activation (str): Activation Function for conv module + """ + + def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='SiLU', dropout=0.): + super().__init__() + self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, in_norm=in_norm, activation=activation, dropout=dropout) + + def forward(self, x) -> torch.Tensor: + residual = x + x = self.convmodule(x) + x = residual + x + + return x + + +class LYNXNetDecoder(nn.Module): + def __init__( + self, in_dims, out_dims, /, *, + num_channels=512, num_layers=6, kernel_size=31, dropout_rate=0. + ): + super().__init__() + self.input_projection = nn.Conv1d(in_dims, num_channels, 1) + self.encoder_layers = nn.ModuleList( + LYNXNetDecoderLayer( + dim=num_channels, + expansion_factor=2, + kernel_size=kernel_size, + in_norm=False, + activation='SiLU', + dropout=dropout_rate) for _ in range(num_layers) + ) + self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) + + def forward(self, x, infer=False): + """ + Args: + x (torch.Tensor): Input tensor (#batch, length, in_dims) + return: + torch.Tensor: Output tensor (#batch, length, out_dims) + """ + x = x.transpose(1, 2) + x = self.input_projection(x) + x = x.transpose(1, 2) + for layer in self.encoder_layers: + x = layer(x) + x = x.transpose(1, 2) + x = self.output_projection(x) + x = x.transpose(1, 2) + + return x \ No newline at end of file diff --git a/modules/aux_decoder/__init__.py b/modules/aux_decoder/__init__.py index 54ceb2113..4801b1156 100644 --- a/modules/aux_decoder/__init__.py +++ b/modules/aux_decoder/__init__.py @@ -2,13 +2,16 @@ from torch import nn from .convnext import ConvNeXtDecoder +from .LYNXNetDecoder import LYNXNetDecoder from utils import filter_kwargs AUX_DECODERS = { - 'convnext': ConvNeXtDecoder + 'convnext': ConvNeXtDecoder, + 'lynxnet': LYNXNetDecoder } AUX_LOSSES = { - 'convnext': nn.L1Loss + 'convnext': nn.L1Loss, + 'lynxnet': nn.L1Loss } diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py new file mode 100644 index 000000000..ffc5a94e5 --- /dev/null +++ b/modules/backbones/LYNXNet.py @@ -0,0 +1,171 @@ +# refer to: +# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/model_conformer_naive.py +# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/naive_v2_diff.py + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils.hparams import hparams + + +class SwiGLU(nn.Module): + ## Swish-Applies the gated linear unit function. + def __init__(self, dim=-1): + super().__init__() + self.dim = dim + def forward(self, x): + # out, gate = x.chunk(2, dim=self.dim) + # Using torch.split instead of chunk for ONNX export compatibility. + out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) + return out * F.silu(gate) + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Transpose(nn.Module): + def __init__(self, dims): + super().__init__() + assert len(dims) == 2, 'dims must be a tuple of two dimensions' + self.dims = dims + + def forward(self, x): + return x.transpose(*self.dims) + + +class LYNXConvModule(nn.Module): + @staticmethod + def calc_same_padding(kernel_size): + pad = kernel_size // 2 + return (pad, pad - (kernel_size + 1) % 2) + + def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): + super().__init__() + inner_dim = dim * expansion_factor + _normalize = nn.LayerNorm(dim) if in_norm or dim > 512 else nn.Identity() + activation_classes = { + 'SiLU': nn.SiLU, + 'ReLU': nn.ReLU, + 'PReLU': lambda: nn.PReLU(inner_dim) + } + activation = activation if activation is not None else 'PReLU' + if activation not in activation_classes: + raise ValueError(f'{activation} is not a valid activation') + _activation = activation_classes[activation]() + padding = self.calc_same_padding(kernel_size) + if float(dropout) > 0.: + _dropout = nn.Dropout(dropout) + else: + _dropout = nn.Identity() + self.net = nn.Sequential( + _normalize, + Transpose((1, 2)), + nn.Conv1d(dim, inner_dim * 2, 1), + SwiGLU(dim=1), + nn.Conv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding[0], groups=inner_dim), + _activation, + nn.Conv1d(inner_dim, dim, 1), + Transpose((1, 2)), + _dropout + ) + + def forward(self, x): + return self.net(x) + + +class LYNXNetResidualLayer(nn.Module): + def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): + super().__init__() + self.diffusion_projection = nn.Conv1d(dim, dim, 1) + self.conditioner_projection = nn.Conv1d(dim_cond, dim, 1) + self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, in_norm=in_norm, activation=activation, dropout=dropout) + + def forward(self, x, conditioner, diffusion_step): + res_x = x.transpose(1, 2) + x = x + self.diffusion_projection(diffusion_step) + self.conditioner_projection(conditioner) + x = x.transpose(1, 2) + x = self.convmodule(x) # (#batch, dim, length) + x = x + res_x + x = x.transpose(1, 2) + + return x # (#batch, length, dim) + + +class LYNXNet(nn.Module): + def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in_norm=False, activation='PReLU', dropout=0.): + """ + LYNXNet(Linear Gated Depthwise Separable Convolution Network) + TIPS:You can control the style of the generated results by modifying the 'activation', + - 'PReLU'(default) : Similar to WaveNet + - 'SiLU' : Voice will be more pronounced, not recommended for use under DDPM + - 'ReLU' : Contrary to 'SiLU', Voice will be weakened + """ + super().__init__() + self.input_projection = nn.Conv1d(in_dims * n_feats, n_chans, 1) + self.diffusion_embedding = nn.Sequential( + SinusoidalPosEmb(n_chans), + nn.Linear(n_chans, n_chans * 4), + nn.GELU(), + nn.Linear(n_chans * 4, n_chans), + ) + self.residual_layers = nn.ModuleList( + [ + LYNXNetResidualLayer( + dim_cond=hparams['hidden_size'], + dim=n_chans, + expansion_factor=n_dilates, + kernel_size=31, + in_norm=in_norm, + activation=activation, + dropout=dropout + ) + for i in range(n_layers) + ] + ) + self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + :param spec: [B, F, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, H, T] + :return: + """ + + # To keep compatibility with DiffSVC, [B, 1, M, T] + x = spec + use_4_dim = False + if x.dim() == 4: + x = x[:, 0] + use_4_dim = True + + assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}" + + x = self.input_projection(x) # x [B, residual_channel, T] + x = F.gelu(x) + + diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1) + + for layer in self.residual_layers: + x = layer(x, cond, diffusion_step) + + # MLP and GLU + x = self.output_projection(x) # [B, 128, T] + + return x[:, None] if use_4_dim else x diff --git a/modules/backbones/__init__.py b/modules/backbones/__init__.py index 1061b8779..a91578567 100644 --- a/modules/backbones/__init__.py +++ b/modules/backbones/__init__.py @@ -1,5 +1,7 @@ from modules.backbones.wavenet import WaveNet +from modules.backbones.LYNXNet import LYNXNet BACKBONES = { - 'wavenet': WaveNet + 'wavenet': WaveNet, + 'lynxnet': LYNXNet } From 0a6a802860469521ca154e712532897bd146efd9 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Fri, 6 Sep 2024 22:29:12 +0800 Subject: [PATCH 3/5] Lynxnet outnorm (#206) * post-norm * fix * add norm+mlp * Update LYNXNet.py * Update LYNXNetDecoder.py * do not need mlp * do not need mlp * Add out norm for LYNXNET * Add out norm for LYNXNETDecoder --- modules/aux_decoder/LYNXNetDecoder.py | 5 ++++- modules/backbones/LYNXNet.py | 9 ++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index 42b552fbf..4ac5923ee 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -49,6 +49,7 @@ def __init__( activation='SiLU', dropout=dropout_rate) for _ in range(num_layers) ) + self.norm = nn.LayerNorm(num_channels) self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) def forward(self, x, infer=False): @@ -63,8 +64,10 @@ def forward(self, x, infer=False): x = x.transpose(1, 2) for layer in self.encoder_layers: x = layer(x) + x = self.norm(x) x = x.transpose(1, 2) + x = self.output_projection(x) x = x.transpose(1, 2) - return x \ No newline at end of file + return x diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index ffc5a94e5..b4527f98b 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -57,7 +57,6 @@ def calc_same_padding(kernel_size): def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): super().__init__() inner_dim = dim * expansion_factor - _normalize = nn.LayerNorm(dim) if in_norm or dim > 512 else nn.Identity() activation_classes = { 'SiLU': nn.SiLU, 'ReLU': nn.ReLU, @@ -73,7 +72,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activat else: _dropout = nn.Identity() self.net = nn.Sequential( - _normalize, + nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), SwiGLU(dim=1), @@ -137,9 +136,10 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in for i in range(n_layers) ] ) + self.norm = nn.LayerNorm(n_chans) self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) nn.init.zeros_(self.output_projection.weight) - + def forward(self, spec, diffusion_step, cond): """ :param spec: [B, F, M, T] @@ -164,6 +164,9 @@ def forward(self, spec, diffusion_step, cond): for layer in self.residual_layers: x = layer(x, cond, diffusion_step) + + # post-norm + x = self.norm(x.transpose(1, 2)).transpose(1, 2) # MLP and GLU x = self.output_projection(x) # [B, 128, T] From 2ac0158b51f69c26c3b4fc4f55ccaa0f45916fb8 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Sun, 20 Oct 2024 21:41:52 +0800 Subject: [PATCH 4/5] delete lynxnet aux_decoder (#212) --- modules/aux_decoder/LYNXNetDecoder.py | 73 --------------------------- modules/aux_decoder/__init__.py | 7 +-- 2 files changed, 2 insertions(+), 78 deletions(-) delete mode 100644 modules/aux_decoder/LYNXNetDecoder.py diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py deleted file mode 100644 index 4ac5923ee..000000000 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ /dev/null @@ -1,73 +0,0 @@ -# refer to: -# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/model_conformer_naive.py -# https://github.com/CNChTu/Diffusion-SVC/blob/v2.0_dev/diffusion/naive_v2/naive_v2_diff.py - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from modules.backbones.LYNXNet import LYNXConvModule - - -class LYNXNetDecoderLayer(nn.Module): - """ - LYNXNet Decoder Layer - - Args: - dim (int): Dimension of model - expansion_factor (int): Expansion factor of conv module, default 2 - kernel_size (int): Kernel size of conv module, default 31 - in_norm (bool): Whether to use norm - activation (str): Activation Function for conv module - """ - - def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='SiLU', dropout=0.): - super().__init__() - self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, in_norm=in_norm, activation=activation, dropout=dropout) - - def forward(self, x) -> torch.Tensor: - residual = x - x = self.convmodule(x) - x = residual + x - - return x - - -class LYNXNetDecoder(nn.Module): - def __init__( - self, in_dims, out_dims, /, *, - num_channels=512, num_layers=6, kernel_size=31, dropout_rate=0. - ): - super().__init__() - self.input_projection = nn.Conv1d(in_dims, num_channels, 1) - self.encoder_layers = nn.ModuleList( - LYNXNetDecoderLayer( - dim=num_channels, - expansion_factor=2, - kernel_size=kernel_size, - in_norm=False, - activation='SiLU', - dropout=dropout_rate) for _ in range(num_layers) - ) - self.norm = nn.LayerNorm(num_channels) - self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) - - def forward(self, x, infer=False): - """ - Args: - x (torch.Tensor): Input tensor (#batch, length, in_dims) - return: - torch.Tensor: Output tensor (#batch, length, out_dims) - """ - x = x.transpose(1, 2) - x = self.input_projection(x) - x = x.transpose(1, 2) - for layer in self.encoder_layers: - x = layer(x) - x = self.norm(x) - x = x.transpose(1, 2) - - x = self.output_projection(x) - x = x.transpose(1, 2) - - return x diff --git a/modules/aux_decoder/__init__.py b/modules/aux_decoder/__init__.py index 4801b1156..54ceb2113 100644 --- a/modules/aux_decoder/__init__.py +++ b/modules/aux_decoder/__init__.py @@ -2,16 +2,13 @@ from torch import nn from .convnext import ConvNeXtDecoder -from .LYNXNetDecoder import LYNXNetDecoder from utils import filter_kwargs AUX_DECODERS = { - 'convnext': ConvNeXtDecoder, - 'lynxnet': LYNXNetDecoder + 'convnext': ConvNeXtDecoder } AUX_LOSSES = { - 'convnext': nn.L1Loss, - 'lynxnet': nn.L1Loss + 'convnext': nn.L1Loss } From 40f8488c812819224292162242c1a431609a697f Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sun, 3 Nov 2024 00:10:52 +0800 Subject: [PATCH 5/5] refactor configuration options --- configs/acoustic.yaml | 10 +++-- configs/templates/config_acoustic.yaml | 16 ++++++-- configs/templates/config_variance.yaml | 16 +++++--- configs/variance.yaml | 16 +++++--- deployment/modules/toplevel.py | 32 ++++------------ modules/backbones/LYNXNet.py | 52 ++++++++++++++------------ modules/backbones/__init__.py | 10 +++++ modules/backbones/wavenet.py | 20 +++++----- modules/core/ddpm.py | 4 +- modules/core/reflow.py | 4 +- modules/fastspeech/param_adaptor.py | 16 +++++--- modules/toplevel.py | 47 +++++++++++------------ 12 files changed, 130 insertions(+), 113 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 0364b5c15..6efe72367 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -73,10 +73,12 @@ sampling_steps: 20 diff_accelerator: ddim diff_speedup: 10 hidden_size: 256 -residual_layers: 20 -residual_channels: 512 -dilation_cycle_length: 4 # * -backbone_type: 'wavenet' +backbone_type: 'lynxnet' +backbone_args: + num_channels: 1024 + num_layers: 6 + kernel_size: 31 + dropout_rate: 0.0 main_loss_type: l2 main_loss_log_norm: false schedule_type: 'linear' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 198444bc7..f0edef7a8 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -51,16 +51,24 @@ augmentation_args: range: [0.5, 2.] scale: 0.75 -residual_channels: 512 -residual_layers: 20 - -# shallow diffusion +# diffusion and shallow diffusion diffusion_type: reflow use_shallow_diffusion: true T_start: 0.4 T_start_infer: 0.4 K_step: 300 K_step_infer: 300 +backbone_type: 'lynxnet' +backbone_args: + num_channels: 1024 + num_layers: 6 + kernel_size: 31 + dropout_rate: 0.0 +# backbone_type: 'wavenet' +# backbone_args: +# num_channels: 512 +# num_layers: 20 +# dilation_cycle_length: 4 shallow_diffusion_args: train_aux_decoder: true train_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index d75667797..842a76395 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -78,15 +78,19 @@ pitch_prediction_args: pitd_clip_min: -12.0 pitd_clip_max: 12.0 repeat_bins: 64 - residual_layers: 20 - residual_channels: 256 - dilation_cycle_length: 5 # * + backbone_type: 'wavenet' + backbone_args: + num_layers: 20 + num_channels: 256 + dilation_cycle_length: 5 variances_prediction_args: total_repeat_bins: 48 - residual_layers: 10 - residual_channels: 192 - dilation_cycle_length: 4 # * + backbone_type: 'wavenet' + backbone_args: + num_layers: 10 + num_channels: 192 + dilation_cycle_length: 4 lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/configs/variance.yaml b/configs/variance.yaml index 2c6d002da..95a0781be 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -68,9 +68,11 @@ pitch_prediction_args: pitd_clip_min: -12.0 pitd_clip_max: 12.0 repeat_bins: 64 - residual_layers: 20 - residual_channels: 256 - dilation_cycle_length: 5 # * + backbone_type: 'wavenet' + backbone_args: + num_layers: 20 + num_channels: 256 + dilation_cycle_length: 5 energy_db_min: -96.0 energy_db_max: -12.0 @@ -89,9 +91,11 @@ tension_smooth_width: 0.12 variances_prediction_args: total_repeat_bins: 48 - residual_layers: 10 - residual_channels: 192 - dilation_cycle_length: 4 # * + backbone_type: 'wavenet' + backbone_args: + num_layers: 10 + num_channels: 192 + dilation_cycle_length: 4 lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/deployment/modules/toplevel.py b/deployment/modules/toplevel.py index 1dd4fe129..e358f25a0 100644 --- a/deployment/modules/toplevel.py +++ b/deployment/modules/toplevel.py @@ -31,12 +31,8 @@ def __init__(self, vocab_size, out_dims): num_feats=1, timesteps=hparams['timesteps'], k_step=hparams['K_step'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': hparams['residual_layers'], - 'n_chans': hparams['residual_channels'], - 'n_dilates': hparams['dilation_cycle_length'], - }, + backbone_type=self.backbone_type, + backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) @@ -46,12 +42,8 @@ def __init__(self, vocab_size, out_dims): num_feats=1, t_start=hparams['T_start'], time_scale_factor=hparams['time_scale_factor'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': hparams['residual_layers'], - 'n_chans': hparams['residual_channels'], - 'n_dilates': hparams['dilation_cycle_length'], - }, + backbone_type=self.backbone_type, + backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) @@ -155,12 +147,8 @@ def __init__(self, vocab_size): repeat_bins=pitch_hparams['repeat_bins'], timesteps=hparams['timesteps'], k_step=hparams['K_step'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': pitch_hparams['residual_layers'], - 'n_chans': pitch_hparams['residual_channels'], - 'n_dilates': pitch_hparams['dilation_cycle_length'], - } + backbone_type=self.pitch_backbone_type, + backbone_args=self.pitch_backbone_args ) elif self.diffusion_type == 'reflow': self.pitch_predictor = PitchRectifiedFlowONNX( @@ -170,12 +158,8 @@ def __init__(self, vocab_size): cmax=pitch_hparams['pitd_clip_max'], repeat_bins=pitch_hparams['repeat_bins'], time_scale_factor=hparams['time_scale_factor'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': pitch_hparams['residual_layers'], - 'n_chans': pitch_hparams['residual_channels'], - 'n_dilates': pitch_hparams['dilation_cycle_length'], - } + backbone_type=self.pitch_backbone_type, + backbone_args=self.pitch_backbone_args ) else: raise ValueError(f"Invalid diffusion type: {self.diffusion_type}") diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index b4527f98b..ecf97a95c 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -54,7 +54,7 @@ def calc_same_padding(kernel_size): pad = kernel_size // 2 return (pad, pad - (kernel_size + 1) % 2) - def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): + def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.): super().__init__() inner_dim = dim * expansion_factor activation_classes = { @@ -88,11 +88,11 @@ def forward(self, x): class LYNXNetResidualLayer(nn.Module): - def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): + def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.): super().__init__() self.diffusion_projection = nn.Conv1d(dim, dim, 1) self.conditioner_projection = nn.Conv1d(dim_cond, dim, 1) - self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, in_norm=in_norm, activation=activation, dropout=dropout) + self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, activation=activation, dropout=dropout) def forward(self, x, conditioner, diffusion_step): res_x = x.transpose(1, 2) @@ -106,7 +106,7 @@ def forward(self, x, conditioner, diffusion_step): class LYNXNet(nn.Module): - def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in_norm=False, activation='PReLU', dropout=0.): + def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31, activation='PReLU', dropout=0.): """ LYNXNet(Linear Gated Depthwise Separable Convolution Network) TIPS:You can control the style of the generated results by modifying the 'activation', @@ -115,29 +115,30 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in - 'ReLU' : Contrary to 'SiLU', Voice will be weakened """ super().__init__() - self.input_projection = nn.Conv1d(in_dims * n_feats, n_chans, 1) + self.in_dims = in_dims + self.n_feats = n_feats + self.input_projection = nn.Conv1d(in_dims * n_feats, num_channels, 1) self.diffusion_embedding = nn.Sequential( - SinusoidalPosEmb(n_chans), - nn.Linear(n_chans, n_chans * 4), + SinusoidalPosEmb(num_channels), + nn.Linear(num_channels, num_channels * 4), nn.GELU(), - nn.Linear(n_chans * 4, n_chans), + nn.Linear(num_channels * 4, num_channels), ) self.residual_layers = nn.ModuleList( [ LYNXNetResidualLayer( dim_cond=hparams['hidden_size'], - dim=n_chans, - expansion_factor=n_dilates, - kernel_size=31, - in_norm=in_norm, + dim=num_channels, + expansion_factor=expansion_factor, + kernel_size=kernel_size, activation=activation, dropout=dropout ) - for i in range(n_layers) + for i in range(num_layers) ] ) - self.norm = nn.LayerNorm(n_chans) - self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) + self.norm = nn.LayerNorm(num_channels) + self.output_projection = nn.Conv1d(num_channels, in_dims * n_feats, kernel_size=1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): @@ -148,14 +149,10 @@ def forward(self, spec, diffusion_step, cond): :return: """ - # To keep compatibility with DiffSVC, [B, 1, M, T] - x = spec - use_4_dim = False - if x.dim() == 4: - x = x[:, 0] - use_4_dim = True - - assert x.dim() == 3, f"mel must be 3 dim tensor, but got {x.dim()}" + if self.n_feats == 1: + x = spec.squeeze(1) # [B, M, T] + else: + x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T] x = self.input_projection(x) # x [B, residual_channel, T] x = F.gelu(x) @@ -171,4 +168,11 @@ def forward(self, spec, diffusion_step, cond): # MLP and GLU x = self.output_projection(x) # [B, 128, T] - return x[:, None] if use_4_dim else x + if self.n_feats == 1: + x = x[:, None, :, :] + else: + # This is the temporary solution since PyTorch 1.13 + # does not support exporting aten::unflatten to ONNX + # x = x.unflatten(dim=1, sizes=(self.n_feats, self.in_dims)) + x = x.reshape(-1, self.n_feats, self.in_dims, x.shape[2]) + return x diff --git a/modules/backbones/__init__.py b/modules/backbones/__init__.py index a91578567..c9cf0b8d5 100644 --- a/modules/backbones/__init__.py +++ b/modules/backbones/__init__.py @@ -1,7 +1,17 @@ +import torch.nn from modules.backbones.wavenet import WaveNet from modules.backbones.LYNXNet import LYNXNet +from utils import filter_kwargs BACKBONES = { 'wavenet': WaveNet, 'lynxnet': LYNXNet } + +def build_backbone( + out_dims: int, num_feats: int, + backbone_type: str, backbone_args: dict +) -> torch.nn.Module: + backbone = BACKBONES[backbone_type] + kwargs = filter_kwargs(backbone_args, backbone) + return BACKBONES[backbone_type](out_dims, num_feats, **kwargs) \ No newline at end of file diff --git a/modules/backbones/wavenet.py b/modules/backbones/wavenet.py index 0a1400d30..3ddbb4689 100644 --- a/modules/backbones/wavenet.py +++ b/modules/backbones/wavenet.py @@ -63,27 +63,27 @@ def forward(self, x, conditioner, diffusion_step): class WaveNet(nn.Module): - def __init__(self, in_dims, n_feats, *, n_layers=20, n_chans=256, n_dilates=4): + def __init__(self, in_dims, n_feats, *, num_layers=20, num_channels=256, dilation_cycle_length=4): super().__init__() self.in_dims = in_dims self.n_feats = n_feats - self.input_projection = Conv1d(in_dims * n_feats, n_chans, 1) - self.diffusion_embedding = SinusoidalPosEmb(n_chans) + self.input_projection = Conv1d(in_dims * n_feats, num_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(num_channels) self.mlp = nn.Sequential( - nn.Linear(n_chans, n_chans * 4), + nn.Linear(num_channels, num_channels * 4), nn.Mish(), - nn.Linear(n_chans * 4, n_chans) + nn.Linear(num_channels * 4, num_channels) ) self.residual_layers = nn.ModuleList([ ResidualBlock( encoder_hidden=hparams['hidden_size'], - residual_channels=n_chans, - dilation=2 ** (i % n_dilates) + residual_channels=num_channels, + dilation=2 ** (i % dilation_cycle_length) ) - for i in range(n_layers) + for i in range(num_layers) ]) - self.skip_projection = Conv1d(n_chans, n_chans, 1) - self.output_projection = Conv1d(n_chans, in_dims * n_feats, 1) + self.skip_projection = Conv1d(num_channels, num_channels, 1) + self.output_projection = Conv1d(num_channels, in_dims * n_feats, 1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): diff --git a/modules/core/ddpm.py b/modules/core/ddpm.py index d79f21c79..6b0ae4803 100644 --- a/modules/core/ddpm.py +++ b/modules/core/ddpm.py @@ -9,7 +9,7 @@ from torch import nn from tqdm import tqdm -from modules.backbones import BACKBONES +from modules.backbones import build_backbone from utils.hparams import hparams @@ -57,7 +57,7 @@ def __init__(self, out_dims, num_feats=1, timesteps=1000, k_step=1000, backbone_type=None, backbone_args=None, betas=None, spec_min=None, spec_max=None): super().__init__() - self.denoise_fn: nn.Module = BACKBONES[backbone_type](out_dims, num_feats, **backbone_args) + self.denoise_fn: nn.Module = build_backbone(out_dims, num_feats, backbone_type, backbone_args) self.out_dims = out_dims self.num_feats = num_feats diff --git a/modules/core/reflow.py b/modules/core/reflow.py index 2a2b21fcb..f09eb2392 100644 --- a/modules/core/reflow.py +++ b/modules/core/reflow.py @@ -6,7 +6,7 @@ import torch.nn as nn from tqdm import tqdm -from modules.backbones import BACKBONES +from modules.backbones import build_backbone from utils.hparams import hparams @@ -15,7 +15,7 @@ def __init__(self, out_dims, num_feats=1, t_start=0., time_scale_factor=1000, backbone_type=None, backbone_args=None, spec_min=None, spec_max=None): super().__init__() - self.velocity_fn: nn.Module = BACKBONES[backbone_type](out_dims, num_feats, **backbone_args) + self.velocity_fn: nn.Module = build_backbone(out_dims, num_feats, backbone_type, backbone_args) self.out_dims = out_dims self.num_feats = num_feats self.use_shallow_diffusion = hparams.get('use_shallow_diffusion', False) diff --git a/modules/fastspeech/param_adaptor.py b/modules/fastspeech/param_adaptor.py index e5668536b..ace58ff41 100644 --- a/modules/fastspeech/param_adaptor.py +++ b/modules/fastspeech/param_adaptor.py @@ -68,6 +68,14 @@ def build_adaptor(self, cls=MultiVarianceDiffusion): f'Total number of repeat bins must be divisible by number of ' \ f'variance parameters ({len(self.variance_prediction_list)}).' repeat_bins = total_repeat_bins // len(self.variance_prediction_list) + backbone_type = variances_hparams.get('backbone_type', + variances_hparams.get('backbone_type', + variances_hparams.get('diff_decoder_type', 'wavenet'))) + backbone_args = variances_hparams.get('backbone_args', { + 'num_layers': variances_hparams.get('residual_layers'), + 'num_channels': variances_hparams.get('residual_channels'), + 'dilation_cycle_length': variances_hparams.get('dilation_cycle_length'), + } if backbone_type == 'wavenet' else None) kwargs = filter_kwargs( { 'ranges': ranges, @@ -75,12 +83,8 @@ def build_adaptor(self, cls=MultiVarianceDiffusion): 'repeat_bins': repeat_bins, 'timesteps': hparams.get('timesteps'), 'time_scale_factor': hparams.get('time_scale_factor'), - 'backbone_type': hparams.get('backbone_type', hparams.get('diff_decoder_type')), - 'backbone_args': { - 'n_layers': variances_hparams['residual_layers'], - 'n_chans': variances_hparams['residual_channels'], - 'n_dilates': variances_hparams['dilation_cycle_length'], - } + 'backbone_type': backbone_type, + 'backbone_args': backbone_args }, cls ) diff --git a/modules/toplevel.py b/modules/toplevel.py index 1976d09a9..9ec16bec1 100644 --- a/modules/toplevel.py +++ b/modules/toplevel.py @@ -53,18 +53,20 @@ def __init__(self, vocab_size, out_dims): aux_decoder_args=self.shallow_args['aux_decoder_args'] ) self.diffusion_type = hparams.get('diffusion_type', 'ddpm') + self.backbone_type = hparams.get('backbone_type', hparams.get('diff_decoder_type', 'wavenet')) + self.backbone_args = hparams.get('backbone_args', { + 'num_layers': hparams.get('residual_layers'), + 'num_channels': hparams.get('residual_channels'), + 'dilation_cycle_length': hparams.get('dilation_cycle_length'), + } if self.backbone_type == 'wavenet' else None) if self.diffusion_type == 'ddpm': self.diffusion = GaussianDiffusion( out_dims=out_dims, num_feats=1, timesteps=hparams['timesteps'], k_step=hparams['K_step'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': hparams['residual_layers'], - 'n_chans': hparams['residual_channels'], - 'n_dilates': hparams['dilation_cycle_length'], - }, + backbone_type=self.backbone_type, + backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) @@ -74,12 +76,8 @@ def __init__(self, vocab_size, out_dims): num_feats=1, t_start=hparams['T_start'], time_scale_factor=hparams['time_scale_factor'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': hparams['residual_layers'], - 'n_chans': hparams['residual_channels'], - 'n_dilates': hparams['dilation_cycle_length'], - }, + backbone_type=self.backbone_type, + backbone_args=self.backbone_args, spec_min=hparams['spec_min'], spec_max=hparams['spec_max'] ) @@ -157,7 +155,14 @@ def __init__(self, vocab_size): self.pitch_retake_embed = Embedding(2, hparams['hidden_size']) pitch_hparams = hparams['pitch_prediction_args'] - + self.pitch_backbone_type = pitch_hparams.get('backbone_type', + hparams.get('backbone_type', + hparams.get('diff_decoder_type', 'wavenet'))) + self.pitch_backbone_args = pitch_hparams.get('backbone_args', { + 'num_layers': pitch_hparams.get('residual_layers'), + 'num_channels': pitch_hparams.get('residual_channels'), + 'dilation_cycle_length': pitch_hparams.get('dilation_cycle_length'), + } if self.pitch_backbone_type == 'wavenet' else None) if self.diffusion_type == 'ddpm': self.pitch_predictor = PitchDiffusion( vmin=pitch_hparams['pitd_norm_min'], @@ -167,12 +172,8 @@ def __init__(self, vocab_size): repeat_bins=pitch_hparams['repeat_bins'], timesteps=hparams['timesteps'], k_step=hparams['K_step'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': pitch_hparams['residual_layers'], - 'n_chans': pitch_hparams['residual_channels'], - 'n_dilates': pitch_hparams['dilation_cycle_length'], - } + backbone_type=self.pitch_backbone_type, + backbone_args=self.pitch_backbone_args ) elif self.diffusion_type == 'reflow': self.pitch_predictor = PitchRectifiedFlow( @@ -182,12 +183,8 @@ def __init__(self, vocab_size): cmax=pitch_hparams['pitd_clip_max'], repeat_bins=pitch_hparams['repeat_bins'], time_scale_factor=hparams['time_scale_factor'], - backbone_type=hparams.get('backbone_type', hparams.get('diff_decoder_type')), - backbone_args={ - 'n_layers': pitch_hparams['residual_layers'], - 'n_chans': pitch_hparams['residual_channels'], - 'n_dilates': pitch_hparams['dilation_cycle_length'], - } + backbone_type=self.pitch_backbone_type, + backbone_args=self.pitch_backbone_args ) else: raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")