Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 91 additions & 56 deletions src/transformers/models/parakeet/feature_extraction_parakeet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

import numpy as np
import torch

Expand Down Expand Up @@ -96,9 +98,28 @@ def __init__(
)
self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32)

@staticmethod
@functools.cache
def _get_window(win_length: int, device: str) -> torch.Tensor:
return torch.hann_window(win_length, periodic=False, device=device)

@staticmethod
@functools.cache
def _get_mel_filters(feature_size: int, sampling_rate: int, n_fft: int, device: str) -> torch.Tensor:
mel_filters = librosa.filters.mel(
sr=sampling_rate,
n_fft=n_fft,
n_mels=feature_size,
fmin=0.0,
fmax=sampling_rate / 2,
norm="slaney",
)
return torch.from_numpy(mel_filters).to(device=device, dtype=torch.float32)

def _torch_extract_fbank_features(self, waveform, device="cpu"):
# spectrogram
window = torch.hann_window(self.win_length, periodic=False, device=device)
device = str(torch.device(device))
window = self._get_window(self.win_length, device)
stft = torch.stft(
waveform,
self.n_fft,
Expand All @@ -108,21 +129,53 @@ def _torch_extract_fbank_features(self, waveform, device="cpu"):
return_complex=True,
pad_mode="constant",
)
# Let's math original implementation
# magnitudes = torch.abs(stft) ** 2
magnitudes = torch.view_as_real(stft)
magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1))
magnitudes = magnitudes.pow(2)

# log mel spectrogram
mel_filters = self.mel_filters.to(device)
mel_filters = self._get_mel_filters(self.feature_size, self.sampling_rate, self.n_fft, device)
return self._apply_mel_filters(stft, mel_filters)

@torch.compile(dynamic=True)
def _apply_mel_filters(self, stft_output: torch.Tensor, mel_filters: torch.Tensor) -> torch.Tensor:
magnitudes = stft_output.real.square() + stft_output.imag.square()
mel_spec = mel_filters @ magnitudes
mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
return mel_spec.permute(0, 2, 1)

@torch.compile(dynamic=True)
def _apply_preemphasis(self, input_features: torch.Tensor, audio_lengths: torch.Tensor) -> torch.Tensor:
if self.preemphasis is not None:
timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
0
) < audio_lengths.unsqueeze(1)
input_features = torch.cat(
[input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1
)
input_features = input_features.masked_fill(~timemask, 0.0)
return input_features

# (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters)
mel_spec = mel_spec.permute(0, 2, 1)
@torch.compile(dynamic=True)
def _normalize_mel_features(
self, mel_features: torch.Tensor, audio_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# normalize mel features, ignoring padding
features_lengths = torch.floor_divide(audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length)
attention_mask = (
torch.arange(mel_features.shape[1], device=mel_features.device)[None, :] < features_lengths[:, None]
)

return mel_spec
mask = attention_mask.unsqueeze(-1)
lengths = attention_mask.sum(dim=1)
mel_features_masked = mel_features * mask
mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1)
variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (lengths - 1).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
return (mel_features - mean) / (std + EPSILON) * mask, attention_mask

def _pad_raw_speech(self, raw_speech: list[torch.Tensor], max_len: int, device: str) -> torch.Tensor:
output = torch.full((len(raw_speech), max_len), self.padding_value, device=device, dtype=torch.float32)
dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))]
srcs = [s.squeeze(-1) for s in raw_speech]
# single kernel horizontal fusion
torch._foreach_copy_(dsts, srcs)
return output

def __call__(
self,
Expand Down Expand Up @@ -205,11 +258,18 @@ def __call__(
"Failing to do so can result in silent errors that might be hard to debug."
)

device = device if device is not None else "cpu"

# Convert to torch tensor
if isinstance(raw_speech, np.ndarray):
raw_speech = torch.tensor(raw_speech)
elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
raw_speech = [torch.tensor(speech) for speech in raw_speech]
raw_speech = torch.as_tensor(raw_speech, device=device)
elif isinstance(raw_speech, (list, tuple)) and len(raw_speech) > 0:
if isinstance(raw_speech[0], np.ndarray):
raw_speech = [torch.as_tensor(speech, device=device) for speech in raw_speech]
elif isinstance(raw_speech[0], (float, int)):
raw_speech = torch.tensor(raw_speech, device=device, dtype=torch.float32)
elif isinstance(raw_speech[0], (list, tuple)):
raw_speech = [torch.tensor(speech, device=device, dtype=torch.float32) for speech in raw_speech]

is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
if is_batched_torch and len(raw_speech.shape) > 2:
Expand All @@ -221,57 +281,32 @@ def __call__(

is_batched_sequence = isinstance(raw_speech, (list, tuple))
if is_batched_sequence:
for speech in raw_speech:
for i, speech in enumerate(raw_speech):
if len(speech.shape) > 1:
logger.warning(
f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
"We will take the mean of the channels to convert to mono."
)
speech = speech.mean(-1)
raw_speech[i] = speech.mean(-1)

if is_batched_torch or is_batched_sequence:
raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
if is_batched_torch:
raw_speech = raw_speech.to(device=device, dtype=torch.float32)
elif is_batched_sequence:
raw_speech = [speech.to(device=device, dtype=torch.float32) for speech in raw_speech]
else:
raw_speech = [raw_speech[:, None].to(torch.float32)]

audio_lengths = [len(speech) for speech in raw_speech]
batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths})

padded_inputs = self.pad(
batched_speech,
padding=padding,
max_length=max_length,
truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)
input_features = padded_inputs.input_features.squeeze(-1)
raw_speech = [raw_speech.to(device=device, dtype=torch.float32)]

# preemphasis
if self.preemphasis is not None:
timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
0
) < padded_inputs.audio_lengths.unsqueeze(1)
input_features = torch.cat(
[input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1
)
input_features = input_features.masked_fill(~timemask, 0.0)
audio_lengths = torch.tensor([len(speech) for speech in raw_speech], dtype=torch.long, device=device)

input_features = self._torch_extract_fbank_features(input_features, device)
features_lengths = torch.floor_divide(
padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length
)
attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None]
if isinstance(raw_speech, torch.Tensor):
input_features = raw_speech
else:
max_length = max(len(speech) for speech in raw_speech)
input_features = self._pad_raw_speech(raw_speech, max_length, device)

# normalize mel features, ignoring padding
mask = attention_mask.unsqueeze(-1)
input_features_masked = input_features * mask
mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1)
mean = mean.unsqueeze(1)
variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1)
std = torch.sqrt(variance).unsqueeze(1)
input_features = (input_features - mean) / (std + EPSILON)
input_features *= mask
input_features = self._apply_preemphasis(input_features, audio_lengths)
input_features = self._torch_extract_fbank_features(input_features, device)
input_features, attention_mask = self._normalize_mel_features(input_features, audio_lengths)

return BatchFeature(
data={
Expand Down
Loading