diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py index a77eb0a38..cc21039f7 100644 --- a/modules/nsf_hifigan/models.py +++ b/modules/nsf_hifigan/models.py @@ -130,41 +130,20 @@ def _f02uv(self, f0): uv = uv * (f0 > self.voiced_threshold) return uv - def _f02sine(self, f0_values, upp): - """ f0_values: (batchsize, length, dim) + def _f02sine(self, f0, upp): + """ f0: (batchsize, length, dim) where dim indicates fundamental tone and overtones """ - rad_values = (f0_values / self.sampling_rate).fmod(1.) # %1意味着n_har的乘积无法后处理优化 - rand_ini = torch.rand(1, self.dim, device=f0_values.device) - rand_ini[:, 0] = 0 - rad_values[:, 0, :] += rand_ini - is_half = rad_values.dtype is not torch.float32 - tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 - if is_half: - tmp_over_one = tmp_over_one.half() - else: - tmp_over_one = tmp_over_one.float() - tmp_over_one *= upp - tmp_over_one = F.interpolate( - tmp_over_one.transpose(2, 1), scale_factor=upp, - mode='linear', align_corners=True - ).transpose(2, 1) - rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) - tmp_over_one = tmp_over_one.fmod(1.) - diff = F.conv2d( - tmp_over_one.unsqueeze(1), torch.FloatTensor([[[[-1.], [1.]]]]).to(tmp_over_one.device), - stride=(1, 1), padding=0, dilation=(1, 1) - ).squeeze(1) # Equivalent to torch.diff, but able to export ONNX - cumsum_shift = (diff < 0).double() - cumsum_shift = torch.cat(( - torch.zeros((1, 1, self.dim), dtype=torch.double).to(f0_values.device), - cumsum_shift - ), dim=1) - sines = torch.sin(torch.cumsum(rad_values.double() + cumsum_shift, dim=1) * 2 * np.pi) - if is_half: - sines = sines.half() - else: - sines = sines.float() + rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, device=f0.device) + rad2 = torch.fmod(rad[..., -1:].float() + 0.5, 1.0) - 0.5 + rad_acc = rad2.cumsum(dim=1).fmod(1.0).to(f0) + rad += F.pad(rad_acc, (0, 0, 1, -1)) + rad = rad.reshape(f0.shape[0], -1, 1) + rad = torch.multiply(rad, torch.arange(1, self.dim + 1, device=f0.device).reshape(1, 1, -1)) + rand_ini = torch.rand(1, 1, self.dim, device=f0.device) + rand_ini[..., 0] = 0 + rad += rand_ini + sines = torch.sin(2 * np.pi * rad) return sines @torch.no_grad() @@ -176,8 +155,7 @@ def forward(self, f0, upp): output uv: tensor(batchsize=1, length, 1) """ f0 = f0.unsqueeze(-1) - fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) - sine_waves = self._f02sine(fn, upp) * self.sine_amp + sine_waves = self._f02sine(f0, upp) * self.sine_amp uv = (f0 > self.voiced_threshold).float() uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3