From d2643d9cf2ca243b17b7bf968df2c0874fae1830 Mon Sep 17 00:00:00 2001 From: "llc1995@sina.com" Date: Mon, 24 Jul 2023 18:48:19 +0800 Subject: [PATCH 1/9] Add RMVPE pitch extractor --- modules/pe/__init__.py | 7 +- modules/pe/rmvpe.py | 18 ++++ modules/rmvpe/__init__.py | 5 + modules/rmvpe/constants.py | 9 ++ modules/rmvpe/deepunet.py | 189 +++++++++++++++++++++++++++++++++++++ modules/rmvpe/inference.py | 52 ++++++++++ modules/rmvpe/model.py | 60 ++++++++++++ modules/rmvpe/seq.py | 20 ++++ modules/rmvpe/spec.py | 66 +++++++++++++ modules/rmvpe/utils.py | 142 ++++++++++++++++++++++++++++ 10 files changed, 567 insertions(+), 1 deletion(-) create mode 100644 modules/pe/rmvpe.py create mode 100644 modules/rmvpe/__init__.py create mode 100644 modules/rmvpe/constants.py create mode 100644 modules/rmvpe/deepunet.py create mode 100644 modules/rmvpe/inference.py create mode 100644 modules/rmvpe/model.py create mode 100644 modules/rmvpe/seq.py create mode 100644 modules/rmvpe/spec.py create mode 100644 modules/rmvpe/utils.py diff --git a/modules/pe/__init__.py b/modules/pe/__init__.py index 2eac4c5cd..ab0df0058 100644 --- a/modules/pe/__init__.py +++ b/modules/pe/__init__.py @@ -1,9 +1,14 @@ from utils import hparams from .pm import ParselmouthPE - +from .rmvpe import RMVPE def initialize_pe(): pe = hparams.get('pe', 'parselmouth') + pe_ckpt = hparams.get('pe_ckpt', '') if pe == 'parselmouth': return ParselmouthPE() + elif pe == 'rmvpe': + return RMVPE(pe_ckpt) + else: + raise ValueError(f" [x] Unknown f0 extractor: {pe}") \ No newline at end of file diff --git a/modules/pe/rmvpe.py b/modules/pe/rmvpe.py new file mode 100644 index 000000000..484154eb4 --- /dev/null +++ b/modules/pe/rmvpe.py @@ -0,0 +1,18 @@ +import numpy as np +from basics.base_pe import BasePE +from modules.rmvpe.inference import RMVPE as rmvpe +from utils.pitch_utils import interp_f0 + +class RMVPE(BasePE): + def __init__(self, model_path): + self.rmvpe = rmvpe(model_path, hop_length=160) + + def get_pitch(self, waveform, length, hparams, interp_uv=False, speed=1): + hop_size = int(np.round(hparams['hop_size'] * speed)) + time_step = hop_size / hparams['audio_sample_rate'] + f0 = self.rmvpe.infer_from_audio(waveform, sample_rate=hparams['audio_sample_rate']) + f0 = np.array([f0[int(min(int(np.round(n * time_step / 0.01)), len(f0) - 1))] for n in range(length)]) + uv = f0 == 0 + if interp_uv: + f0, uv = interp_f0(f0, uv) + return f0, uv diff --git a/modules/rmvpe/__init__.py b/modules/rmvpe/__init__.py new file mode 100644 index 000000000..1e5425d2b --- /dev/null +++ b/modules/rmvpe/__init__.py @@ -0,0 +1,5 @@ +from .constants import * +from .model import E2E, E2E0 +from .utils import to_local_average_f0, to_viterbi_f0 +from .inference import RMVPE +from .spec import MelSpectrogram \ No newline at end of file diff --git a/modules/rmvpe/constants.py b/modules/rmvpe/constants.py new file mode 100644 index 000000000..525a2a0da --- /dev/null +++ b/modules/rmvpe/constants.py @@ -0,0 +1,9 @@ +SAMPLE_RATE = 16000 + +N_CLASS = 360 + +N_MELS = 128 +MEL_FMIN = 30 +MEL_FMAX = 8000 +WINDOW_LENGTH = 1024 +CONST = 1997.3794084376191 diff --git a/modules/rmvpe/deepunet.py b/modules/rmvpe/deepunet.py new file mode 100644 index 000000000..d0d5d777b --- /dev/null +++ b/modules/rmvpe/deepunet.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from .constants import N_MELS + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels, out_channels, momentum=0.01): + super(ConvBlockRes, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + + nn.Conv2d(in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + if self.is_shortcut: + return self.conv(x) + self.shortcut(x) + else: + return self.conv(x) + x + + +class ResEncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): + super(ResEncoderBlock, self).__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x): + for i in range(self.n_blocks): + x = self.conv[i](x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class ResDecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): + super(ResDecoderBlock, self).__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks-1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x, concat_tensor): + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i in range(self.n_blocks): + x = self.conv2[i](x) + return x + + +class Encoder(nn.Module): + def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): + super(Encoder, self).__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x): + concat_tensors = [] + x = self.bn(x) + for i in range(self.n_encoders): + _, x = self.layers[i](x) + concat_tensors.append(_) + return x, concat_tensors + + +class Intermediate(nn.Module): + def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): + super(Intermediate, self).__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) + for i in range(self.n_inters-1): + self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) + + def forward(self, x): + for i in range(self.n_inters): + x = self.layers[i](x) + return x + + +class Decoder(nn.Module): + def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): + super(Decoder, self).__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) + in_channels = out_channels + + def forward(self, x, concat_tensors): + for i in range(self.n_decoders): + x = self.layers[i](x, concat_tensors[-1-i]) + return x + + +class TimbreFilter(nn.Module): + def __init__(self, latent_rep_channels): + super(TimbreFilter, self).__init__() + self.layers = nn.ModuleList() + for latent_rep in latent_rep_channels: + self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) + + def forward(self, x_tensors): + out_tensors = [] + for i, layer in enumerate(self.layers): + out_tensors.append(layer(x_tensors[i])) + return out_tensors + + +class DeepUnet(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + concat_tensors = self.tf(concat_tensors) + x = self.decoder(x, concat_tensors) + return x + + +class DeepUnet0(nn.Module): + def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): + super(DeepUnet0, self).__init__() + self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) + self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) + self.tf = TimbreFilter(self.encoder.latent_channels) + self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) + + def forward(self, x): + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x diff --git a/modules/rmvpe/inference.py b/modules/rmvpe/inference.py new file mode 100644 index 000000000..e48f29c9e --- /dev/null +++ b/modules/rmvpe/inference.py @@ -0,0 +1,52 @@ +import numpy as np +import torch +import torch.nn.functional as F +from torchaudio.transforms import Resample +from .constants import * +from .model import E2E0, E2E +from .spec import MelSpectrogram +from .utils import to_local_average_f0, to_viterbi_f0 + +class RMVPE: + def __init__(self, model_path, hop_length=160): + self.resample_kernel = {} + model = E2E0(4, 1, (2, 2)) + ckpt = torch.load(model_path) + model.load_state_dict(ckpt['model'], strict=False) + model.eval() + self.model = model + self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) + self.resample_kernel = {} + + def mel2hidden(self, mel): + with torch.no_grad(): + n_frames = mel.shape[-1] + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') + hidden = self.model(mel) + return hidden[:, :n_frames] + + def decode(self, hidden, thred=0.03, use_viterbi=False): + if use_viterbi: + f0 = to_viterbi_f0(hidden, thred=thred) + else: + f0 = to_local_average_f0(hidden, thred=thred) + return f0 + + def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + audio = torch.from_numpy(audio).float().unsqueeze(0).to(device) + if sample_rate == 16000: + audio_res = audio + else: + key_str = str(sample_rate) + if key_str not in self.resample_kernel: + self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(device) + audio_res = self.resample_kernel[key_str](audio) + mel_extractor = self.mel_extractor.to(device) + self.model = self.model.to(device) + mel = mel_extractor(audio_res, center=True) + hidden = self.mel2hidden(mel) + f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) + return f0 \ No newline at end of file diff --git a/modules/rmvpe/model.py b/modules/rmvpe/model.py new file mode 100644 index 000000000..214788a38 --- /dev/null +++ b/modules/rmvpe/model.py @@ -0,0 +1,60 @@ +import torch +from torch import nn +from .deepunet import DeepUnet, DeepUnet0 +from .constants import * +from .spec import MelSpectrogram +from .seq import BiGRU + + +class E2E(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E, self).__init__() + self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), + nn.Linear(512, N_CLASS), + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x + + +class E2E0(nn.Module): + def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, + en_out_channels=16): + super(E2E0, self).__init__() + self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), + nn.Linear(512, N_CLASS), + nn.Dropout(0.25), + nn.Sigmoid() + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), + nn.Dropout(0.25), + nn.Sigmoid() + ) + + def forward(self, mel): + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x diff --git a/modules/rmvpe/seq.py b/modules/rmvpe/seq.py new file mode 100644 index 000000000..0d48e49d7 --- /dev/null +++ b/modules/rmvpe/seq.py @@ -0,0 +1,20 @@ +import torch.nn as nn + + +class BiGRU(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiGRU, self).__init__() + self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.gru(x)[0] + + +class BiLSTM(nn.Module): + def __init__(self, input_features, hidden_features, num_layers): + super(BiLSTM, self).__init__() + self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) + + def forward(self, x): + return self.lstm(x)[0] + diff --git a/modules/rmvpe/spec.py b/modules/rmvpe/spec.py new file mode 100644 index 000000000..f91fbc1dc --- /dev/null +++ b/modules/rmvpe/spec.py @@ -0,0 +1,66 @@ +import torch +import numpy as np +import torch.nn.functional as F +from librosa.filters import mel + +class MelSpectrogram(torch.nn.Module): + def __init__( + self, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp = 1e-5 + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + self.hann_window = {} + mel_basis = mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.n_fft = win_length if n_fft is None else n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + + def forward(self, audio, keyshift=0, speed=1, center=True): + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(self.n_fft * factor)) + win_length_new = int(np.round(self.win_length * factor)) + hop_length_new = int(np.round(self.hop_length * speed)) + + keyshift_key = str(keyshift)+'_'+str(audio.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) + + fft = torch.stft( + audio, + n_fft=n_fft_new, + hop_length=hop_length_new, + win_length=win_length_new, + window=self.hann_window[keyshift_key], + center=center, + return_complex=True) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + + if keyshift != 0: + size = self.n_fft // 2 + 1 + resize = magnitude.size(1) + if resize < size: + magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) + magnitude = magnitude[:, :size, :] * self.win_length / win_length_new + + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec \ No newline at end of file diff --git a/modules/rmvpe/utils.py b/modules/rmvpe/utils.py new file mode 100644 index 000000000..3dd817055 --- /dev/null +++ b/modules/rmvpe/utils.py @@ -0,0 +1,142 @@ +import sys +import numpy as np +import librosa +import torch +from functools import reduce +from .constants import * +from torch.nn.modules.module import _addindent + + +def cycle(iterable): + while True: + for item in iterable: + yield item + + +def summary(model, file=sys.stdout): + def repr(model): + # We treat the extra repr like the sub-module, one item per line + extra_lines = [] + extra_repr = model.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + total_params = 0 + for key, module in model._modules.items(): + mod_str, num_params = repr(module) + mod_str = _addindent(mod_str, 2) + child_lines.append('(' + key + '): ' + mod_str) + total_params += num_params + lines = extra_lines + child_lines + + for name, p in model._parameters.items(): + if hasattr(p, 'shape'): + total_params += reduce(lambda x, y: x * y, p.shape) + + main_str = model._get_name() + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + if file is sys.stdout: + main_str += ', \033[92m{:,}\033[0m params'.format(total_params) + else: + main_str += ', {:,} params'.format(total_params) + return main_str, total_params + + string, count = repr(model) + if file is not None: + if isinstance(file, str): + file = open(file, 'w') + print(string, file=file) + file.flush() + + return count + + +def to_local_average_cents(salience, center=None, thred=0.03): + """ + find the weighted average cents near the argmax bin + """ + + if not hasattr(to_local_average_cents, 'cents_mapping'): + # the bin number-to-cents mapping + to_local_average_cents.cents_mapping = ( + 20 * np.arange(N_CLASS) + CONST) + + if salience.ndim == 1: + if center is None: + center = int(np.argmax(salience)) + start = max(0, center - 4) + end = min(len(salience), center + 5) + salience = salience[start:end] + product_sum = np.sum( + salience * to_local_average_cents.cents_mapping[start:end]) + weight_sum = np.sum(salience) + return product_sum / weight_sum if np.max(salience) > thred else 0 + if salience.ndim == 2: + return np.array([to_local_average_cents(salience[i, :], None, thred) for i in + range(salience.shape[0])]) + + raise Exception("label should be either 1d or 2d ndarray") + +def to_viterbi_cents(salience, thred=0.03): + # Create viterbi transition matrix + if not hasattr(to_viterbi_cents, 'transition'): + xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) + transition = np.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_cents.transition = transition + + # Convert to probability + prob = salience.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) + + return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in + range(len(path))]) + +def to_local_average_f0(hidden, center=None, thred=0.03): + idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] + idx_cents = idx * 20 + CONST # [B=1, N] + if center is None: + center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] + start = torch.clip(center - 4, min=0) # [B, T, 1] + end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] + idx_mask = (idx >= start) & (idx < end) # [B, T, N] + weights = hidden * idx_mask # [B, T, N] + product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] + weight_sum = torch.sum(weights, dim=2) # [B, T] + cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] + f0 = 10 * 2 ** (cents / 1200) + uv = hidden.max(dim=2)[0] < thred # [B, T] + f0 = f0 * ~uv + return f0.squeeze(0).cpu().numpy() + +def to_viterbi_f0(hidden, thred=0.03): + # Create viterbi transition matrix + if not hasattr(to_viterbi_cents, 'transition'): + xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) + transition = np.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_cents.transition = transition + + # Convert to probability + prob = hidden.squeeze(0).cpu().numpy() + prob = prob.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) + center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) + + return to_local_average_f0(hidden, center=center, thred=thred) + + \ No newline at end of file From 9c68381aed489d8e0b873fe2b1977b44a5cfc5b3 Mon Sep 17 00:00:00 2001 From: "llc1995@sina.com" Date: Mon, 24 Jul 2023 20:38:09 +0800 Subject: [PATCH 2/9] Improve interpolation for rmvpe --- modules/pe/rmvpe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/modules/pe/rmvpe.py b/modules/pe/rmvpe.py index 484154eb4..37d16344c 100644 --- a/modules/pe/rmvpe.py +++ b/modules/pe/rmvpe.py @@ -2,17 +2,21 @@ from basics.base_pe import BasePE from modules.rmvpe.inference import RMVPE as rmvpe from utils.pitch_utils import interp_f0 +from utils.infer_utils import resample_align_curve class RMVPE(BasePE): def __init__(self, model_path): self.rmvpe = rmvpe(model_path, hop_length=160) def get_pitch(self, waveform, length, hparams, interp_uv=False, speed=1): - hop_size = int(np.round(hparams['hop_size'] * speed)) - time_step = hop_size / hparams['audio_sample_rate'] f0 = self.rmvpe.infer_from_audio(waveform, sample_rate=hparams['audio_sample_rate']) - f0 = np.array([f0[int(min(int(np.round(n * time_step / 0.01)), len(f0) - 1))] for n in range(length)]) uv = f0 == 0 - if interp_uv: - f0, uv = interp_f0(f0, uv) - return f0, uv + f0, uv = interp_f0(f0, uv) + + hop_size = int(np.round(hparams['hop_size'] * speed)) + time_step = hop_size / hparams['audio_sample_rate'] + f0_res = resample_align_curve(f0, 0.01, time_step, length) + uv_res = resample_align_curve(uv.astype(float), 0.01, time_step, length) > 0.5 + if not interp_uv: + f0_res[uv_res] = 0 + return f0_res, uv_res From 54b6f6feb6d70b012c36af6b675db5d289068bbe Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 15:43:37 +0800 Subject: [PATCH 3/9] Simplify code --- modules/pe/__init__.py | 5 +- modules/pe/rmvpe.py | 22 ----- modules/{ => pe}/rmvpe/__init__.py | 4 +- modules/{ => pe}/rmvpe/constants.py | 0 modules/{ => pe}/rmvpe/deepunet.py | 0 modules/{ => pe}/rmvpe/inference.py | 26 ++++- modules/{ => pe}/rmvpe/model.py | 32 +------ modules/{ => pe}/rmvpe/seq.py | 10 -- modules/{ => pe}/rmvpe/spec.py | 46 ++++----- modules/pe/rmvpe/utils.py | 45 +++++++++ modules/rmvpe/utils.py | 142 ---------------------------- 11 files changed, 97 insertions(+), 235 deletions(-) delete mode 100644 modules/pe/rmvpe.py rename modules/{ => pe}/rmvpe/__init__.py (63%) rename modules/{ => pe}/rmvpe/constants.py (100%) rename modules/{ => pe}/rmvpe/deepunet.py (100%) rename modules/{ => pe}/rmvpe/inference.py (70%) rename modules/{ => pe}/rmvpe/model.py (51%) rename modules/{ => pe}/rmvpe/seq.py (51%) rename modules/{ => pe}/rmvpe/spec.py (73%) create mode 100644 modules/pe/rmvpe/utils.py delete mode 100644 modules/rmvpe/utils.py diff --git a/modules/pe/__init__.py b/modules/pe/__init__.py index ab0df0058..2c0dcf773 100644 --- a/modules/pe/__init__.py +++ b/modules/pe/__init__.py @@ -3,12 +3,13 @@ from .pm import ParselmouthPE from .rmvpe import RMVPE + def initialize_pe(): pe = hparams.get('pe', 'parselmouth') - pe_ckpt = hparams.get('pe_ckpt', '') + pe_ckpt = hparams['pe_ckpt'] if pe == 'parselmouth': return ParselmouthPE() elif pe == 'rmvpe': return RMVPE(pe_ckpt) else: - raise ValueError(f" [x] Unknown f0 extractor: {pe}") \ No newline at end of file + raise ValueError(f" [x] Unknown f0 extractor: {pe}") diff --git a/modules/pe/rmvpe.py b/modules/pe/rmvpe.py deleted file mode 100644 index 37d16344c..000000000 --- a/modules/pe/rmvpe.py +++ /dev/null @@ -1,22 +0,0 @@ -import numpy as np -from basics.base_pe import BasePE -from modules.rmvpe.inference import RMVPE as rmvpe -from utils.pitch_utils import interp_f0 -from utils.infer_utils import resample_align_curve - -class RMVPE(BasePE): - def __init__(self, model_path): - self.rmvpe = rmvpe(model_path, hop_length=160) - - def get_pitch(self, waveform, length, hparams, interp_uv=False, speed=1): - f0 = self.rmvpe.infer_from_audio(waveform, sample_rate=hparams['audio_sample_rate']) - uv = f0 == 0 - f0, uv = interp_f0(f0, uv) - - hop_size = int(np.round(hparams['hop_size'] * speed)) - time_step = hop_size / hparams['audio_sample_rate'] - f0_res = resample_align_curve(f0, 0.01, time_step, length) - uv_res = resample_align_curve(uv.astype(float), 0.01, time_step, length) > 0.5 - if not interp_uv: - f0_res[uv_res] = 0 - return f0_res, uv_res diff --git a/modules/rmvpe/__init__.py b/modules/pe/rmvpe/__init__.py similarity index 63% rename from modules/rmvpe/__init__.py rename to modules/pe/rmvpe/__init__.py index 1e5425d2b..cf71a053e 100644 --- a/modules/rmvpe/__init__.py +++ b/modules/pe/rmvpe/__init__.py @@ -1,5 +1,5 @@ from .constants import * -from .model import E2E, E2E0 +from .model import E2E0 from .utils import to_local_average_f0, to_viterbi_f0 from .inference import RMVPE -from .spec import MelSpectrogram \ No newline at end of file +from .spec import MelSpectrogram diff --git a/modules/rmvpe/constants.py b/modules/pe/rmvpe/constants.py similarity index 100% rename from modules/rmvpe/constants.py rename to modules/pe/rmvpe/constants.py diff --git a/modules/rmvpe/deepunet.py b/modules/pe/rmvpe/deepunet.py similarity index 100% rename from modules/rmvpe/deepunet.py rename to modules/pe/rmvpe/deepunet.py diff --git a/modules/rmvpe/inference.py b/modules/pe/rmvpe/inference.py similarity index 70% rename from modules/rmvpe/inference.py rename to modules/pe/rmvpe/inference.py index e48f29c9e..7fcdd6105 100644 --- a/modules/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -2,11 +2,15 @@ import torch import torch.nn.functional as F from torchaudio.transforms import Resample + +from utils.infer_utils import resample_align_curve +from utils.pitch_utils import interp_f0 from .constants import * -from .model import E2E0, E2E -from .spec import MelSpectrogram +from .model import E2E0 +from .spec import MelSpectrogram from .utils import to_local_average_f0, to_viterbi_f0 + class RMVPE: def __init__(self, model_path, hop_length=160): self.resample_kernel = {} @@ -16,7 +20,6 @@ def __init__(self, model_path, hop_length=160): model.eval() self.model = model self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) - self.resample_kernel = {} def mel2hidden(self, mel): with torch.no_grad(): @@ -29,7 +32,7 @@ def decode(self, hidden, thred=0.03, use_viterbi=False): if use_viterbi: f0 = to_viterbi_f0(hidden, thred=thred) else: - f0 = to_local_average_f0(hidden, thred=thred) + f0 = to_local_average_f0(hidden, thred=thred) return f0 def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False): @@ -49,4 +52,17 @@ def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, us mel = mel_extractor(audio_res, center=True) hidden = self.mel2hidden(mel) f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) - return f0 \ No newline at end of file + return f0 + + def get_pitch(self, waveform, length, hparams, interp_uv=False, speed=1): + f0 = self.infer_from_audio(waveform, sample_rate=hparams['audio_sample_rate']) + uv = f0 == 0 + f0, uv = interp_f0(f0, uv) + + hop_size = int(np.round(hparams['hop_size'] * speed)) + time_step = hop_size / hparams['audio_sample_rate'] + f0_res = resample_align_curve(f0, 0.01, time_step, length) + uv_res = resample_align_curve(uv.astype(float), 0.01, time_step, length) > 0.5 + if not interp_uv: + f0_res[uv_res] = 0 + return f0_res, uv_res diff --git a/modules/rmvpe/model.py b/modules/pe/rmvpe/model.py similarity index 51% rename from modules/rmvpe/model.py rename to modules/pe/rmvpe/model.py index 214788a38..8bdeb43d6 100644 --- a/modules/rmvpe/model.py +++ b/modules/pe/rmvpe/model.py @@ -1,38 +1,10 @@ -import torch from torch import nn -from .deepunet import DeepUnet, DeepUnet0 + from .constants import * -from .spec import MelSpectrogram +from .deepunet import DeepUnet, DeepUnet0 from .seq import BiGRU -class E2E(nn.Module): - def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, - en_out_channels=16): - super(E2E, self).__init__() - self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) - self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) - if n_gru: - self.fc = nn.Sequential( - BiGRU(3 * N_MELS, 256, n_gru), - nn.Linear(512, N_CLASS), - nn.Dropout(0.25), - nn.Sigmoid() - ) - else: - self.fc = nn.Sequential( - nn.Linear(3 * N_MELS, N_CLASS), - nn.Dropout(0.25), - nn.Sigmoid() - ) - - def forward(self, mel): - mel = mel.transpose(-1, -2).unsqueeze(1) - x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) - x = self.fc(x) - return x - - class E2E0(nn.Module): def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): diff --git a/modules/rmvpe/seq.py b/modules/pe/rmvpe/seq.py similarity index 51% rename from modules/rmvpe/seq.py rename to modules/pe/rmvpe/seq.py index 0d48e49d7..9c4c8f880 100644 --- a/modules/rmvpe/seq.py +++ b/modules/pe/rmvpe/seq.py @@ -8,13 +8,3 @@ def __init__(self, input_features, hidden_features, num_layers): def forward(self, x): return self.gru(x)[0] - - -class BiLSTM(nn.Module): - def __init__(self, input_features, hidden_features, num_layers): - super(BiLSTM, self).__init__() - self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) - - def forward(self, x): - return self.lstm(x)[0] - diff --git a/modules/rmvpe/spec.py b/modules/pe/rmvpe/spec.py similarity index 73% rename from modules/rmvpe/spec.py rename to modules/pe/rmvpe/spec.py index f91fbc1dc..4a38054f8 100644 --- a/modules/rmvpe/spec.py +++ b/modules/pe/rmvpe/spec.py @@ -3,26 +3,27 @@ import torch.nn.functional as F from librosa.filters import mel + class MelSpectrogram(torch.nn.Module): def __init__( - self, - n_mel_channels, - sampling_rate, - win_length, - hop_length, - n_fft=None, - mel_fmin=0, - mel_fmax=None, - clamp = 1e-5 + self, + n_mel_channels, + sampling_rate, + win_length, + hop_length, + n_fft=None, + mel_fmin=0, + mel_fmax=None, + clamp=1e-5 ): super().__init__() n_fft = win_length if n_fft is None else n_fft self.hann_window = {} mel_basis = mel( sr=sampling_rate, - n_fft=n_fft, - n_mels=n_mel_channels, - fmin=mel_fmin, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, fmax=mel_fmax, htk=True) mel_basis = torch.from_numpy(mel_basis).float() @@ -35,15 +36,15 @@ def __init__( self.clamp = clamp def forward(self, audio, keyshift=0, speed=1, center=True): - factor = 2 ** (keyshift / 12) + factor = 2 ** (keyshift / 12) n_fft_new = int(np.round(self.n_fft * factor)) win_length_new = int(np.round(self.win_length * factor)) hop_length_new = int(np.round(self.hop_length * speed)) - - keyshift_key = str(keyshift)+'_'+str(audio.device) + + keyshift_key = str(keyshift) + '_' + str(audio.device) if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) - + fft = torch.stft( audio, n_fft=n_fft_new, @@ -51,16 +52,17 @@ def forward(self, audio, keyshift=0, speed=1, center=True): win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, - return_complex=True) - magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) - + return_complex=True + ) + magnitude = fft.abs() + if keyshift != 0: size = self.n_fft // 2 + 1 resize = magnitude.size(1) if resize < size: - magnitude = F.pad(magnitude, (0, 0, 0, size-resize)) + magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) magnitude = magnitude[:, :size, :] * self.win_length / win_length_new - + mel_output = torch.matmul(self.mel_basis, magnitude) log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) - return log_mel_spec \ No newline at end of file + return log_mel_spec diff --git a/modules/pe/rmvpe/utils.py b/modules/pe/rmvpe/utils.py new file mode 100644 index 000000000..34cf5ed15 --- /dev/null +++ b/modules/pe/rmvpe/utils.py @@ -0,0 +1,45 @@ +import sys +import numpy as np +import librosa +import torch +from functools import reduce +from .constants import * +from torch.nn.modules.module import _addindent + + +def to_local_average_f0(hidden, center=None, thred=0.03): + idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] + idx_cents = idx * 20 + CONST # [B=1, N] + if center is None: + center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] + start = torch.clip(center - 4, min=0) # [B, T, 1] + end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] + idx_mask = (idx >= start) & (idx < end) # [B, T, N] + weights = hidden * idx_mask # [B, T, N] + product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] + weight_sum = torch.sum(weights, dim=2) # [B, T] + cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] + f0 = 10 * 2 ** (cents / 1200) + uv = hidden.max(dim=2)[0] < thred # [B, T] + f0 = f0 * ~uv + return f0.squeeze(0).cpu().numpy() + + +def to_viterbi_f0(hidden, thred=0.03): + # Create viterbi transition matrix + if not hasattr(to_viterbi_f0, 'transition'): + xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) + transition = np.maximum(30 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + to_viterbi_f0.transition = transition + + # Convert to probability + prob = hidden.squeeze(0).cpu().numpy() + prob = prob.T + prob = prob / prob.sum(axis=0) + + # Perform viterbi decoding + path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64) + center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) + + return to_local_average_f0(hidden, center=center, thred=thred) diff --git a/modules/rmvpe/utils.py b/modules/rmvpe/utils.py deleted file mode 100644 index 3dd817055..000000000 --- a/modules/rmvpe/utils.py +++ /dev/null @@ -1,142 +0,0 @@ -import sys -import numpy as np -import librosa -import torch -from functools import reduce -from .constants import * -from torch.nn.modules.module import _addindent - - -def cycle(iterable): - while True: - for item in iterable: - yield item - - -def summary(model, file=sys.stdout): - def repr(model): - # We treat the extra repr like the sub-module, one item per line - extra_lines = [] - extra_repr = model.extra_repr() - # empty string will be split into list [''] - if extra_repr: - extra_lines = extra_repr.split('\n') - child_lines = [] - total_params = 0 - for key, module in model._modules.items(): - mod_str, num_params = repr(module) - mod_str = _addindent(mod_str, 2) - child_lines.append('(' + key + '): ' + mod_str) - total_params += num_params - lines = extra_lines + child_lines - - for name, p in model._parameters.items(): - if hasattr(p, 'shape'): - total_params += reduce(lambda x, y: x * y, p.shape) - - main_str = model._get_name() + '(' - if lines: - # simple one-liner info, which most builtin Modules will use - if len(extra_lines) == 1 and not child_lines: - main_str += extra_lines[0] - else: - main_str += '\n ' + '\n '.join(lines) + '\n' - - main_str += ')' - if file is sys.stdout: - main_str += ', \033[92m{:,}\033[0m params'.format(total_params) - else: - main_str += ', {:,} params'.format(total_params) - return main_str, total_params - - string, count = repr(model) - if file is not None: - if isinstance(file, str): - file = open(file, 'w') - print(string, file=file) - file.flush() - - return count - - -def to_local_average_cents(salience, center=None, thred=0.03): - """ - find the weighted average cents near the argmax bin - """ - - if not hasattr(to_local_average_cents, 'cents_mapping'): - # the bin number-to-cents mapping - to_local_average_cents.cents_mapping = ( - 20 * np.arange(N_CLASS) + CONST) - - if salience.ndim == 1: - if center is None: - center = int(np.argmax(salience)) - start = max(0, center - 4) - end = min(len(salience), center + 5) - salience = salience[start:end] - product_sum = np.sum( - salience * to_local_average_cents.cents_mapping[start:end]) - weight_sum = np.sum(salience) - return product_sum / weight_sum if np.max(salience) > thred else 0 - if salience.ndim == 2: - return np.array([to_local_average_cents(salience[i, :], None, thred) for i in - range(salience.shape[0])]) - - raise Exception("label should be either 1d or 2d ndarray") - -def to_viterbi_cents(salience, thred=0.03): - # Create viterbi transition matrix - if not hasattr(to_viterbi_cents, 'transition'): - xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) - transition = np.maximum(30 - abs(xx - yy), 0) - transition = transition / transition.sum(axis=1, keepdims=True) - to_viterbi_cents.transition = transition - - # Convert to probability - prob = salience.T - prob = prob / prob.sum(axis=0) - - # Perform viterbi decoding - path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) - - return np.array([to_local_average_cents(salience[i, :], path[i], thred) for i in - range(len(path))]) - -def to_local_average_f0(hidden, center=None, thred=0.03): - idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] - idx_cents = idx * 20 + CONST # [B=1, N] - if center is None: - center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] - start = torch.clip(center - 4, min=0) # [B, T, 1] - end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] - idx_mask = (idx >= start) & (idx < end) # [B, T, N] - weights = hidden * idx_mask # [B, T, N] - product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] - weight_sum = torch.sum(weights, dim=2) # [B, T] - cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] - f0 = 10 * 2 ** (cents / 1200) - uv = hidden.max(dim=2)[0] < thred # [B, T] - f0 = f0 * ~uv - return f0.squeeze(0).cpu().numpy() - -def to_viterbi_f0(hidden, thred=0.03): - # Create viterbi transition matrix - if not hasattr(to_viterbi_cents, 'transition'): - xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) - transition = np.maximum(30 - abs(xx - yy), 0) - transition = transition / transition.sum(axis=1, keepdims=True) - to_viterbi_cents.transition = transition - - # Convert to probability - prob = hidden.squeeze(0).cpu().numpy() - prob = prob.T - prob = prob / prob.sum(axis=0) - - # Perform viterbi decoding - path = librosa.sequence.viterbi(prob, to_viterbi_cents.transition).astype(np.int64) - center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) - - return to_local_average_f0(hidden, center=center, thred=thred) - - \ No newline at end of file From 8590b2ca0b3c8cfaf56a46f69aed545145af335b Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 15:43:59 +0800 Subject: [PATCH 4/9] Fix mismatching dtype --- utils/binarizer_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/binarizer_utils.py b/utils/binarizer_utils.py index 786bbf330..db8ba7768 100644 --- a/utils/binarizer_utils.py +++ b/utils/binarizer_utils.py @@ -56,7 +56,7 @@ def get_pitch_parselmouth(wav_data, length, hparams, speed=1, interp_uv=False): f0 = parselmouth.Sound(wav_data, sampling_frequency=hparams['audio_sample_rate']).to_pitch_ac( time_step=time_step, voicing_threshold=0.6, pitch_floor=f0_min, pitch_ceiling=f0_max - ).selected_array['frequency'] + ).selected_array['frequency'].astype(np.float32) f0 = pad_frames(f0, hop_size, wav_data.shape[0], length) uv = f0 == 0 if interp_uv: @@ -95,6 +95,7 @@ def get_breathiness_pyworld(wav_data, f0, length, hparams): fft_size = hparams['fft_size'] x = wav_data.astype(np.double) + f0 = f0.astype(np.double) wav_frames = (x.shape[0] + hop_size - 1) // hop_size f0_frames = f0.shape[0] if f0_frames < wav_frames: From e184df0af303a809ffe70e33cee520536c77ccf3 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 15:54:16 +0800 Subject: [PATCH 5/9] Constraint dtype to float32 --- modules/pe/rmvpe/inference.py | 2 +- utils/binarizer_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/pe/rmvpe/inference.py b/modules/pe/rmvpe/inference.py index 7fcdd6105..c3624b1cb 100644 --- a/modules/pe/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -62,7 +62,7 @@ def get_pitch(self, waveform, length, hparams, interp_uv=False, speed=1): hop_size = int(np.round(hparams['hop_size'] * speed)) time_step = hop_size / hparams['audio_sample_rate'] f0_res = resample_align_curve(f0, 0.01, time_step, length) - uv_res = resample_align_curve(uv.astype(float), 0.01, time_step, length) > 0.5 + uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5 if not interp_uv: f0_res[uv_res] = 0 return f0_res, uv_res diff --git a/utils/binarizer_utils.py b/utils/binarizer_utils.py index db8ba7768..0124dfc36 100644 --- a/utils/binarizer_utils.py +++ b/utils/binarizer_utils.py @@ -110,7 +110,7 @@ def get_breathiness_pyworld(wav_data, f0, length, hparams): y = pw.synthesize( f0, sp * ap * ap, np.ones_like(ap), sample_rate, frame_period=time_step * 1000 - ) # synthesize the aperiodic part using the parameters + ).astype(np.float32) # synthesize the aperiodic part using the parameters breathiness = get_energy_librosa(y, length, hparams) return breathiness From a38296699e09cd1d0cf73036a71be40f7b0a6ba8 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 16:07:00 +0800 Subject: [PATCH 6/9] Simplify device management --- modules/pe/rmvpe/inference.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/modules/pe/rmvpe/inference.py b/modules/pe/rmvpe/inference.py index c3624b1cb..8df68ed6e 100644 --- a/modules/pe/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -14,19 +14,20 @@ class RMVPE: def __init__(self, model_path, hop_length=160): self.resample_kernel = {} - model = E2E0(4, 1, (2, 2)) - ckpt = torch.load(model_path) - model.load_state_dict(ckpt['model'], strict=False) - model.eval() - self.model = model - self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.model = E2E0(4, 1, (2, 2)).eval().to(self.device) + ckpt = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(ckpt['model'], strict=False) + self.mel_extractor = MelSpectrogram( + N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX + ).to(self.device) + @torch.no_grad() def mel2hidden(self, mel): - with torch.no_grad(): - n_frames = mel.shape[-1] - mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') - hidden = self.model(mel) - return hidden[:, :n_frames] + n_frames = mel.shape[-1] + mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') + hidden = self.model(mel) + return hidden[:, :n_frames] def decode(self, hidden, thred=0.03, use_viterbi=False): if use_viterbi: @@ -35,21 +36,17 @@ def decode(self, hidden, thred=0.03, use_viterbi=False): f0 = to_local_average_f0(hidden, thred=thred) return f0 - def infer_from_audio(self, audio, sample_rate=16000, device=None, thred=0.03, use_viterbi=False): - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - audio = torch.from_numpy(audio).float().unsqueeze(0).to(device) + def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False): + audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) if sample_rate == 16000: audio_res = audio else: key_str = str(sample_rate) if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) - self.resample_kernel[key_str] = self.resample_kernel[key_str].to(device) + self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device) audio_res = self.resample_kernel[key_str](audio) - mel_extractor = self.mel_extractor.to(device) - self.model = self.model.to(device) - mel = mel_extractor(audio_res, center=True) + mel = self.mel_extractor(audio_res, center=True) hidden = self.mel2hidden(mel) f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) return f0 From ab02b7e73ee80d4cd698fa7dcec4b355da488f20 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 19:08:04 +0800 Subject: [PATCH 7/9] Remove unused code --- modules/pe/rmvpe/deepunet.py | 16 ---------------- modules/pe/rmvpe/model.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/modules/pe/rmvpe/deepunet.py b/modules/pe/rmvpe/deepunet.py index d0d5d777b..2e50d5e0b 100644 --- a/modules/pe/rmvpe/deepunet.py +++ b/modules/pe/rmvpe/deepunet.py @@ -158,22 +158,6 @@ def forward(self, x_tensors): return out_tensors -class DeepUnet(nn.Module): - def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): - super(DeepUnet, self).__init__() - self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) - self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) - self.tf = TimbreFilter(self.encoder.latent_channels) - self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) - - def forward(self, x): - x, concat_tensors = self.encoder(x) - x = self.intermediate(x) - concat_tensors = self.tf(concat_tensors) - x = self.decoder(x, concat_tensors) - return x - - class DeepUnet0(nn.Module): def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): super(DeepUnet0, self).__init__() diff --git a/modules/pe/rmvpe/model.py b/modules/pe/rmvpe/model.py index 8bdeb43d6..5b2d72cfb 100644 --- a/modules/pe/rmvpe/model.py +++ b/modules/pe/rmvpe/model.py @@ -1,7 +1,7 @@ from torch import nn from .constants import * -from .deepunet import DeepUnet, DeepUnet0 +from .deepunet import DeepUnet0 from .seq import BiGRU From c046ff687c55c3da31a3991108224ff4202af899 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 22:17:58 +0800 Subject: [PATCH 8/9] Add missing inheritance --- modules/pe/rmvpe/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/pe/rmvpe/inference.py b/modules/pe/rmvpe/inference.py index 8df68ed6e..49b785c53 100644 --- a/modules/pe/rmvpe/inference.py +++ b/modules/pe/rmvpe/inference.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from torchaudio.transforms import Resample +from basics.base_pe import BasePE from utils.infer_utils import resample_align_curve from utils.pitch_utils import interp_f0 from .constants import * @@ -11,7 +12,7 @@ from .utils import to_local_average_f0, to_viterbi_f0 -class RMVPE: +class RMVPE(BasePE): def __init__(self, model_path, hop_length=160): self.resample_kernel = {} self.device = 'cuda' if torch.cuda.is_available() else 'cpu' From b09007c16540dd05261ff776108e413dd3408620 Mon Sep 17 00:00:00 2001 From: yqzhishen Date: Tue, 25 Jul 2023 22:18:11 +0800 Subject: [PATCH 9/9] Optimize imports --- modules/pe/rmvpe/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/pe/rmvpe/utils.py b/modules/pe/rmvpe/utils.py index 34cf5ed15..9cdf0b1c3 100644 --- a/modules/pe/rmvpe/utils.py +++ b/modules/pe/rmvpe/utils.py @@ -1,10 +1,8 @@ -import sys -import numpy as np import librosa +import numpy as np import torch -from functools import reduce + from .constants import * -from torch.nn.modules.module import _addindent def to_local_average_f0(hidden, center=None, thred=0.03):