From a19a2eb05eae12f3f1120a81c19b46217c705b7b Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Fri, 13 Dec 2024 16:29:57 +0800 Subject: [PATCH 1/4] Change the injection method of conditions on lynxnet (#225) --- configs/templates/config_acoustic.yaml | 2 +- configs/templates/config_variance.yaml | 2 ++ modules/backbones/lynxnet.py | 25 ++++++++++++++++--------- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index a9453a368..718339f3e 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -63,7 +63,7 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.0 + dropout_rate: 0.1 #backbone_type: 'wavenet' #backbone_args: # num_channels: 512 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index daa8e15dc..840ae88a5 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -87,6 +87,7 @@ pitch_prediction_args: # backbone_args: # num_layers: 6 # num_channels: 512 +# dropout_rate: 0.1 variances_prediction_args: total_repeat_bins: 48 @@ -99,6 +100,7 @@ variances_prediction_args: # backbone_args: # num_layers: 6 # num_channels: 384 +# dropout_rate: 0.1 lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 744967c6b..120e1d900 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -10,6 +10,12 @@ from utils.hparams import hparams +class Conv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + nn.init.kaiming_normal_(self.weight) + + class SwiGLU(nn.Module): # Swish-Applies the gated linear unit function. def __init__(self, dim=-1): @@ -39,7 +45,7 @@ def calc_same_padding(kernel_size): pad = kernel_size // 2 return pad, pad - (kernel_size + 1) % 2 - def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.): + def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1): super().__init__() inner_dim = dim * expansion_factor activation_classes = { @@ -57,7 +63,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dr else: _dropout = nn.Identity() self.net = nn.Sequential( - nn.LayerNorm(dim), + nn.LayerNorm(dim, eps=1e-6), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), SwiGLU(dim=1), @@ -73,7 +79,7 @@ def forward(self, x): class LYNXNetResidualLayer(nn.Module): - def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.): + def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1): super().__init__() self.diffusion_projection = nn.Conv1d(dim, dim, 1) self.conditioner_projection = nn.Conv1d(dim_cond, dim, 1) @@ -81,8 +87,9 @@ def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation=' activation=activation, dropout=dropout) def forward(self, x, conditioner, diffusion_step): + x = x + self.conditioner_projection(conditioner) res_x = x.transpose(1, 2) - x = x + self.diffusion_projection(diffusion_step) + self.conditioner_projection(conditioner) + x = x + self.diffusion_projection(diffusion_step) x = x.transpose(1, 2) x = self.convmodule(x) # (#batch, dim, length) x = x + res_x @@ -93,7 +100,7 @@ def forward(self, x, conditioner, diffusion_step): class LYNXNet(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31, - activation='PReLU', dropout=0.): + activation='PReLU', dropout=0.1): """ LYNXNet(Linear Gated Depthwise Separable Convolution Network) TIPS:You can control the style of the generated results by modifying the 'activation', @@ -104,7 +111,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio super().__init__() self.in_dims = in_dims self.n_feats = n_feats - self.input_projection = nn.Conv1d(in_dims * n_feats, num_channels, 1) + self.input_projection = Conv1d(in_dims * n_feats, num_channels, 1) self.diffusion_embedding = nn.Sequential( SinusoidalPosEmb(num_channels), nn.Linear(num_channels, num_channels * 4), @@ -124,8 +131,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio for i in range(num_layers) ] ) - self.norm = nn.LayerNorm(num_channels) - self.output_projection = nn.Conv1d(num_channels, in_dims * n_feats, kernel_size=1) + self.norm = nn.LayerNorm(num_channels, eps=1e-6) + self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1) nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): @@ -142,7 +149,7 @@ def forward(self, spec, diffusion_step, cond): x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T] x = self.input_projection(x) # x [B, residual_channel, T] - x = F.gelu(x) + # x = F.gelu(x) diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1) From 50bab77df356748a26d6473c56260d688a411e29 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 13 Dec 2024 17:31:18 +0800 Subject: [PATCH 2/4] update configurations for new-lynxnet --- configs/acoustic.yaml | 3 ++- configs/templates/config_acoustic.yaml | 1 + configs/templates/config_variance.yaml | 1 + modules/backbones/lynxnet.py | 24 ++++++++++++++---------- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 2cbc45303..4658ba366 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -78,7 +78,8 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.0 + dropout_rate: 0.1 + strong_cond: true main_loss_type: l2 main_loss_log_norm: false schedule_type: 'linear' diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 718339f3e..1fddd5f93 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -64,6 +64,7 @@ backbone_args: num_layers: 6 kernel_size: 31 dropout_rate: 0.1 + strong_cond: true #backbone_type: 'wavenet' #backbone_args: # num_channels: 512 diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 840ae88a5..4d8df4fd1 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -101,6 +101,7 @@ variances_prediction_args: # num_layers: 6 # num_channels: 384 # dropout_rate: 0.1 +# strong_cond: true lambda_dur_loss: 1.0 lambda_pitch_loss: 1.0 diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 120e1d900..25267235f 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -86,21 +86,23 @@ def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation=' self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size, activation=activation, dropout=dropout) - def forward(self, x, conditioner, diffusion_step): - x = x + self.conditioner_projection(conditioner) - res_x = x.transpose(1, 2) + def forward(self, x, conditioner, diffusion_step, front_cond_inject=False): + if front_cond_inject: + x = x + self.conditioner_projection(conditioner) + res_x = x + else: + res_x = x + x = x + self.conditioner_projection(conditioner) x = x + self.diffusion_projection(diffusion_step) x = x.transpose(1, 2) - x = self.convmodule(x) # (#batch, dim, length) - x = x + res_x - x = x.transpose(1, 2) - + x = self.convmodule(x) # (#batch, dim, length) + x = x.transpose(1, 2) + res_x return x # (#batch, length, dim) class LYNXNet(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31, - activation='PReLU', dropout=0.1): + activation='PReLU', dropout=0.1, strong_cond=False): """ LYNXNet(Linear Gated Depthwise Separable Convolution Network) TIPS:You can control the style of the generated results by modifying the 'activation', @@ -133,6 +135,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio ) self.norm = nn.LayerNorm(num_channels, eps=1e-6) self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1) + self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight) def forward(self, spec, diffusion_step, cond): @@ -149,12 +152,13 @@ def forward(self, spec, diffusion_step, cond): x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T] x = self.input_projection(x) # x [B, residual_channel, T] - # x = F.gelu(x) + if not self.strong_cond: + x = F.gelu(x) diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1) for layer in self.residual_layers: - x = layer(x, cond, diffusion_step) + x = layer(x, cond, diffusion_step, front_cond_inject=self.strong_cond) # post-norm x = self.norm(x.transpose(1, 2)).transpose(1, 2) From 0c844ec1beba4c34eeb5cb82f369b63fa209bff9 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Fri, 13 Dec 2024 19:36:12 +0800 Subject: [PATCH 3/4] update configurations for new-lynxnet --- configs/templates/config_variance.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 4d8df4fd1..09a8bbd80 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -88,6 +88,7 @@ pitch_prediction_args: # num_layers: 6 # num_channels: 512 # dropout_rate: 0.1 +# strong_cond: true variances_prediction_args: total_repeat_bins: 48 From 74ab9e4ffb046b4060508c0479c7602a499bdf96 Mon Sep 17 00:00:00 2001 From: yxlllc Date: Sat, 14 Dec 2024 01:44:04 +0800 Subject: [PATCH 4/4] update configurations for new-lynxnet --- configs/acoustic.yaml | 2 +- configs/templates/config_acoustic.yaml | 2 +- configs/templates/config_variance.yaml | 4 ++-- modules/backbones/lynxnet.py | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 4658ba366..6e0e6044c 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -78,7 +78,7 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.1 + dropout_rate: 0.0 strong_cond: true main_loss_type: l2 main_loss_log_norm: false diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 1fddd5f93..9be272fe0 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -63,7 +63,7 @@ backbone_args: num_channels: 1024 num_layers: 6 kernel_size: 31 - dropout_rate: 0.1 + dropout_rate: 0.0 strong_cond: true #backbone_type: 'wavenet' #backbone_args: diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 09a8bbd80..4a56ac865 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -87,7 +87,7 @@ pitch_prediction_args: # backbone_args: # num_layers: 6 # num_channels: 512 -# dropout_rate: 0.1 +# dropout_rate: 0.0 # strong_cond: true variances_prediction_args: @@ -101,7 +101,7 @@ variances_prediction_args: # backbone_args: # num_layers: 6 # num_channels: 384 -# dropout_rate: 0.1 +# dropout_rate: 0.0 # strong_cond: true lambda_dur_loss: 1.0 diff --git a/modules/backbones/lynxnet.py b/modules/backbones/lynxnet.py index 25267235f..18e7bf497 100644 --- a/modules/backbones/lynxnet.py +++ b/modules/backbones/lynxnet.py @@ -45,7 +45,7 @@ def calc_same_padding(kernel_size): pad = kernel_size // 2 return pad, pad - (kernel_size + 1) % 2 - def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1): + def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.0): super().__init__() inner_dim = dim * expansion_factor activation_classes = { @@ -63,7 +63,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dr else: _dropout = nn.Identity() self.net = nn.Sequential( - nn.LayerNorm(dim, eps=1e-6), + nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), SwiGLU(dim=1), @@ -79,7 +79,7 @@ def forward(self, x): class LYNXNetResidualLayer(nn.Module): - def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1): + def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.0): super().__init__() self.diffusion_projection = nn.Conv1d(dim, dim, 1) self.conditioner_projection = nn.Conv1d(dim_cond, dim, 1) @@ -102,7 +102,7 @@ def forward(self, x, conditioner, diffusion_step, front_cond_inject=False): class LYNXNet(nn.Module): def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31, - activation='PReLU', dropout=0.1, strong_cond=False): + activation='PReLU', dropout=0.0, strong_cond=False): """ LYNXNet(Linear Gated Depthwise Separable Convolution Network) TIPS:You can control the style of the generated results by modifying the 'activation', @@ -133,7 +133,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio for i in range(num_layers) ] ) - self.norm = nn.LayerNorm(num_channels, eps=1e-6) + self.norm = nn.LayerNorm(num_channels) self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1) self.strong_cond = strong_cond nn.init.zeros_(self.output_projection.weight)