From 5f51149de43eadf48eb31ae186d285e419002580 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:59:34 +0800 Subject: [PATCH 1/9] post-norm --- modules/backbones/LYNXNet.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index ffc5a94e5..b69aa220d 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -57,7 +57,6 @@ def calc_same_padding(kernel_size): def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): super().__init__() inner_dim = dim * expansion_factor - _normalize = nn.LayerNorm(dim) if in_norm or dim > 512 else nn.Identity() activation_classes = { 'SiLU': nn.SiLU, 'ReLU': nn.ReLU, @@ -73,7 +72,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activat else: _dropout = nn.Identity() self.net = nn.Sequential( - _normalize, + nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), SwiGLU(dim=1), @@ -137,9 +136,17 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in for i in range(n_layers) ] ) - self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) - nn.init.zeros_(self.output_projection.weight) - + self.norm = nn.LayerNorm(dim) + # self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) + # nn.init.zeros_(self.output_projection.weight) + _ = nn.Conv1d(dim * mlp_factor, mel_channels, kernel_size=1) + nn.init.zeros_(_.weight) + self.output_projection = nn.Sequential( + nn.Conv1d(dim, dim * mlp_factor, kernel_size=1), + nn.GELU(), + _, + ) + def forward(self, spec, diffusion_step, cond): """ :param spec: [B, F, M, T] @@ -164,6 +171,9 @@ def forward(self, spec, diffusion_step, cond): for layer in self.residual_layers: x = layer(x, cond, diffusion_step) + + # post-norm + x = self.norm(x) # MLP and GLU x = self.output_projection(x) # [B, 128, T] From cbcc018ed1cb2bb3a77f72abf1d709d837ad7fb3 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 4 Sep 2024 20:04:41 +0800 Subject: [PATCH 2/9] fix --- modules/backbones/LYNXNet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index b69aa220d..47cc993f5 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -136,13 +136,13 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in for i in range(n_layers) ] ) - self.norm = nn.LayerNorm(dim) + self.norm = nn.LayerNorm(n_chans) # self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) # nn.init.zeros_(self.output_projection.weight) - _ = nn.Conv1d(dim * mlp_factor, mel_channels, kernel_size=1) + _ = nn.Conv1d(n_chans * 4, in_dims * n_feats, kernel_size=1) nn.init.zeros_(_.weight) self.output_projection = nn.Sequential( - nn.Conv1d(dim, dim * mlp_factor, kernel_size=1), + nn.Conv1d(n_chans, n_chans * 4, kernel_size=1), nn.GELU(), _, ) From 69d33666f5f669cfd3b176fa0dfdd03f3fa8b29e Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 4 Sep 2024 22:30:15 +0800 Subject: [PATCH 3/9] add norm+mlp --- modules/aux_decoder/LYNXNetDecoder.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index 42b552fbf..102441e38 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -49,7 +49,13 @@ def __init__( activation='SiLU', dropout=dropout_rate) for _ in range(num_layers) ) - self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) + self.norm = nn.LayerNorm(num_channels) + # self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) + self.output_projection = nn.Sequential( + nn.Conv1d(num_channels, num_channels * 4, kernel_size=1), + nn.GELU(), + nn.Conv1d(num_channels * 4, out_dims, kernel_size=1), + ) def forward(self, x, infer=False): """ @@ -64,7 +70,8 @@ def forward(self, x, infer=False): for layer in self.encoder_layers: x = layer(x) x = x.transpose(1, 2) + x = self.norm(x) x = self.output_projection(x) x = x.transpose(1, 2) - return x \ No newline at end of file + return x From 63e4b5c26ce922e5d1f0f7cecd4cf09687101f79 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 4 Sep 2024 23:01:32 +0800 Subject: [PATCH 4/9] Update LYNXNet.py --- modules/backbones/LYNXNet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index 47cc993f5..eaf7559d2 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -173,7 +173,7 @@ def forward(self, spec, diffusion_step, cond): x = layer(x, cond, diffusion_step) # post-norm - x = self.norm(x) + x = self.norm(x.transpose(1, 2)).transpose(1, 2) # MLP and GLU x = self.output_projection(x) # [B, 128, T] From 4d5d6b399c1d21c5205ed875fa174bd7b102b152 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Wed, 4 Sep 2024 23:02:01 +0800 Subject: [PATCH 5/9] Update LYNXNetDecoder.py --- modules/aux_decoder/LYNXNetDecoder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index 102441e38..23104d055 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -69,8 +69,9 @@ def forward(self, x, infer=False): x = x.transpose(1, 2) for layer in self.encoder_layers: x = layer(x) - x = x.transpose(1, 2) x = self.norm(x) + x = x.transpose(1, 2) + x = self.output_projection(x) x = x.transpose(1, 2) From 02210966000abbff0a8c10adf527d2e30f75f6a7 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:42:52 +0800 Subject: [PATCH 6/9] do not need mlp --- modules/aux_decoder/LYNXNetDecoder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index 23104d055..c5ce6e1fa 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -50,12 +50,12 @@ def __init__( dropout=dropout_rate) for _ in range(num_layers) ) self.norm = nn.LayerNorm(num_channels) - # self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) - self.output_projection = nn.Sequential( - nn.Conv1d(num_channels, num_channels * 4, kernel_size=1), - nn.GELU(), - nn.Conv1d(num_channels * 4, out_dims, kernel_size=1), - ) + self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) + # self.output_projection = nn.Sequential( + # nn.Conv1d(num_channels, num_channels * 4, kernel_size=1), + # nn.GELU(), + # nn.Conv1d(num_channels * 4, out_dims, kernel_size=1), + # ) def forward(self, x, infer=False): """ From c606ce738eacbe71ce468e5285ee2dd71d2b8235 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Thu, 5 Sep 2024 21:43:44 +0800 Subject: [PATCH 7/9] do not need mlp --- modules/backbones/LYNXNet.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index eaf7559d2..1a5ecd15e 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -137,15 +137,15 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in ] ) self.norm = nn.LayerNorm(n_chans) - # self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) - # nn.init.zeros_(self.output_projection.weight) - _ = nn.Conv1d(n_chans * 4, in_dims * n_feats, kernel_size=1) - nn.init.zeros_(_.weight) - self.output_projection = nn.Sequential( - nn.Conv1d(n_chans, n_chans * 4, kernel_size=1), - nn.GELU(), - _, - ) + self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) + nn.init.zeros_(self.output_projection.weight) + # _ = nn.Conv1d(n_chans * 4, in_dims * n_feats, kernel_size=1) + # nn.init.zeros_(_.weight) + # self.output_projection = nn.Sequential( + # nn.Conv1d(n_chans, n_chans * 4, kernel_size=1), + # nn.GELU(), + # _, + # ) def forward(self, spec, diffusion_step, cond): """ From 537b240575f876eba29977ac91b5dec716784b75 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:34:27 +0800 Subject: [PATCH 8/9] Add out norm for LYNXNET --- modules/backbones/LYNXNet.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index 1a5ecd15e..b4527f98b 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -139,13 +139,6 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in self.norm = nn.LayerNorm(n_chans) self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) nn.init.zeros_(self.output_projection.weight) - # _ = nn.Conv1d(n_chans * 4, in_dims * n_feats, kernel_size=1) - # nn.init.zeros_(_.weight) - # self.output_projection = nn.Sequential( - # nn.Conv1d(n_chans, n_chans * 4, kernel_size=1), - # nn.GELU(), - # _, - # ) def forward(self, spec, diffusion_step, cond): """ From 0a64a0e75b688226ed5c28fd8304dc29dec52659 Mon Sep 17 00:00:00 2001 From: KakaruHayate <97896816+KakaruHayate@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:34:48 +0800 Subject: [PATCH 9/9] Add out norm for LYNXNETDecoder --- modules/aux_decoder/LYNXNetDecoder.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index c5ce6e1fa..4ac5923ee 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -51,11 +51,6 @@ def __init__( ) self.norm = nn.LayerNorm(num_channels) self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) - # self.output_projection = nn.Sequential( - # nn.Conv1d(num_channels, num_channels * 4, kernel_size=1), - # nn.GELU(), - # nn.Conv1d(num_channels * 4, out_dims, kernel_size=1), - # ) def forward(self, x, infer=False): """