diff --git a/modules/aux_decoder/LYNXNetDecoder.py b/modules/aux_decoder/LYNXNetDecoder.py index 42b552fbf..4ac5923ee 100644 --- a/modules/aux_decoder/LYNXNetDecoder.py +++ b/modules/aux_decoder/LYNXNetDecoder.py @@ -49,6 +49,7 @@ def __init__( activation='SiLU', dropout=dropout_rate) for _ in range(num_layers) ) + self.norm = nn.LayerNorm(num_channels) self.output_projection = nn.Conv1d(num_channels, out_dims, kernel_size=1) def forward(self, x, infer=False): @@ -63,8 +64,10 @@ def forward(self, x, infer=False): x = x.transpose(1, 2) for layer in self.encoder_layers: x = layer(x) + x = self.norm(x) x = x.transpose(1, 2) + x = self.output_projection(x) x = x.transpose(1, 2) - return x \ No newline at end of file + return x diff --git a/modules/backbones/LYNXNet.py b/modules/backbones/LYNXNet.py index ffc5a94e5..b4527f98b 100644 --- a/modules/backbones/LYNXNet.py +++ b/modules/backbones/LYNXNet.py @@ -57,7 +57,6 @@ def calc_same_padding(kernel_size): def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activation='PReLU', dropout=0.): super().__init__() inner_dim = dim * expansion_factor - _normalize = nn.LayerNorm(dim) if in_norm or dim > 512 else nn.Identity() activation_classes = { 'SiLU': nn.SiLU, 'ReLU': nn.ReLU, @@ -73,7 +72,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, in_norm=False, activat else: _dropout = nn.Identity() self.net = nn.Sequential( - _normalize, + nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), SwiGLU(dim=1), @@ -137,9 +136,10 @@ def __init__(self, in_dims, n_feats, *, n_layers=6, n_chans=512, n_dilates=2, in for i in range(n_layers) ] ) + self.norm = nn.LayerNorm(n_chans) self.output_projection = nn.Conv1d(n_chans, in_dims * n_feats, kernel_size=1) nn.init.zeros_(self.output_projection.weight) - + def forward(self, spec, diffusion_step, cond): """ :param spec: [B, F, M, T] @@ -164,6 +164,9 @@ def forward(self, spec, diffusion_step, cond): for layer in self.residual_layers: x = layer(x, cond, diffusion_step) + + # post-norm + x = self.norm(x.transpose(1, 2)).transpose(1, 2) # MLP and GLU x = self.output_projection(x) # [B, 128, T]