From a9875b94dde7fed23b9a824492bd16c112a263fa Mon Sep 17 00:00:00 2001 From: Alexandre Milesi Date: Tue, 17 Mar 2026 21:42:45 -0700 Subject: [PATCH 1/3] Optimize Parakeet feature extraction on CUDA --- .../parakeet/feature_extraction_parakeet.py | 85 +++++++++++++------ 1 file changed, 59 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/parakeet/feature_extraction_parakeet.py b/src/transformers/models/parakeet/feature_extraction_parakeet.py index c745d02c9629..86d81aa33782 100644 --- a/src/transformers/models/parakeet/feature_extraction_parakeet.py +++ b/src/transformers/models/parakeet/feature_extraction_parakeet.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools + import numpy as np import torch @@ -96,9 +98,28 @@ def __init__( ) self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) + @staticmethod + @functools.cache + def _get_window(win_length: int, device: str) -> torch.Tensor: + return torch.hann_window(win_length, periodic=False, device=device) + + @staticmethod + @functools.cache + def _get_mel_filters(feature_size: int, sampling_rate: int, n_fft: int, device: str) -> torch.Tensor: + mel_filters = librosa.filters.mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=feature_size, + fmin=0.0, + fmax=sampling_rate / 2, + norm="slaney", + ) + return torch.from_numpy(mel_filters).to(device=device, dtype=torch.float32) + def _torch_extract_fbank_features(self, waveform, device="cpu"): # spectrogram - window = torch.hann_window(self.win_length, periodic=False, device=device) + device = str(torch.device(device)) + window = self._get_window(self.win_length, device) stft = torch.stft( waveform, self.n_fft, @@ -110,12 +131,10 @@ def _torch_extract_fbank_features(self, waveform, device="cpu"): ) # Let's math original implementation # magnitudes = torch.abs(stft) ** 2 - magnitudes = torch.view_as_real(stft) - magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) - magnitudes = magnitudes.pow(2) + magnitudes = stft.real.square() + stft.imag.square() # log mel spectrogram - mel_filters = self.mel_filters.to(device) + mel_filters = self._get_mel_filters(self.feature_size, self.sampling_rate, self.n_fft, device) mel_spec = mel_filters @ magnitudes mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) @@ -205,11 +224,13 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) + device_obj = torch.device(device if device is not None else "cpu") + # Convert to torch tensor if isinstance(raw_speech, np.ndarray): - raw_speech = torch.tensor(raw_speech) + raw_speech = torch.as_tensor(raw_speech, device=device_obj) elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): - raw_speech = [torch.tensor(speech) for speech in raw_speech] + raw_speech = [torch.as_tensor(speech, device=device_obj) for speech in raw_speech] is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 if is_batched_torch and len(raw_speech.shape) > 2: @@ -230,38 +251,50 @@ def __call__( speech = speech.mean(-1) if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] + raw_speech = [speech[:, None].to(device=device_obj, dtype=torch.float32) for speech in raw_speech] else: - raw_speech = [raw_speech[:, None].to(torch.float32)] - - audio_lengths = [len(speech) for speech in raw_speech] - batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths}) - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ) - input_features = padded_inputs.input_features.squeeze(-1) + raw_speech = [raw_speech[:, None].to(device=device_obj, dtype=torch.float32)] + + audio_lengths = torch.tensor([len(speech) for speech in raw_speech], dtype=torch.long, device=device_obj) + if device_obj.type == "cuda": + max_audio_len = int(audio_lengths.max().item()) if raw_speech else 0 + input_features = torch.full( + (len(raw_speech), max_audio_len, raw_speech[0].shape[-1]), + fill_value=self.padding_value, + dtype=torch.float32, + device=device_obj, + ) + for idx, speech in enumerate(raw_speech): + input_features[idx, : len(speech)] = speech + input_features = input_features.squeeze(-1) + else: + batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths.tolist()}) + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ) + input_features = padded_inputs.input_features.squeeze(-1).to(device_obj) + audio_lengths = padded_inputs.audio_lengths.to(device_obj) # preemphasis if self.preemphasis is not None: timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze( 0 - ) < padded_inputs.audio_lengths.unsqueeze(1) + ) < audio_lengths.unsqueeze(1) input_features = torch.cat( [input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1 ) input_features = input_features.masked_fill(~timemask, 0.0) - input_features = self._torch_extract_fbank_features(input_features, device) + input_features = self._torch_extract_fbank_features(input_features, device_obj) features_lengths = torch.floor_divide( - padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length + audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length ) - attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None] + attention_mask = torch.arange(input_features.shape[1], device=input_features.device)[None, :] < features_lengths[:, None] # normalize mel features, ignoring padding mask = attention_mask.unsqueeze(-1) From 61e5e7dbfcff273ed7ea23d92a90ba19038e4349 Mon Sep 17 00:00:00 2001 From: Alexandre Milesi Date: Wed, 25 Mar 2026 16:27:22 -0700 Subject: [PATCH 2/3] Add dynamic compile to Parakeet feature extraction --- .../parakeet/feature_extraction_parakeet.py | 128 +++++++++--------- 1 file changed, 63 insertions(+), 65 deletions(-) diff --git a/src/transformers/models/parakeet/feature_extraction_parakeet.py b/src/transformers/models/parakeet/feature_extraction_parakeet.py index 86d81aa33782..bdb6846a3c60 100644 --- a/src/transformers/models/parakeet/feature_extraction_parakeet.py +++ b/src/transformers/models/parakeet/feature_extraction_parakeet.py @@ -129,19 +129,49 @@ def _torch_extract_fbank_features(self, waveform, device="cpu"): return_complex=True, pad_mode="constant", ) - # Let's math original implementation - # magnitudes = torch.abs(stft) ** 2 - magnitudes = stft.real.square() + stft.imag.square() - - # log mel spectrogram mel_filters = self._get_mel_filters(self.feature_size, self.sampling_rate, self.n_fft, device) + return self._apply_mel_filters(stft, mel_filters) + + @torch.compile(dynamic=True) + def _apply_mel_filters(self, stft_output: torch.Tensor, mel_filters: torch.Tensor) -> torch.Tensor: + magnitudes = stft_output.real.square() + stft_output.imag.square() mel_spec = mel_filters @ magnitudes mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) + return mel_spec.permute(0, 2, 1) + + @torch.compile(dynamic=True) + def _apply_preemphasis(self, input_features: torch.Tensor, audio_lengths: torch.Tensor) -> torch.Tensor: + if self.preemphasis is not None: + timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze( + 0 + ) < audio_lengths.unsqueeze(1) + input_features = torch.cat( + [input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1 + ) + input_features = input_features.masked_fill(~timemask, 0.0) + return input_features + + @torch.compile(dynamic=True) + def _normalize_mel_features(self, mel_features: torch.Tensor, audio_lengths: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # normalize mel features, ignoring padding + features_lengths = torch.floor_divide(audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length) + attention_mask = torch.arange(mel_features.shape[1], device=mel_features.device)[None, :] < features_lengths[:, None] - # (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters) - mel_spec = mel_spec.permute(0, 2, 1) + mask = attention_mask.unsqueeze(-1) + lengths = attention_mask.sum(dim=1) + mel_features_masked = mel_features * mask + mean = (mel_features_masked.sum(dim=1) / lengths.unsqueeze(-1)).unsqueeze(1) + variance = ((mel_features_masked - mean) ** 2 * mask).sum(dim=1) / (lengths - 1).unsqueeze(-1) + std = torch.sqrt(variance).unsqueeze(1) + return (mel_features - mean) / (std + EPSILON) * mask, attention_mask - return mel_spec + def _pad_raw_speech(self, raw_speech: list[torch.Tensor], max_len: int, device: str) -> torch.Tensor: + output = torch.full((len(raw_speech), max_len), self.padding_value, device=device, dtype=torch.float32) + dsts = [output[i, :raw_speech[i].shape[0]] for i in range(len(raw_speech))] + srcs = [s.squeeze(-1) for s in raw_speech] + # single kernel horizontal fusion + torch._foreach_copy_(dsts, srcs) + return output def __call__( self, @@ -224,13 +254,18 @@ def __call__( "Failing to do so can result in silent errors that might be hard to debug." ) - device_obj = torch.device(device if device is not None else "cpu") + device = device if device is not None else "cpu" # Convert to torch tensor if isinstance(raw_speech, np.ndarray): - raw_speech = torch.as_tensor(raw_speech, device=device_obj) - elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): - raw_speech = [torch.as_tensor(speech, device=device_obj) for speech in raw_speech] + raw_speech = torch.as_tensor(raw_speech, device=device) + elif isinstance(raw_speech, (list, tuple)) and len(raw_speech) > 0: + if isinstance(raw_speech[0], np.ndarray): + raw_speech = [torch.as_tensor(speech, device=device) for speech in raw_speech] + elif isinstance(raw_speech[0], (float, int)): + raw_speech = torch.tensor(raw_speech, device=device, dtype=torch.float32) + elif isinstance(raw_speech[0], (list, tuple)): + raw_speech = [torch.tensor(speech, device=device, dtype=torch.float32) for speech in raw_speech] is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 if is_batched_torch and len(raw_speech.shape) > 2: @@ -242,69 +277,32 @@ def __call__( is_batched_sequence = isinstance(raw_speech, (list, tuple)) if is_batched_sequence: - for speech in raw_speech: + for i, speech in enumerate(raw_speech): if len(speech.shape) > 1: logger.warning( f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " "We will take the mean of the channels to convert to mono." ) - speech = speech.mean(-1) + raw_speech[i] = speech.mean(-1) - if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(device=device_obj, dtype=torch.float32) for speech in raw_speech] + if is_batched_torch: + raw_speech = raw_speech.to(device=device, dtype=torch.float32) + elif is_batched_sequence: + raw_speech = [speech.to(device=device, dtype=torch.float32) for speech in raw_speech] else: - raw_speech = [raw_speech[:, None].to(device=device_obj, dtype=torch.float32)] - - audio_lengths = torch.tensor([len(speech) for speech in raw_speech], dtype=torch.long, device=device_obj) - if device_obj.type == "cuda": - max_audio_len = int(audio_lengths.max().item()) if raw_speech else 0 - input_features = torch.full( - (len(raw_speech), max_audio_len, raw_speech[0].shape[-1]), - fill_value=self.padding_value, - dtype=torch.float32, - device=device_obj, - ) - for idx, speech in enumerate(raw_speech): - input_features[idx, : len(speech)] = speech - input_features = input_features.squeeze(-1) - else: - batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths.tolist()}) - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ) - input_features = padded_inputs.input_features.squeeze(-1).to(device_obj) - audio_lengths = padded_inputs.audio_lengths.to(device_obj) + raw_speech = [raw_speech.to(device=device, dtype=torch.float32)] - # preemphasis - if self.preemphasis is not None: - timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze( - 0 - ) < audio_lengths.unsqueeze(1) - input_features = torch.cat( - [input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1 - ) - input_features = input_features.masked_fill(~timemask, 0.0) + audio_lengths = torch.tensor([len(speech) for speech in raw_speech], dtype=torch.long, device=device) - input_features = self._torch_extract_fbank_features(input_features, device_obj) - features_lengths = torch.floor_divide( - audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length - ) - attention_mask = torch.arange(input_features.shape[1], device=input_features.device)[None, :] < features_lengths[:, None] + if isinstance(raw_speech, torch.Tensor): + input_features = raw_speech + else: + max_length = max(len(speech) for speech in raw_speech) + input_features = self._pad_raw_speech(raw_speech, max_length, device) - # normalize mel features, ignoring padding - mask = attention_mask.unsqueeze(-1) - input_features_masked = input_features * mask - mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1) - mean = mean.unsqueeze(1) - variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) - std = torch.sqrt(variance).unsqueeze(1) - input_features = (input_features - mean) / (std + EPSILON) - input_features *= mask + input_features = self._apply_preemphasis(input_features, audio_lengths) + input_features = self._torch_extract_fbank_features(input_features, device) + input_features, attention_mask = self._normalize_mel_features(input_features, audio_lengths) return BatchFeature( data={ From 6703a6d11cb3f2c400f4e74ccc06fa850e7ca3ea Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 2 Apr 2026 22:08:08 -0700 Subject: [PATCH 3/3] Fix Parakeet feature extractor formatting Signed-off-by: milesial --- .../models/parakeet/feature_extraction_parakeet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/parakeet/feature_extraction_parakeet.py b/src/transformers/models/parakeet/feature_extraction_parakeet.py index bdb6846a3c60..b28231b81f53 100644 --- a/src/transformers/models/parakeet/feature_extraction_parakeet.py +++ b/src/transformers/models/parakeet/feature_extraction_parakeet.py @@ -152,10 +152,14 @@ def _apply_preemphasis(self, input_features: torch.Tensor, audio_lengths: torch. return input_features @torch.compile(dynamic=True) - def _normalize_mel_features(self, mel_features: torch.Tensor, audio_lengths: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _normalize_mel_features( + self, mel_features: torch.Tensor, audio_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # normalize mel features, ignoring padding features_lengths = torch.floor_divide(audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length) - attention_mask = torch.arange(mel_features.shape[1], device=mel_features.device)[None, :] < features_lengths[:, None] + attention_mask = ( + torch.arange(mel_features.shape[1], device=mel_features.device)[None, :] < features_lengths[:, None] + ) mask = attention_mask.unsqueeze(-1) lengths = attention_mask.sum(dim=1) @@ -167,7 +171,7 @@ def _normalize_mel_features(self, mel_features: torch.Tensor, audio_lengths: tor def _pad_raw_speech(self, raw_speech: list[torch.Tensor], max_len: int, device: str) -> torch.Tensor: output = torch.full((len(raw_speech), max_len), self.padding_value, device=device, dtype=torch.float32) - dsts = [output[i, :raw_speech[i].shape[0]] for i in range(len(raw_speech))] + dsts = [output[i, : raw_speech[i].shape[0]] for i in range(len(raw_speech))] srcs = [s.squeeze(-1) for s in raw_speech] # single kernel horizontal fusion torch._foreach_copy_(dsts, srcs)