diff --git a/.gitignore b/.gitignore
index 75f5a9998310..6ed5479ab0c4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -176,3 +176,7 @@ tags
# Cursor IDE files
.cursor/
test-results/
+src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py
+.gitignore
+tests/test_wav2vec2_whisper.py
+run_preprocessing_tests.sh
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index c66b077cac36..82b20dc39684 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -355,6 +355,7 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
+ _import_structure["audio_processing_backends"] = ["NumpyAudioBackend", "NumpyBackend", "TorchAudioBackend", "TorchBackend"]
_import_structure["model_debugging_utils"] = [
"model_addition_debugger_context",
]
@@ -477,6 +478,10 @@
if TYPE_CHECKING:
# All modeling imports
# Models
+ from .audio_processing_backends import NumpyAudioBackend as NumpyAudioBackend
+ from .audio_processing_backends import NumpyBackend as NumpyBackend
+ from .audio_processing_backends import TorchAudioBackend as TorchAudioBackend
+ from .audio_processing_backends import TorchBackend as TorchBackend
from .backbone_utils import BackboneConfigMixin, BackboneMixin
from .cache_utils import Cache as Cache
from .cache_utils import DynamicCache as DynamicCache
diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py
new file mode 100644
index 000000000000..20a4c8a1f4c8
--- /dev/null
+++ b/src/transformers/audio_processing_backends.py
@@ -0,0 +1,702 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+
+from .audio_processing_utils import BaseAudioProcessor
+from .audio_utils import SpectrogramConfig, amplitude_to_db, mel_filter_bank, power_to_db
+from .utils import PaddingStrategy, is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+ import torch
+
+
+# ── Torch frequency conversion utilities (used by TorchAudioBackend._mel_filter_bank) ──
+
+
+def _torch_hertz_to_mel_scalar(freq: float, mel_scale: str = "htk") -> float:
+ if mel_scale == "htk":
+ return 2595.0 * math.log10(1.0 + freq / 700.0)
+ elif mel_scale == "kaldi":
+ return 1127.0 * math.log(1.0 + freq / 700.0)
+ f_sp = 200.0 / 3
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - 0.0) / f_sp
+ logstep = math.log(6.4) / 27.0
+ if freq >= min_log_hz:
+ return min_log_mel + math.log(freq / min_log_hz) / logstep
+ return (freq - 0.0) / f_sp
+
+
+def _torch_hertz_to_mel(freq: "torch.Tensor", mel_scale: str = "htk") -> "torch.Tensor":
+ if mel_scale == "htk":
+ return 2595.0 * torch.log10(1.0 + freq / 700.0)
+ elif mel_scale == "kaldi":
+ return 1127.0 * torch.log(1.0 + freq / 700.0)
+ f_sp = 200.0 / 3
+ min_log_hertz = 1000.0
+ min_log_mel = min_log_hertz / f_sp
+ logstep = 27.0 / torch.log(torch.tensor(6.4))
+ mels = freq / f_sp
+ log_region = freq >= min_log_hertz
+ mels[log_region] = min_log_mel + torch.log(freq[log_region] / min_log_hertz) * logstep
+ return mels
+
+
+def _torch_mel_to_hertz(mels: "torch.Tensor", mel_scale: str = "htk") -> "torch.Tensor":
+ if mel_scale == "htk":
+ return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
+ elif mel_scale == "kaldi":
+ return 700.0 * (torch.exp(mels / 1127.0) - 1.0)
+ f_sp = 200.0 / 3
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - 0.0) / f_sp
+ logstep = math.log(6.4) / 27.0
+ freq = 0.0 + f_sp * mels
+ log_region = mels >= min_log_mel
+ freq[log_region] = min_log_hz * torch.exp(logstep * (mels[log_region] - min_log_mel))
+ return freq
+
+
+def _torch_triangular_filter_bank(fft_freqs, filter_freqs, computation_dtype=None):
+ """Compute triangular mel filter bank (shared by non-kaldi TorchAudioBackend paths)."""
+ num_mel_filters = len(filter_freqs) - 2
+ filter_diff = filter_freqs[1:] - filter_freqs[:-1]
+ slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ zero = torch.zeros(1, dtype=computation_dtype) if computation_dtype else torch.zeros(1)
+ return torch.clamp(torch.minimum(down_slopes, up_slopes), min=0)
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# NumpyAudioBackend
+# ═══════════════════════════════════════════════════════════════════════════════
+
+
+class NumpyAudioBackend(BaseAudioProcessor):
+ """NumPy backend for portable CPU-only audio processing."""
+
+ @property
+ def backend(self) -> str:
+ return "numpy"
+
+ # ── Audio input processing ────────────────────────────────────────────
+
+ def _process_audio(self, audio_el):
+ if not isinstance(audio_el, np.ndarray):
+ audio_el = np.asarray(audio_el)
+ if audio_el.ndim > 1:
+ if self.force_mono and audio_el.shape[0] > 1:
+ audio_el = audio_el.mean(axis=0)
+ elif audio_el.shape[0] == 1:
+ audio_el = np.squeeze(audio_el, axis=0)
+ else:
+ raise ValueError("Audio has more than one channel but force_mono is False")
+ return audio_el
+
+ # ── Padding & batching ────────────────────────────────────────────────
+
+ def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray:
+ current_length = audio.shape[-1]
+ if current_length >= max_length:
+ return audio
+ pad_length = max_length - current_length
+ if self.padding_side == "right":
+ pad_width = [(0, 0)] * (audio.ndim - 1) + [(0, pad_length)]
+ elif self.padding_side == "left":
+ pad_width = [(0, 0)] * (audio.ndim - 1) + [(pad_length, 0)]
+ else:
+ raise ValueError(f"Invalid padding side: {self.padding_side}")
+ return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value)
+
+ def _to_batch(self, audio):
+ batch = np.stack(audio)
+ if self.add_channel_dim:
+ batch = batch[:, np.newaxis, :]
+ return batch
+
+ def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of):
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
+ if truncation and max_length is not None:
+ features = [f[:max_length] for f in features]
+ actual_lengths = [f.shape[0] for f in features]
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = max(actual_lengths)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+ if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
+ features = [
+ np.pad(f, [(0, max_length - f.shape[0])] + [(0, 0)] * (f.ndim - 1),
+ mode="constant", constant_values=self.padding_value)
+ if f.shape[0] < max_length else f
+ for f in features
+ ]
+ return features, [(0, length) for length in actual_lengths]
+
+ def _stack_features(self, features):
+ return np.stack(features)
+
+ # ── Masking ───────────────────────────────────────────────────────────
+
+ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config):
+ use_audio_mask = self.mask_level == "audio"
+ if do_extract_spectrogram and not use_audio_mask:
+ spec_cfg = spectrogram_config or self.spectrogram_config
+ audio_lengths = np.array([end - start for start, end in audio_ranges])
+ features_lengths = self._get_features_lengths(audio_lengths, spec_cfg)
+ n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True)
+ mask = (np.arange(n_features)[None, :] < features_lengths[:, None]).astype(np.int32)
+ return {"audio_features_mask": mask}
+ mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32)
+ for i, (start, end) in enumerate(audio_ranges):
+ mask[i, start:end] = 1
+ return {("audio_features_mask" if do_extract_spectrogram else "audio_values_mask"): mask}
+
+ def _get_feature_mask(self, feature_ranges, padded_length):
+ mask = np.zeros((len(feature_ranges), padded_length), dtype=np.int32)
+ for i, (start, end) in enumerate(feature_ranges):
+ mask[i, start:end] = 1
+ return {"audio_features_mask": mask}
+
+ # ── STFT pipeline ─────────────────────────────────────────────────────
+
+ def _create_stft_window(self, win_length, stft_cfg, audio):
+ N = win_length + 1 if stft_cfg.periodic else win_length
+ fac = np.linspace(-np.pi, np.pi, N)
+ name = stft_cfg.window_fn
+ if name in ("hann", "hann_window"):
+ w = 0.5 + 0.5 * np.cos(fac)
+ elif name in ("hamming", "hamming_window"):
+ w = 0.54 + 0.46 * np.cos(fac)
+ elif name == "boxcar":
+ w = np.ones(N)
+ elif name == "povey":
+ w = (0.5 + 0.5 * np.cos(fac)) ** 0.85
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+ return w[:win_length] if stft_cfg.periodic else w
+
+ def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing):
+ if needs_manual_framing and win_length < n_fft:
+ return window, win_length
+ if win_length < n_fft:
+ left_pad = (n_fft - win_length) // 2
+ right_pad = n_fft - win_length - left_pad
+ window = np.pad(window, (left_pad, right_pad))
+ return window, n_fft
+
+ @staticmethod
+ def _np_frame(x, frame_length, hop_length):
+ """Create overlapping frames using stride tricks (replaces librosa.util.frame)."""
+ n_frames = 1 + (x.shape[-1] - frame_length) // hop_length
+ strides = x.strides[:-1] + (x.strides[-1] * hop_length, x.strides[-1])
+ shape = x.shape[:-1] + (n_frames, frame_length)
+ return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+
+ def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad_mode):
+ squeezed = waveform.ndim == 1
+ if squeezed:
+ waveform = waveform[np.newaxis, :]
+
+ if center:
+ start_k = int(np.ceil(n_fft // 2 / hop_length))
+ tail_k = (waveform.shape[-1] + n_fft // 2 - n_fft) // hop_length + 1
+
+ if tail_k <= start_k:
+ # Short audio: simple center-pad and index-based framing
+ waveform = np.pad(waveform, ((0, 0), (frame_length // 2, frame_length // 2)), mode=pad_mode)
+ num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length
+ frame_starts = np.arange(num_frames) * hop_length
+ frames = waveform[:, frame_starts[:, np.newaxis] + np.arange(frame_length)]
+ else:
+ # Long audio: split into pre (left-padded), middle (no pad), post (right-padded)
+ # to handle edge effects from center padding correctly
+ padding = [(0, 0) for _ in range(waveform.ndim)]
+
+ padding[-1] = (frame_length // 2, 0)
+ y_pre = np.pad(waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1], padding, mode=pad_mode)
+ y_frames_pre = self._np_frame(y_pre, frame_length, hop_length)[..., :start_k, :]
+
+ padding[-1] = (0, frame_length // 2)
+ y_post = np.pad(waveform[..., tail_k * hop_length - n_fft // 2 :], padding, mode=pad_mode)
+ y_frames_post = self._np_frame(y_post, frame_length, hop_length)
+
+ start = start_k * hop_length - n_fft // 2
+ y_frames_middle = self._np_frame(np.ascontiguousarray(waveform[..., start:]), frame_length, hop_length)
+
+ num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2]
+ frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2)
+ else:
+ # Non-centered: simple index-based framing
+ num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length
+ frame_starts = np.arange(num_frames) * hop_length
+ frames = waveform[:, frame_starts[:, np.newaxis] + np.arange(frame_length)]
+
+ if squeezed:
+ frames = frames.squeeze(0)
+ return frames, num_frames
+
+ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ frames, _ = self._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode)
+ compute_dtype = np.result_type(audio.dtype, window.dtype)
+ return frames.astype(compute_dtype, copy=False)
+
+ def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs):
+ if spectrogram_config.remove_dc_offset:
+ frames = frames - frames.mean(axis=-1, keepdims=True)
+ preemphasis = spectrogram_config.preemphasis
+ if preemphasis is not None:
+ preemph_src = preemphasis * frames[..., :-1]
+ frames[..., 1:] = frames[..., 1:] - preemph_src
+ frames[..., 0] = frames[..., 0] * (1 - preemphasis)
+ return frames
+
+ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None):
+ frames = frames * window
+ spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64)
+ if stft_cfg.normalized:
+ spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype)
+ return np.moveaxis(spec, -1, -2)
+
+ def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ frames, _ = self._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode)
+ compute_dtype = np.result_type(audio.dtype, window.dtype)
+ frames = frames.astype(compute_dtype, copy=False) * window
+ spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64)
+ if stft_cfg.normalized:
+ spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype)
+ return np.moveaxis(spec, -1, -2)
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ # computation_dtype signals that upstream FE used float64 magnitudes
+ if spectrogram_config and spectrogram_config.computation_dtype:
+ return np.abs(stft_out, dtype=np.float64) ** power
+ return np.abs(stft_out) ** power
+
+ # ── Mel scale & normalization ─────────────────────────────────────────
+
+ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig):
+ stft_cfg = spectrogram_config.stft_config
+ mel_cfg = spectrogram_config.mel_scale_config
+ # float32 dtype matches librosa's per-band rounding; computation_dtype keeps float64
+ filter_dtype = None if spectrogram_config.computation_dtype else np.float32
+ return mel_filter_bank(
+ num_frequency_bins=1 + stft_cfg.n_fft // 2,
+ num_mel_filters=mel_cfg.n_mels,
+ min_frequency=mel_cfg.f_min,
+ max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2,
+ sampling_rate=self.sample_rate,
+ norm=mel_cfg.norm,
+ mel_scale=mel_cfg.mel_scale,
+ triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space,
+ dtype=filter_dtype,
+ )
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ mel_filters = self.mel_filters.astype(features.dtype, copy=False)
+ if spectrogram_config.mel_scale_config.matmul_order == "features_first":
+ mel_spec = np.matmul(features, mel_filters)
+ else:
+ mel_spec = np.matmul(mel_filters.T, features)
+ return np.maximum(spectrogram_config.mel_floor, mel_spec)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config,
+ reference=1.0, min_value=1e-10, db_range=None,
+ dtype=np.float32, **kwargs):
+ log_mel = spectrogram_config.log_mode
+ if log_mel is None:
+ return features.astype(dtype)
+
+ mel_floor = spectrogram_config.mel_floor
+ result = np.maximum(mel_floor, features)
+
+ if log_mel == "log":
+ result = np.log(result).astype(dtype)
+ elif log_mel == "log10":
+ result = np.log10(result).astype(dtype)
+ elif log_mel == "dB":
+ power = spectrogram_config.stft_config.power
+ if power == 1.0:
+ result = amplitude_to_db(result, reference, min_value, db_range).astype(dtype)
+ elif power == 2.0:
+ result = power_to_db(result, reference, min_value, db_range).astype(dtype)
+ else:
+ raise ValueError(f"Cannot use log_mel option 'dB' with power {power}")
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+ return result
+
+ # ── Kaldi fbank helper ────────────────────────────────────────────────
+
+ def _kaldi_fbank(self, waveform, num_mel_bins, sample_frequency=None, **kwargs):
+ """Extract kaldi-compatible fbank features using torchaudio (or fallback to base pipeline).
+
+ Returns numpy array of shape (time, num_mel_bins).
+ """
+ from .utils import is_speech_available
+
+ if sample_frequency is None:
+ sample_frequency = self.sample_rate
+
+ if is_speech_available():
+ import torch
+ import torchaudio.compliance.kaldi as ta_kaldi
+
+ waveform_tensor = torch.from_numpy(np.asarray(waveform)).unsqueeze(0)
+ fbank = ta_kaldi.fbank(waveform_tensor, num_mel_bins=num_mel_bins,
+ sample_frequency=sample_frequency, **kwargs)
+ return fbank.numpy()
+
+ waveform = np.squeeze(waveform)
+ features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config)
+ return features[0].T
+
+
+# ═══════════════════════════════════════════════════════════════════════════════
+# TorchAudioBackend
+# ═══════════════════════════════════════════════════════════════════════════════
+
+
+class TorchAudioBackend(BaseAudioProcessor):
+ """Torch backend for audio processing."""
+
+ @property
+ def backend(self) -> str:
+ return "torch"
+
+ # ── Audio input processing ────────────────────────────────────────────
+
+ def _process_audio(self, audio_el):
+ if isinstance(audio_el, np.ndarray):
+ audio_el = torch.from_numpy(audio_el)
+ if audio_el.ndim > 1:
+ if self.force_mono and audio_el.shape[0] > 1:
+ audio_el = audio_el.mean(dim=0)
+ elif audio_el.shape[0] == 1:
+ audio_el = audio_el.squeeze(0)
+ else:
+ raise ValueError("Audio has more than one channel but force_mono is False")
+ return audio_el
+
+ # ── Padding & batching ────────────────────────────────────────────────
+
+ def _pad_single(self, audio, max_length):
+ current_length = audio.shape[-1]
+ if current_length >= max_length:
+ return audio
+ if self.padding_side == "right":
+ pad_args = (0, max_length - current_length)
+ elif self.padding_side == "left":
+ pad_args = (max_length - current_length, 0)
+ else:
+ raise ValueError(f"Invalid padding side: {self.padding_side}")
+ return torch.nn.functional.pad(audio, pad_args, "constant", self.padding_value)
+
+ def _to_batch(self, audio):
+ batch = torch.stack(audio)
+ if self.add_channel_dim:
+ batch = batch.unsqueeze(1)
+ return batch
+
+ def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of):
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
+ if truncation and max_length is not None:
+ features = [f[:max_length] for f in features]
+ actual_lengths = [f.shape[0] for f in features]
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = max(actual_lengths)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+ if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None:
+ padded = []
+ for f in features:
+ if f.shape[0] < max_length:
+ pad_args = [0, 0] * (f.ndim - 1) + [0, max_length - f.shape[0]]
+ f = torch.nn.functional.pad(f, pad_args, "constant", self.padding_value)
+ padded.append(f)
+ features = padded
+ return features, [(0, length) for length in actual_lengths]
+
+ def _stack_features(self, features):
+ return torch.stack(features)
+
+ # ── Masking ───────────────────────────────────────────────────────────
+
+ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config):
+ use_audio_mask = self.mask_level == "audio"
+ if do_extract_spectrogram and not use_audio_mask:
+ spec_cfg = spectrogram_config or self.spectrogram_config
+ audio_lengths = torch.tensor([end - start for start, end in audio_ranges])
+ features_lengths = self._get_features_lengths(audio_lengths, spec_cfg)
+ n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True)
+ mask = (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32)
+ return {"audio_features_mask": mask}
+ mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32)
+ for i, (start, end) in enumerate(audio_ranges):
+ mask[i, start:end] = 1
+ return {("audio_features_mask" if do_extract_spectrogram else "audio_values_mask"): mask}
+
+ def _get_feature_mask(self, feature_ranges, padded_length):
+ mask = torch.zeros((len(feature_ranges), padded_length), dtype=torch.int32)
+ for i, (start, end) in enumerate(feature_ranges):
+ mask[i, start:end] = 1
+ return {"audio_features_mask": mask}
+
+ # ── STFT pipeline ─────────────────────────────────────────────────────
+
+ def _needs_manual_framing(self, spectrogram_config):
+ return super()._needs_manual_framing(spectrogram_config) or spectrogram_config.stft_config.left_align_fft
+
+ def _create_stft_window(self, win_length, stft_cfg, audio):
+ dtype = getattr(torch, stft_cfg.window_dtype) if stft_cfg.window_dtype else audio.dtype
+ wkwargs = {**(stft_cfg.wkwargs or {}), "dtype": dtype}
+ name = stft_cfg.window_fn
+ if name in ("hann", "hann_window"):
+ window = torch.hann_window(win_length, periodic=stft_cfg.periodic, **wkwargs)
+ elif name in ("hamming", "hamming_window"):
+ window = torch.hamming_window(win_length, periodic=stft_cfg.periodic, **wkwargs)
+ elif name == "boxcar":
+ window = torch.ones(win_length)
+ elif name == "povey":
+ window = torch.hann_window(win_length, periodic=stft_cfg.periodic, **wkwargs).pow(0.85)
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+ return window.to(device=audio.device)
+
+ def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing):
+ if needs_manual_framing and win_length < n_fft:
+ return window, win_length
+ if win_length < n_fft:
+ left_pad = (n_fft - win_length) // 2
+ right_pad = n_fft - win_length - left_pad
+ window = torch.nn.functional.pad(window, (left_pad, right_pad))
+ return window, n_fft
+
+ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ if stft_cfg.center:
+ audio = torch.nn.functional.pad(audio, (frame_length // 2, frame_length // 2), mode=stft_cfg.pad_mode)
+ return audio.unfold(-1, frame_length, hop_length)
+
+ def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs):
+ if spectrogram_config.remove_dc_offset:
+ frames = frames - frames.mean(dim=-1, keepdim=True)
+ preemphasis = spectrogram_config.preemphasis
+ if preemphasis is not None:
+ frames = torch.cat([
+ frames[..., :1] * (1 - preemphasis),
+ frames[..., 1:] - preemphasis * frames[..., :-1],
+ ], dim=-1)
+ return frames
+
+ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None):
+ frames = frames * window
+ if frame_length < n_fft:
+ frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length))
+ spec = torch.fft.rfft(frames, n=n_fft)
+ if stft_cfg.normalized:
+ spec = spec / window.pow(2.0).sum().sqrt()
+ return spec.transpose(-2, -1)
+
+ def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ stft_out = torch.stft(
+ audio, n_fft=n_fft, hop_length=hop_length, win_length=frame_length,
+ window=window, center=stft_cfg.center, pad_mode=stft_cfg.pad_mode,
+ normalized=False, return_complex=True,
+ )
+ if stft_cfg.normalized:
+ stft_out = stft_out / window.pow(2.0).sum().sqrt()
+ return stft_out
+
+ def _cast_stft_output(self, magnitudes, spectrogram_config):
+ if spectrogram_config.computation_dtype:
+ return magnitudes
+ return magnitudes.float()
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ return stft_out.abs() ** power
+
+ # ── Mel scale & normalization ─────────────────────────────────────────
+
+ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig):
+ stft_cfg = spectrogram_config.stft_config
+ mel_cfg = spectrogram_config.mel_scale_config
+ computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None
+ num_frequency_bins = 1 + stft_cfg.n_fft // 2
+ num_mel_filters = mel_cfg.n_mels
+ min_frequency = mel_cfg.f_min
+ max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2
+ n_fft = (num_frequency_bins - 1) * 2
+
+ if mel_cfg.triangularize_in_mel_space and mel_cfg.bands_to_zero == 0:
+ # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks
+ mel_filters = self._kaldi_exact_mel_banks(
+ num_mel_filters, num_frequency_bins, min_frequency, max_frequency,
+ self.sample_rate, n_fft,
+ )
+ elif mel_cfg.triangularize_in_mel_space:
+ mel_filters = self._kaldi_mel_banks_with_zero_bands(
+ num_mel_filters, num_frequency_bins, min_frequency, max_frequency,
+ self.sample_rate, n_fft, mel_cfg, computation_dtype,
+ )
+ else:
+ mel_filters = self._standard_mel_banks(
+ num_mel_filters, num_frequency_bins, min_frequency, max_frequency,
+ self.sample_rate, n_fft, mel_cfg, computation_dtype,
+ )
+
+ # Cast back when mel computation_dtype doesn't match spectrogram computation_dtype
+ if computation_dtype is not None and not spectrogram_config.computation_dtype:
+ mel_filters = mel_filters.to(torch.get_default_dtype())
+ return mel_filters
+
+ @staticmethod
+ def _kaldi_exact_mel_banks(num_mel_filters, num_frequency_bins, min_frequency,
+ max_frequency, sampling_rate, n_fft):
+ """Matches torchaudio.compliance.kaldi.get_mel_banks exactly."""
+ num_fft_bins = n_fft // 2
+ fft_bin_width = sampling_rate / n_fft
+ mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0)
+ mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0)
+ mel_delta = (mel_high - mel_low) / (num_mel_filters + 1)
+
+ bin_idx = torch.arange(num_mel_filters).unsqueeze(1)
+ left_mel = mel_low + bin_idx * mel_delta
+ center_mel = mel_low + (bin_idx + 1.0) * mel_delta
+ right_mel = mel_low + (bin_idx + 2.0) * mel_delta
+
+ mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log()
+ mel = mel.unsqueeze(0)
+
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+ banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
+ banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0)
+ return banks.T
+
+ @staticmethod
+ def _kaldi_mel_banks_with_zero_bands(num_mel_filters, num_frequency_bins, min_frequency,
+ max_frequency, sampling_rate, n_fft, mel_cfg, computation_dtype):
+ """Kaldi-style with bands_to_zero > 0."""
+ mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale)
+ mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale)
+ mel_delta = (mel_max - mel_min) / (num_mel_filters + 1)
+ bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1)
+ left_mel = mel_min + bin_idx * mel_delta
+ center_mel = mel_min + (bin_idx + 1.0) * mel_delta
+ right_mel = mel_min + (bin_idx + 2.0) * mel_delta
+
+ fft_bin_width = sampling_rate / n_fft
+ hz_freqs = fft_bin_width * torch.arange(mel_cfg.bands_to_zero, num_frequency_bins, dtype=computation_dtype)
+ mel = _torch_hertz_to_mel(hz_freqs, mel_scale=mel_cfg.mel_scale).unsqueeze(0)
+
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+ zero = torch.zeros(1, dtype=computation_dtype)
+ mel_filters = torch.max(zero, torch.min(up_slope, down_slope)).T
+ if mel_cfg.bands_to_zero > 0:
+ mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0))
+ return mel_filters
+
+ @staticmethod
+ def _standard_mel_banks(num_mel_filters, num_frequency_bins, min_frequency,
+ max_frequency, sampling_rate, n_fft, mel_cfg, computation_dtype):
+ """Standard (non-kaldi) triangular mel filter bank."""
+ mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale)
+ mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale)
+ mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype)
+ filter_freqs = _torch_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale)
+
+ if mel_cfg.frequency_bin_mode == "rfft":
+ fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate)
+ else:
+ fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins)
+ if computation_dtype is not None:
+ fft_freqs = fft_freqs.to(computation_dtype)
+
+ filter_diff = filter_freqs[1:] - filter_freqs[:-1]
+ slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ mel_filters = torch.clamp(torch.minimum(down_slopes, up_slopes), min=0)
+
+ if mel_cfg.norm == "slaney":
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
+ mel_filters = mel_filters * enorm.unsqueeze(0)
+
+ if mel_cfg.bands_to_zero > 0:
+ mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0))
+ return mel_filters
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ mel_filters = self.mel_filters.to(device=features.device)
+ if spectrogram_config.mel_scale_config.matmul_order == "features_first":
+ mel_spec = torch.matmul(features.transpose(-2, -1), mel_filters)
+ else:
+ # F.linear matches torchaudio's MelScale implementation exactly
+ mel_spec = torch.nn.functional.linear(features.transpose(-2, -1), mel_filters.T).transpose(-2, -1)
+ return torch.clamp(mel_spec, min=spectrogram_config.mel_floor)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config,
+ reference=1.0, min_value=1e-10, db_range=None,
+ dtype=None, **kwargs):
+ log_mel = spectrogram_config.log_mode
+ mel_floor = spectrogram_config.mel_floor
+ power = spectrogram_config.stft_config.power
+ if dtype is None:
+ dtype = torch.float32
+
+ if log_mel is None:
+ return features
+
+ result = torch.clamp(features, min=mel_floor)
+
+ if log_mel == "log":
+ result = torch.log(result).to(dtype)
+ elif log_mel == "log10":
+ result = torch.log10(result).to(dtype)
+ elif log_mel == "dB":
+ if reference <= 0.0:
+ raise ValueError("reference must be greater than zero")
+ if min_value <= 0.0:
+ raise ValueError("min_value must be greater than zero")
+ reference = max(min_value, reference)
+ multiplier = 10.0 if power == 2.0 else 20.0 if power == 1.0 else None
+ if multiplier is None:
+ raise ValueError(f"Cannot use log_mel option 'dB' with power {power}")
+ log_ref = torch.log10(torch.tensor(reference, dtype=result.dtype, device=result.device))
+ result = torch.clamp(result, min=min_value)
+ result = multiplier * (torch.log10(result) - log_ref)
+ if db_range is not None:
+ if db_range <= 0.0:
+ raise ValueError("db_range must be greater than zero")
+ max_vals = result.amax(dim=-2, keepdim=True) if result.ndim > 2 else result.max()
+ result = torch.clamp(result, min=max_vals - db_range)
+ result = result.to(dtype)
+ else:
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+ if spectrogram_config.skip_last_frame:
+ result = result[..., :-1]
+
+ return result
diff --git a/src/transformers/audio_processing_base.py b/src/transformers/audio_processing_base.py
new file mode 100644
index 000000000000..6fba8be02082
--- /dev/null
+++ b/src/transformers/audio_processing_base.py
@@ -0,0 +1,148 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import warnings
+from typing import Any, ClassVar, TypeVar
+
+from .audio_utils import is_valid_audio, load_audio
+from .feature_extraction_utils import BatchFeature as BaseBatchFeature
+from .preprocessing_base import PreprocessingMixin
+from .utils import (
+ FEATURE_EXTRACTOR_NAME,
+ copy_func,
+ logging,
+)
+
+
+_LEGACY_KEY_MAP = {
+ "input_features": "audio_features",
+ "input_values": "audio_values",
+ "audio_input_features": "audio_features",
+}
+
+
+AudioProcessorType = TypeVar("AudioProcessorType", bound="AudioProcessingMixin")
+
+
+logger = logging.get_logger(__name__)
+
+
+class BatchFeature(BaseBatchFeature):
+ r"""
+ Holds the output of the audio processor specific `__call__` methods.
+
+ This class is derived from a python dictionary and can be used as a dictionary.
+
+ Args:
+ data (`dict`):
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('input_values', 'input_features', etc.).
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
+ You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at
+ initialization.
+ """
+
+ _warned_keys: ClassVar[set] = set()
+
+ def __getitem__(self, item):
+ if isinstance(item, str) and item not in self.data:
+ new_key = self._resolve_legacy_key(item)
+ if new_key is not None and new_key in self.data:
+ if item not in BatchFeature._warned_keys:
+ warnings.warn(
+ f"Accessing '{item}' is deprecated, use '{new_key}' instead.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ BatchFeature._warned_keys.add(item)
+ return self.data[new_key]
+ return super().__getitem__(item)
+
+ def __contains__(self, item):
+ if item in self.data:
+ return True
+ new_key = self._resolve_legacy_key(item)
+ return new_key is not None and new_key in self.data
+
+ def _resolve_legacy_key(self, old_key):
+ if old_key in ("attention_mask", "padding_mask"):
+ if "audio_features_mask" in self.data:
+ return "audio_features_mask"
+ if "audio_values_mask" in self.data:
+ return "audio_values_mask"
+ return None
+ return _LEGACY_KEY_MAP.get(old_key)
+
+
+class AudioProcessingMixin(PreprocessingMixin):
+ """
+ This is an audio processor mixin used to provide saving/loading functionality for audio processors.
+ """
+
+ _config_name = FEATURE_EXTRACTOR_NAME
+ _type_key = "audio_processor_type"
+ _nested_config_keys = ["audio_processor", "feature_extractor"]
+ _auto_class_default = "AutoFeatureExtractor"
+ _file_type_label = "audio processor"
+ _excluded_dict_keys = {"mel_filters", "window"}
+ _extra_init_pops = ["feature_extractor_type"]
+ _config_filename_kwarg = "audio_processor_filename"
+ _subfolder_default = ""
+
+ @classmethod
+ def get_audio_processor_dict(
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating an
+ audio processor of type [`~audio_processing_base.AudioProcessingMixin`] using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+ subfolder (`str`, *optional*, defaults to `""`):
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+ specify the folder name here.
+ audio_processor_filename (`str`, *optional*, defaults to `"preprocessor_config.json"`):
+ The name of the file in the model directory to use for the audio processor config.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the audio processor object.
+ """
+ return cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ def fetch_audio(self, audio_url_or_urls: str | list[str] | list[list[str]], sampling_rate: int | None = None):
+ """
+ Convert a single or a list of urls into the corresponding `np.ndarray` objects.
+
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+ returned.
+ """
+ if sampling_rate is None:
+ sampling_rate = getattr(self, "sample_rate", 16000)
+ if isinstance(audio_url_or_urls, list):
+ return [self.fetch_audio(x, sampling_rate=sampling_rate) for x in audio_url_or_urls]
+ elif isinstance(audio_url_or_urls, str):
+ return load_audio(audio_url_or_urls, sampling_rate=sampling_rate)
+ elif is_valid_audio(audio_url_or_urls):
+ return audio_url_or_urls
+ else:
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(audio_url_or_urls)}")
+
+
+AudioProcessingMixin.push_to_hub = copy_func(AudioProcessingMixin.push_to_hub)
+if AudioProcessingMixin.push_to_hub.__doc__ is not None:
+ AudioProcessingMixin.push_to_hub.__doc__ = AudioProcessingMixin.push_to_hub.__doc__.format(
+ object="audio processor", object_class="AutoFeatureExtractor", object_files="audio processor file"
+ )
diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py
new file mode 100644
index 000000000000..33f4eaff143d
--- /dev/null
+++ b/src/transformers/audio_processing_utils.py
@@ -0,0 +1,633 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import fields, replace
+from typing import Unpack
+
+import numpy as np
+from huggingface_hub.dataclasses import validate_typed_dict
+
+from .audio_processing_base import AudioProcessingMixin
+from .audio_utils import AudioInput, SpectrogramConfig, make_list_of_audio
+from .feature_extraction_utils import BatchFeature
+from .tokenization_utils_base import PaddingStrategy, TruncationStrategy
+from .processing_utils import AudioKwargs
+from .utils import PaddingStrategy, TensorType, logging
+
+from typing import TypedDict
+
+
+logger = logging.get_logger(__name__)
+
+
+class AudioKwargs(TypedDict, total=False):
+ sampling_rate: int | None
+ spectrogram_config: dict | SpectrogramConfig | None
+ do_extract_spectrogram: bool | None
+ do_resample: bool | None
+ return_tensors: str | TensorType | None
+ padding: bool | str | PaddingStrategy | None
+ max_length: int | None
+ truncation: bool | str | TruncationStrategy | None
+ pad_to_multiple_of: int | None
+
+
+class BaseAudioProcessor(AudioProcessingMixin):
+ model_input_names = ["audio"]
+ valid_kwargs = AudioKwargs
+ unused_kwargs = None
+
+ # global defaults
+ sample_rate: int = None
+ force_mono: bool = None
+ add_channel_dim: bool = False
+
+ # padding defaults
+ padding = True
+ padding_side = "right"
+ padding_value = 0.0
+ max_length = None
+ truncation = None
+ pad_to_multiple_of = None
+
+ return_padding_mask = True
+ mask_level = None # None = auto (features for spectrogram, audio for raw), "audio" = always audio-level
+ spectrogram_config = None
+ do_extract_spectrogram = None
+
+ def __init__(
+ self,
+ sample_rate: int | None = None,
+ force_mono: bool | None = None,
+ **kwargs,
+ ):
+ if sample_rate is not None:
+ self.sample_rate = sample_rate
+ if self.sample_rate is None:
+ raise ValueError(
+ f"`sample_rate` must be set either as a class attribute on {self.__class__.__name__} "
+ "or passed to __init__."
+ )
+
+ if force_mono is not None:
+ self.force_mono = force_mono
+ if self.force_mono is None:
+ raise ValueError(
+ f"`force_mono` must be set either as a class attribute on {self.__class__.__name__} "
+ "or passed to __init__."
+ )
+
+ super().__init__(**kwargs)
+
+ # Standardize init attributes (coerce dicts to config dataclasses)
+ attributes = {key: getattr(self, key) for key in self._valid_kwargs_names}
+ attributes = self._standardize_kwargs(**attributes)
+ for key, value in attributes.items():
+ setattr(self, key, value)
+
+ # Pre-compute mel filters from spectrogram_config
+ if self.spectrogram_config is not None:
+ if self.spectrogram_config.mel_scale_config is not None and not hasattr(self, "mel_filters"):
+ self.mel_filters = self._mel_filter_bank(self.spectrogram_config)
+ self._cached_stft_window = None
+
+ def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature:
+ return self.preprocess(audio, *args, **kwargs)
+
+ def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature:
+ """
+ Preprocess an audio or a batch of audio.
+ """
+ # Perform type validation on received kwargs
+ validate_typed_dict(self.valid_kwargs, kwargs)
+
+ # Set default kwargs from self.
+ for kwarg_name in self._valid_kwargs_names:
+ kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
+
+ # Standardize kwargs (coerce dicts to config dataclasses)
+ kwargs = self._standardize_kwargs(**kwargs)
+
+ # Validate kwargs
+ self._validate_preprocess_kwargs(**kwargs)
+
+ return self._preprocess_audio_like_inputs(audio, *args, **kwargs)
+
+ def _preprocess_audio_like_inputs(
+ self,
+ audio: AudioInput,
+ *args,
+ sample_rate: int | None = None,
+ **kwargs: Unpack[AudioKwargs],
+ ) -> BatchFeature:
+ audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate)
+ return self._preprocess(audio, *args, **kwargs)
+
+ def _to_batch(self, audio):
+ """Stack a list of audio arrays/tensors into a batch. Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config):
+ """Build attention mask dict from audio_ranges. Returns a dict of {key: mask} to merge into output.
+ Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _preprocess(
+ self,
+ audio,
+ padding,
+ max_length,
+ truncation,
+ pad_to_multiple_of,
+ return_tensors,
+ spectrogram_config=None,
+ do_extract_spectrogram=None,
+ do_batch_spectrogram=None,
+ **kwargs,
+ ) -> BatchFeature:
+ if do_batch_spectrogram is None:
+ do_batch_spectrogram = getattr(self, "do_batch_spectrogram", True)
+ if do_extract_spectrogram and not do_batch_spectrogram:
+ # Per-waveform extraction path: extract → postprocess → pad features → mask
+ features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs)
+ feature_lengths = [f.shape[0] for f in features]
+ features = self._postprocess_features(features, feature_lengths)
+ features, feature_ranges = self._pad_features(
+ features, padding, max_length, truncation, pad_to_multiple_of
+ )
+ output = {"audio_features": self._stack_features(features)}
+ if self.return_padding_mask:
+ padded_length = features[0].shape[0]
+ output.update(self._get_feature_mask(feature_ranges, padded_length))
+ output = self._postprocess_output(output, feature_ranges=feature_ranges, **kwargs)
+ else:
+ # Standard path: pad audio → optionally batch → extract/passthrough
+ audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of)
+ padded_length = audio[0].shape[-1]
+
+ if do_extract_spectrogram:
+ audio = self._to_batch(audio) if do_batch_spectrogram else audio
+ feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, audio_ranges=audio_ranges, **kwargs)
+ output = {"audio_features": feature}
+ else:
+ output = {"audio_values": self._to_batch(audio)}
+
+ if self.return_padding_mask:
+ output.update(self._get_mask(
+ audio_ranges, padded_length, do_extract_spectrogram=do_extract_spectrogram, spectrogram_config=spectrogram_config
+ ))
+ output = self._postprocess_output(output, audio_ranges=audio_ranges, **kwargs)
+
+ return BatchFeature(data=output, tensor_type=return_tensors)
+
+ def _postprocess_features(self, features, feature_lengths):
+ """Hook: per-utterance feature processing after extraction, before feature-level padding.
+
+ Override for normalization that must happen on unpadded features
+ (e.g., SeamlessM4t mean/variance normalization).
+ """
+ return features
+
+ def _postprocess_output(self, output, audio_ranges=None, feature_ranges=None, **kwargs):
+ """Hook: augment or modify the output dict after main processing.
+
+ Override to add custom fields (e.g., audio_embed_sizes) or
+ post-hoc normalization on the stacked/batched output.
+ """
+ return output
+
+ def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of):
+ """Pad a list of 2D feature arrays along the time axis (axis 0).
+ Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _stack_features(self, features):
+ """Stack a list of feature arrays/tensors into a batch.
+ Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _get_feature_mask(self, feature_ranges, padded_length):
+ """Build attention mask dict from feature_ranges.
+ Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list:
+ """
+ Prepare audio-like inputs for processing by structuring and then converting each
+ audio item via `process_audio`.
+
+ Analogous to `_prepare_image_like_inputs` in the image processing pipeline.
+ """
+ audio = self._prepare_audio_structure(audio, sample_rate=sample_rate)
+ audio = [self.process_audio(audio_el) for audio_el in audio]
+ return audio
+
+ def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = None) -> list:
+ """
+ Prepare the audio structure for processing: fetch URL inputs, validate sample rate,
+ and flatten into a list of audio arrays.
+
+ Analogous to `_prepare_images_structure` in the image processing pipeline.
+ """
+ is_url_input = isinstance(audio, str) or (
+ isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio)
+ )
+
+ if is_url_input:
+ # URL inputs: load directly at the correct sample rate
+ audio = self.fetch_audio(audio)
+ else:
+ # Array inputs: validate that the user-provided sample rate matches the model's
+ if sample_rate is not None:
+ if sample_rate != self.sample_rate:
+ raise ValueError(
+ f"The model corresponding to this audio processor: {self.__class__.__name__} was trained using a"
+ f" sample rate of {self.sample_rate}. Please make sure that the provided `audio` input"
+ f" was sampled with {self.sample_rate} and not {sample_rate}."
+ )
+ else:
+ logger.warning(
+ f"It is strongly recommended to pass the `sample_rate` argument to `{self.__class__.__name__}()`. "
+ "Failing to do so can result in silent errors that might be hard to debug."
+ )
+
+ audio = make_list_of_audio(audio)
+ return audio
+
+ def _process_audio(self, *args, **kwargs):
+ """
+ Process a single raw audio input into the backend's working format.
+
+ Implemented by backend subclasses (e.g., `TorchAudioBackend`). Converts a raw input
+ (NumPy array) to the backend's internal format (e.g., `torch.Tensor`), handles
+ mono conversion if needed.
+ """
+ raise NotImplementedError
+
+ def process_audio(self, *args, **kwargs):
+ return self._process_audio(*args, **kwargs)
+
+ def pad(
+ self,
+ audio: AudioInput, # TODO: this type makes it unclear to know the have an iterable
+ padding: bool | str | PaddingStrategy = True,
+ max_length: int | None = None,
+ truncation: bool = False,
+ pad_to_multiple_of: int | None = None,
+ ) -> tuple[list, list[tuple[int, int]]]:
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
+
+ if truncation:
+ if max_length is None:
+ # TODO: maybe this check should happen in the _validate_preprocess_kwargs method
+ raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.")
+ trunc_length = max_length
+ if pad_to_multiple_of is not None and (trunc_length % pad_to_multiple_of != 0):
+ trunc_length = ((trunc_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+ audio = [self._truncate_single(audio_el, max_length=trunc_length) for audio_el in audio]
+
+ if padding_strategy == PaddingStrategy.LONGEST:
+ max_length = max(audio_el.shape[-1] for audio_el in audio)
+ padding_strategy = PaddingStrategy.MAX_LENGTH
+
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+ actual_lengths = [audio_el.shape[-1] for audio_el in audio]
+
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD:
+ audio = [self._pad_single(audio_el, max_length=max_length) for audio_el in audio]
+
+ audio_ranges = []
+ for i, length in enumerate(actual_lengths):
+ padded_length = audio[i].shape[-1]
+ if self.padding_side == "left":
+ audio_ranges.append((padded_length - length, padded_length))
+ else:
+ audio_ranges.append((0, length))
+
+ return audio, audio_ranges
+
+ def _truncate_single(self, audio_el, max_length: int):
+ """Truncate a single audio element to max_length along the time axis."""
+ if audio_el.shape[-1] > max_length:
+ return audio_el[..., :max_length]
+ return audio_el
+
+ def _pad_single(self, audio, max_length: int) -> AudioInput:
+ """
+ Pad a single input (on left/right) up to predefined length or max length in the batch.
+
+ Implemented by backend subclasses.
+ """
+ raise NotImplementedError
+
+ def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | None = None, **kwargs):
+ """
+ Extract spectrogram features from audio.
+
+ Both the numpy and torch backends implement this method in a batched/sequential manner.
+ It is batched by default, but can be set to be sequential.
+ This can extract just a spectrogram or a Mel spectrogram if a mel config is provided.
+
+ Any extra kwargs whose names match ``SpectrogramConfig`` fields will
+ override the corresponding value on the config for this call.
+
+ Note: Models that bypass the base STFT pipeline entirely (e.g., GraniteSpeech
+ using torchaudio.transforms.MelSpectrogram, or MusicgenMelody using chroma
+ features) can set ``do_extract_spectrogram=True`` without providing a
+ ``spectrogram_config``. They must override this method completely.
+ """
+ if spectrogram_config is None:
+ spectrogram_config = self.spectrogram_config
+
+ config_field_names = {f.name for f in fields(SpectrogramConfig)}
+ overrides = {k: kwargs.pop(k) for k in list(kwargs) if k in config_field_names}
+ if overrides:
+ spectrogram_config = replace(spectrogram_config, **overrides)
+
+ if isinstance(audio, list):
+ features = [
+ self._extract_spectrogram(a, spectrogram_config=spectrogram_config, **kwargs)
+ for a in audio
+ ]
+ if spectrogram_config.mel_scale_config is not None:
+ features = [
+ self._apply_mel_scale(f, spectrogram_config=spectrogram_config, **kwargs)
+ for f in features
+ ]
+ features = [
+ self._normalize_magnitude(f, spectrogram_config=spectrogram_config, **kwargs)
+ for f in features
+ ]
+ else:
+ features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs)
+ if spectrogram_config.mel_scale_config is not None:
+ features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs)
+ features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs)
+
+ return features
+
+ # ── Spectrogram extraction pipeline ──────────────────────────────────
+ #
+ # The full feature-extraction pipeline executed by `extract_spectrogram`:
+ #
+ # 1. _extract_spectrogram (STFT → power/magnitude spectrogram)
+ # a. _stft – orchestrates steps b–g (overridable for fully custom STFTs)
+ # b. _needs_manual_framing – decide framing strategy (hook)
+ # c. _create_stft_window – create the STFT window (backend)
+ # d. _prepare_window_and_framing– pad/reshape window, decide frame length (backend)
+ # e. manual path (needs_manual_framing=True):
+ # _frame_audio – center pad + frame extraction (backend)
+ # _apply_frame_processing – per-frame conditioning (hook)
+ # _window_and_fft – window + zero-pad + FFT + normalize → complex (backend)
+ # native path (needs_manual_framing=False):
+ # _native_stft – native STFT returning complex output (backend)
+ # f. _compute_magnitudes – complex → real magnitudes (backend, shared by both paths)
+ # g. _cast_stft_output – cast output dtype (hook, no-op by default)
+ # 2. _apply_mel_scale (mel filterbank projection)
+ # 3. _normalize_magnitude (log / dB scaling, optional per-utterance norm)
+ #
+ # Backend subclasses (NumpyAudioBackend, TorchAudioBackend) implement the
+ # full pipeline. Model-specific processors can override individual hooks
+ # (_apply_frame_processing) or the entire _stft when the base STFT path
+ # is insufficient.
+ #
+ # ``audio_ranges`` is passed through as a kwarg from ``_preprocess`` so that
+ # model-specific overrides (e.g., Parakeet waveform-level preemphasis,
+ # Phi4 boundary masking) can access original audio lengths without stashing
+ # state on ``self``.
+
+ def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs):
+ """Orchestrate the STFT pipeline.
+
+ Runs the sub-steps listed above in order. Override this only when the
+ pipeline ordering itself needs to change. Otherwise, override individual hooks.
+ """
+ return self._stft(audio, spectrogram_config=spectrogram_config, **kwargs)
+
+ def _stft(self, audio, *, spectrogram_config, **kwargs):
+ """Compute the STFT and return a power/magnitude spectrogram.
+
+ Orchestrates the sub-steps listed in the pipeline documentation above.
+ Backend subclasses implement the individual leaf methods; model-specific
+ processors can override this entirely for a fully custom STFT
+ (e.g., Gemma3n's unfold-based STFT with extra-sample framing).
+ """
+ stft_cfg = spectrogram_config.stft_config
+ n_fft = stft_cfg.n_fft
+ win_length = stft_cfg.win_length or n_fft
+ hop_length = stft_cfg.hop_length or win_length // 2
+ needs_manual_framing = self._needs_manual_framing(spectrogram_config)
+
+ if spectrogram_config.computation_dtype:
+ dtype_str = spectrogram_config.computation_dtype
+ if isinstance(audio, np.ndarray):
+ audio = audio.astype(dtype_str)
+ else:
+ import torch
+ audio = audio.to(getattr(torch, dtype_str))
+ if spectrogram_config.waveform_scale is not None:
+ audio = audio * spectrogram_config.waveform_scale
+
+ # Cache window on first call; reuse on subsequent calls with same config
+ if self._cached_stft_window is not None and spectrogram_config is self.spectrogram_config:
+ window, frame_length = self._cached_stft_window
+ else:
+ window = self._create_stft_window(win_length, stft_cfg, audio)
+ window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing)
+ if spectrogram_config is self.spectrogram_config:
+ self._cached_stft_window = (window, frame_length)
+
+ if needs_manual_framing:
+ audio_dtype = audio.dtype
+ frames = self._frame_audio(audio, window, frame_length, hop_length, n_fft, stft_cfg)
+ frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs)
+ stft_out = self._window_and_fft(frames, window, frame_length, n_fft, stft_cfg, audio_dtype=audio_dtype)
+ else:
+ stft_out = self._native_stft(audio, window, frame_length, hop_length, n_fft, stft_cfg)
+
+ magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power, spectrogram_config=spectrogram_config)
+ return self._cast_stft_output(magnitudes, spectrogram_config)
+
+ def _create_stft_window(self, win_length, stft_cfg, audio):
+ """Create the STFT window. Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing):
+ """Pad/reshape window and determine frame length. Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ """Extract overlapping frames from the audio signal.
+
+ Handles center padding and dtype promotion. Returns frames of shape
+ (..., num_frames, frame_length). Implemented by backend subclasses.
+ """
+ raise NotImplementedError
+
+ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg):
+ """Apply window, zero-pad, FFT, and normalize. Returns complex STFT of shape (..., freq, time).
+ Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg):
+ """Native STFT (e.g. torch.stft). Returns complex output. Implemented by backend subclasses."""
+ raise NotImplementedError
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ """Convert complex STFT output to a real-valued magnitude spectrogram.
+ Implemented by backend subclasses. Overridable for custom magnitude computation (e.g. Parakeet)."""
+ raise NotImplementedError
+
+ def _cast_stft_output(self, magnitudes, spectrogram_config):
+ """Cast STFT output to the desired output dtype. Default: no-op."""
+ return magnitudes
+
+ def _needs_manual_framing(self, spectrogram_config):
+ """Whether the STFT requires manual framing (unfold-based) instead of a native STFT.
+
+ Manual framing is needed when per-frame processing must happen between
+ frame extraction and windowing (e.g. per-frame preemphasis, DC offset removal,
+ or left-aligned FFT padding).
+
+ Override in model-specific processors that handle preemphasis at the
+ waveform level (in ``_stft``) and don't need per-frame processing.
+ """
+ return (
+ (spectrogram_config.preemphasis is not None)
+ or spectrogram_config.remove_dc_offset
+ )
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ """Convert complex STFT output to a real-valued magnitude spectrogram.
+
+ Only used in the non-manual-framing STFT path. Override for
+ non-standard magnitude computation (e.g. Parakeet's view_as_real path).
+ """
+ raise NotImplementedError
+
+ def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs):
+ """Hook: per-frame signal conditioning after frame extraction.
+
+ Called after framing, before windowing and FFT. Default backend
+ implementations apply dither, DC-offset removal, and standard
+ preemphasis.
+
+ Override for non-standard frame processing, e.g. HTK-style
+ preemphasis (Gemma3n).
+ """
+ raise NotImplementedError
+
+ def _apply_mel_scale(self, *args, **kwargs):
+ """Apply mel filterbank to spectrogram features."""
+ raise NotImplementedError
+
+ def _normalize_magnitude(self, *args, **kwargs):
+ """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features."""
+ raise NotImplementedError
+
+ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig):
+ raise NotImplementedError
+
+ def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False):
+ """
+ Convert raw audio sample lengths to the number of feature frames after spectrogram extraction.
+
+ By default returns `audio_lengths // hop_length`, which gives the number of valid (non-padding)
+ feature frames for centered STFT. When `include_center_frame=True` and the STFT uses centering,
+ adds 1 to account for the extra frame produced by centered STFT.
+
+ Override this method in subclasses that use non-standard STFT configurations (e.g., unfold-based
+ or non-centered STFT).
+ """
+ hop_length = spectrogram_config.stft_config.hop_length
+ lengths = audio_lengths // hop_length
+ if include_center_frame and spectrogram_config.stft_config.center:
+ lengths = lengths + 1
+ return lengths
+
+ def _get_padding_strategies(self, padding=False, max_length=None):
+ """Find the correct padding strategy."""
+ if padding is not False:
+ if padding is True:
+ padding_strategy = PaddingStrategy.LONGEST
+ elif not isinstance(padding, PaddingStrategy):
+ padding_strategy = PaddingStrategy(padding)
+ elif isinstance(padding, PaddingStrategy):
+ padding_strategy = padding
+ else:
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+ if max_length is None:
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
+ raise ValueError(
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
+ )
+
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
+ raise ValueError(
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
+ )
+
+ return padding_strategy
+
+ def _standardize_kwargs(
+ self,
+ **kwargs,
+ ) -> dict:
+ """Coerce dict configs to their dataclass form."""
+ if isinstance(kwargs.get("spectrogram_config"), dict):
+ kwargs["spectrogram_config"] = SpectrogramConfig.from_dict(
+ kwargs["spectrogram_config"]
+ )
+ if kwargs.get("spectrogram_config") is not None and kwargs.get("do_extract_spectrogram") is None:
+ kwargs["do_extract_spectrogram"] = True
+ return kwargs
+
+ def _validate_preprocess_kwargs(
+ self,
+ sample_rate: int | None = None,
+ max_length: int | None = None,
+ truncation: bool | None = None,
+ pad_to_multiple_of: int | None = None,
+ return_tensors: str | TensorType | None = None,
+ **kwargs,
+ ):
+ """Validate the kwargs for the preprocess method."""
+ if truncation and max_length is None:
+ raise ValueError(
+ "When setting `truncation=True`, make sure that `max_length` is defined."
+ )
+
+ def to_dict(self):
+ output = super().to_dict()
+ # Serialize config dataclasses to plain dicts for JSON persistence
+ for key in ("spectrogram_config",):
+ if key in output and hasattr(output[key], "to_dict"):
+ output[key] = output[key].to_dict()
+
+ # Filter out None values that are class defaults
+ filtered_dict = {}
+ for key, value in output.items():
+ if value is None:
+ class_default = getattr(type(self), key, "NOT_FOUND")
+ # Keep None if user explicitly set it (class default is non-None)
+ if class_default != "NOT_FOUND" and class_default is not None:
+ filtered_dict[key] = value
+ else:
+ filtered_dict[key] = value
+
+ return filtered_dict
diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py
index 85b56634afe7..14e70e12e45c 100644
--- a/src/transformers/audio_utils.py
+++ b/src/transformers/audio_utils.py
@@ -22,6 +22,7 @@
import os
import warnings
from collections.abc import Sequence
+from dataclasses import dataclass, field, fields
from io import BytesIO
from typing import TYPE_CHECKING, Any, Union
@@ -57,6 +58,124 @@
AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]
+@dataclass(frozen=True)
+class StftConfig:
+ """Configuration for Short-Time Fourier Transform.
+
+ Uses torchaudio parameter naming conventions. See
+ `torchaudio.transforms.MelSpectrogram` for reference.
+ """
+
+ n_fft: int = 400
+ win_length: int | None = None
+ hop_length: int | None = None
+ window_fn: str = "hann_window"
+ wkwargs: dict | None = None
+ power: float = 2.0
+ center: bool = True
+ pad_mode: str = "reflect"
+ normalized: bool = False
+ onesided: bool | None = None
+ periodic: bool = True
+ left_align_fft: bool = False
+ window_dtype: str | None = None
+
+ def to_dict(self) -> dict:
+ return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None}
+
+ @classmethod
+ def from_dict(cls, d: dict) -> "StftConfig":
+ valid_keys = {f.name for f in fields(cls)}
+ return cls(**{k: v for k, v in d.items() if k in valid_keys})
+
+
+@dataclass(frozen=True)
+class MelScaleConfig:
+ """Configuration for mel filterbank.
+
+ Uses torchaudio parameter naming conventions. See
+ `torchaudio.transforms.MelSpectrogram` for reference.
+ """
+
+ n_mels: int = 128
+ f_min: float = 0.0
+ f_max: float | None = None
+ mel_scale: str = "htk"
+ norm: str | None = None
+ triangularize_in_mel_space: bool = False
+ frequency_bin_mode: str = "rfft"
+ computation_dtype: str | None = None
+ bands_to_zero: int = 0
+ matmul_order: str = "filters_first"
+
+ def to_dict(self) -> dict:
+ return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None}
+
+ @classmethod
+ def from_dict(cls, d: dict) -> "MelScaleConfig":
+ valid_keys = {f.name for f in fields(cls)}
+ return cls(**{k: v for k, v in d.items() if k in valid_keys})
+
+
+@dataclass(frozen=True)
+class SpectrogramConfig:
+ """Configuration for spectrogram extraction, composed of STFT and mel scale sub-configs."""
+
+ stft_config: StftConfig = field(default_factory=StftConfig)
+ mel_scale_config: MelScaleConfig | None = None
+ log_mode: str = "log10"
+ chunk_length: int | None = None
+ preemphasis: float | None = None
+ remove_dc_offset: bool = False
+ mel_floor: float = 1e-10
+ waveform_scale: float | None = None
+ computation_dtype: str | None = None
+ skip_last_frame: bool = False
+
+ def __getitem__(self, key):
+ if hasattr(self, key):
+ return getattr(self, key)
+ raise KeyError(f"Key {key} not found in SpectrogramConfig.")
+
+ def __iter__(self):
+ for f in fields(self):
+ val = getattr(self, f.name)
+ if val is not None:
+ if hasattr(val, "to_dict"):
+ yield f.name, val.to_dict()
+ else:
+ yield f.name, val
+
+ def __eq__(self, other):
+ if isinstance(other, dict):
+ return dict(self) == other
+ if isinstance(other, SpectrogramConfig):
+ return tuple(getattr(self, f.name) for f in fields(self)) == tuple(
+ getattr(other, f.name) for f in fields(self)
+ )
+ return NotImplemented
+
+ def to_dict(self) -> dict:
+ return dict(self)
+
+ @classmethod
+ def from_dict(cls, d: dict) -> "SpectrogramConfig":
+ stft_config = StftConfig.from_dict(d["stft_config"]) if "stft_config" in d else StftConfig()
+ mel_scale_config = MelScaleConfig.from_dict(d["mel_scale_config"]) if "mel_scale_config" in d else None
+ return cls(
+ stft_config=stft_config,
+ mel_scale_config=mel_scale_config,
+ log_mode=d.get("log_mode", "log10"),
+ chunk_length=d.get("chunk_length"),
+ preemphasis=d.get("preemphasis"),
+ remove_dc_offset=d.get("remove_dc_offset", False),
+ mel_floor=d.get("mel_floor", 1e-10),
+ waveform_scale=d.get("waveform_scale"),
+ skip_last_frame=d.get("skip_last_frame", False),
+ )
+
+
+
def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray:
"""
Loads `audio` to an np.ndarray object.
@@ -282,10 +401,11 @@ def hertz_to_mel(freq: float | np.ndarray, mel_scale: str = "htk") -> float | np
elif mel_scale == "kaldi":
return 1127.0 * np.log(1.0 + (freq / 700.0))
+ f_sp = 200.0 / 3
min_log_hertz = 1000.0
- min_log_mel = 15.0
+ min_log_mel = min_log_hertz / f_sp
logstep = 27.0 / np.log(6.4)
- mels = 3.0 * freq / 200.0
+ mels = freq / f_sp
if isinstance(freq, np.ndarray):
log_region = freq >= min_log_hertz
@@ -318,10 +438,11 @@ def mel_to_hertz(mels: float | np.ndarray, mel_scale: str = "htk") -> float | np
elif mel_scale == "kaldi":
return 700.0 * (np.exp(mels / 1127.0) - 1.0)
+ f_sp = 200.0 / 3
min_log_hertz = 1000.0
- min_log_mel = 15.0
+ min_log_mel = min_log_hertz / f_sp
logstep = np.log(6.4) / 27.0
- freq = 200.0 * mels / 3.0
+ freq = f_sp * mels
if isinstance(mels, np.ndarray):
log_region = mels >= min_log_mel
@@ -459,6 +580,7 @@ def mel_filter_bank(
norm: str | None = None,
mel_scale: str = "htk",
triangularize_in_mel_space: bool = False,
+ dtype: np.dtype | None = None,
) -> np.ndarray:
"""
Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
@@ -527,7 +649,20 @@ def mel_filter_bank(
# frequencies of FFT bins in Hz
fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
- mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+ if dtype is not None:
+ # Per-band computation matching librosa's precision path: compute slopes in float64,
+ # cast each band to dtype immediately. This replicates librosa's per-row assignment
+ # to a dtype-initialized array, which produces different rounding than computing all
+ # bands in float64 and casting at the end.
+ filter_diff = np.diff(filter_freqs)
+ ramps = np.subtract.outer(filter_freqs, fft_freqs) # (num_mel_filters+2, num_frequency_bins)
+ mel_filters = np.zeros((num_frequency_bins, num_mel_filters), dtype=dtype)
+ for i in range(num_mel_filters):
+ lower = -ramps[i] / filter_diff[i]
+ upper = ramps[i + 2] / filter_diff[i + 1]
+ mel_filters[:, i] = np.maximum(0, np.minimum(lower, upper)).astype(dtype)
+ else:
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
if norm is not None and norm == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
@@ -620,428 +755,6 @@ def window_function(
return padded_window
-# Note: This method processes a single waveform. For batch processing, use spectrogram_batch().
-def spectrogram(
- waveform: np.ndarray,
- window: np.ndarray,
- frame_length: int,
- hop_length: int,
- fft_length: int | None = None,
- power: float | None = 1.0,
- center: bool = True,
- pad_mode: str = "reflect",
- onesided: bool = True,
- dither: float = 0.0,
- preemphasis: float | None = None,
- mel_filters: np.ndarray | None = None,
- mel_floor: float = 1e-10,
- log_mel: str | None = None,
- reference: float = 1.0,
- min_value: float = 1e-10,
- db_range: float | None = None,
- remove_dc_offset: bool = False,
- dtype: np.dtype = np.float32,
-) -> np.ndarray:
- """
- Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
-
- This function can create the following kinds of spectrograms:
-
- - amplitude spectrogram (`power = 1.0`)
- - power spectrogram (`power = 2.0`)
- - complex-valued spectrogram (`power = None`)
- - log spectrogram (use `log_mel` argument)
- - mel spectrogram (provide `mel_filters`)
- - log-mel spectrogram (provide `mel_filters` and `log_mel`)
-
- How this works:
-
- 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
- - hop_length` samples.
- 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
- 3. The DFT is taken of each windowed frame.
- 4. The results are stacked into a spectrogram.
-
- We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
-
- - The analysis frame. This is the size of the time slices that the input waveform is split into.
- - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
- - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
-
- In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
- padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
- typically the next power of two.
-
- Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
- `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
- can be constructed.
-
- Args:
- waveform (`np.ndarray` of shape `(length,)`):
- The input waveform. This must be a single real-valued, mono waveform.
- window (`np.ndarray` of shape `(frame_length,)`):
- The windowing function to apply, including zero-padding if necessary. The actual window length may be
- shorter than `frame_length`, but we're assuming the array has already been zero-padded.
- frame_length (`int`):
- The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
- allow smaller sizes.
- hop_length (`int`):
- The stride between successive analysis frames in samples.
- fft_length (`int`, *optional*):
- The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
- For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
- power (`float`, *optional*, defaults to 1.0):
- If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
- complex numbers.
- center (`bool`, *optional*, defaults to `True`):
- Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
- `t` will start at time `t * hop_length`.
- pad_mode (`str`, *optional*, defaults to `"reflect"`):
- Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
- (pad with edge values), `"reflect"` (pads with mirrored values).
- onesided (`bool`, *optional*, defaults to `True`):
- If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
- frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
- dither (`float`, *optional*, defaults to 0.0):
- Adds dithering. In other words, adds a small Gaussian noise to each frame.
- E.g. use 4.0 to add dithering with a normal distribution centered
- around 0.0 with standard deviation 4.0, 0.0 means no dithering.
- Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
- values for signals with hard-zero sections, when VAD cutoff is present in the signal.
- preemphasis (`float`, *optional*)
- Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
- mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
- The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
- mel_floor (`float`, *optional*, defaults to 1e-10):
- Minimum value of mel frequency banks.
- log_mel (`str`, *optional*):
- How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
- the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
- used when `power` is not `None`.
- reference (`float`, *optional*, defaults to 1.0):
- Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
- the loudest part to 0 dB. Must be greater than zero.
- min_value (`float`, *optional*, defaults to `1e-10`):
- The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
- `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
- amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
- db_range (`float`, *optional*):
- Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
- peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
- remove_dc_offset (`bool`, *optional*):
- Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
- order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
- dtype (`np.dtype`, *optional*, defaults to `np.float32`):
- Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
- `np.complex64`.
-
- Returns:
- `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
- `(num_mel_filters, length)` for a mel spectrogram.
- """
- window_length = len(window)
-
- if fft_length is None:
- fft_length = frame_length
-
- if frame_length > fft_length:
- raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
-
- if window_length != frame_length:
- raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
-
- if hop_length <= 0:
- raise ValueError("hop_length must be greater than zero")
-
- if waveform.ndim != 1:
- raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
-
- if np.iscomplexobj(waveform):
- raise ValueError("Complex-valued input waveforms are not currently supported")
-
- if power is None and mel_filters is not None:
- raise ValueError(
- "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
- "Specify `power` to fix this issue."
- )
-
- # center pad the waveform
- if center:
- padding = [(int(frame_length // 2), int(frame_length // 2))]
- waveform = np.pad(waveform, padding, mode=pad_mode)
-
- # promote to float64, since np.fft uses float64 internally
- waveform = waveform.astype(np.float64)
- window = window.astype(np.float64)
-
- # split waveform into frames of frame_length size
- num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
-
- num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
- spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
-
- # rfft is faster than fft
- fft_func = np.fft.rfft if onesided else np.fft.fft
- buffer = np.zeros(fft_length)
-
- timestep = 0
- for frame_idx in range(num_frames):
- buffer[:frame_length] = waveform[timestep : timestep + frame_length]
-
- if dither != 0.0:
- buffer[:frame_length] += dither * np.random.randn(frame_length)
-
- if remove_dc_offset:
- buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
-
- if preemphasis is not None:
- buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
- buffer[0] *= 1 - preemphasis
-
- buffer[:frame_length] *= window
-
- spectrogram[frame_idx] = fft_func(buffer)
- timestep += hop_length
-
- # note: ** is much faster than np.power
- if power is not None:
- spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
-
- spectrogram = spectrogram.T
-
- if mel_filters is not None:
- spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
-
- if power is not None and log_mel is not None:
- if log_mel == "log":
- spectrogram = np.log(spectrogram)
- elif log_mel == "log10":
- spectrogram = np.log10(spectrogram)
- elif log_mel == "dB":
- if power == 1.0:
- spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
- elif power == 2.0:
- spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
- else:
- raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
- else:
- raise ValueError(f"Unknown log_mel option: {log_mel}")
-
- spectrogram = np.asarray(spectrogram, dtype)
-
- return spectrogram
-
-
-def spectrogram_batch(
- waveform_list: list[np.ndarray],
- window: np.ndarray,
- frame_length: int,
- hop_length: int,
- fft_length: int | None = None,
- power: float | None = 1.0,
- center: bool = True,
- pad_mode: str = "reflect",
- onesided: bool = True,
- dither: float = 0.0,
- preemphasis: float | None = None,
- mel_filters: np.ndarray | None = None,
- mel_floor: float = 1e-10,
- log_mel: str | None = None,
- reference: float = 1.0,
- min_value: float = 1e-10,
- db_range: float | None = None,
- remove_dc_offset: bool = False,
- dtype: np.dtype = np.float32,
-) -> list[np.ndarray]:
- """
- Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
- This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
-
- It supports generating various types of spectrograms:
-
- - amplitude spectrogram (`power = 1.0`)
- - power spectrogram (`power = 2.0`)
- - complex-valued spectrogram (`power = None`)
- - log spectrogram (use `log_mel` argument)
- - mel spectrogram (provide `mel_filters`)
- - log-mel spectrogram (provide `mel_filters` and `log_mel`)
-
- How this works:
-
- 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
- - hop_length` samples.
- 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
- 3. The DFT is taken of each windowed frame.
- 4. The results are stacked into a spectrogram.
-
- We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
-
- - The analysis frame. This is the size of the time slices that the input waveform is split into.
- - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
- - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
-
- In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
- padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
- typically the next power of two.
-
- Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
-
- Args:
- waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`):
- The list of input waveforms, each a single-channel (mono) signal.
- window (`np.ndarray` of shape `(frame_length,)`):
- The windowing function to apply, including zero-padding if necessary.
- frame_length (`int`):
- The length of each frame for analysis.
- hop_length (`int`):
- The step size between successive frames.
- fft_length (`int`, *optional*):
- The size of the FFT buffer, defining frequency bin resolution.
- power (`float`, *optional*, defaults to 1.0):
- Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
- center (`bool`, *optional*, defaults to `True`):
- Whether to center-pad the waveform frames.
- pad_mode (`str`, *optional*, defaults to `"reflect"`):
- The padding strategy when `center` is `True`.
- onesided (`bool`, *optional*, defaults to `True`):
- If True, returns a one-sided spectrogram for real input signals.
- dither (`float`, *optional*, defaults to 0.0):
- Adds dithering. In other words, adds a small Gaussian noise to each frame.
- E.g. use 4.0 to add dithering with a normal distribution centered
- around 0.0 with standard deviation 4.0, 0.0 means no dithering.
- preemphasis (`float`, *optional*):
- Applies a pre-emphasis filter to each frame.
- mel_filters (`np.ndarray`, *optional*):
- Mel filter bank for converting to mel spectrogram.
- mel_floor (`float`, *optional*, defaults to 1e-10):
- Floor value for mel spectrogram to avoid log(0).
- log_mel (`str`, *optional*):
- Specifies log scaling strategy; options are None, "log", "log10", "dB".
- reference (`float`, *optional*, defaults to 1.0):
- Reference value for dB conversion in log_mel.
- min_value (`float`, *optional*, defaults to 1e-10):
- Minimum floor value for log scale conversions.
- db_range (`float`, *optional*):
- Dynamic range for dB scale spectrograms.
- remove_dc_offset (`bool`, *optional*):
- Whether to remove the DC offset from each frame.
- dtype (`np.dtype`, *optional*, defaults to `np.float32`):
- Data type of the output spectrogram.
-
- Returns:
- list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
- """
- window_length = len(window)
-
- if fft_length is None:
- fft_length = frame_length
-
- if frame_length > fft_length:
- raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
-
- if window_length != frame_length:
- raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
-
- if hop_length <= 0:
- raise ValueError("hop_length must be greater than zero")
-
- # Check the dimensions of the waveform , and if waveform is complex
- for waveform in waveform_list:
- if waveform.ndim != 1:
- raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
- if np.iscomplexobj(waveform):
- raise ValueError("Complex-valued input waveforms are not currently supported")
- # Center pad the waveform
- if center:
- padding = [(int(frame_length // 2), int(frame_length // 2))]
- waveform_list = [
- np.pad(
- waveform,
- padding,
- mode=pad_mode,
- )
- for waveform in waveform_list
- ]
- original_waveform_lengths = [
- len(waveform) for waveform in waveform_list
- ] # these lengths will be used to remove padding later
-
- # Batch pad the waveform
- max_length = max(original_waveform_lengths)
- padded_waveform_batch = np.array(
- [
- np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
- for waveform in waveform_list
- ],
- dtype=dtype,
- )
-
- # Promote to float64, since np.fft uses float64 internally
- padded_waveform_batch = padded_waveform_batch.astype(np.float64)
- window = window.astype(np.float64)
-
- # Split waveform into frames of frame_length size
- num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
- # these lengths will be used to remove padding later
- true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
- num_batches = padded_waveform_batch.shape[0]
-
- num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
- spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
-
- # rfft is faster than fft
- fft_func = np.fft.rfft if onesided else np.fft.fft
- buffer = np.zeros((num_batches, fft_length))
-
- for frame_idx in range(num_frames):
- timestep = frame_idx * hop_length
- buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
-
- if dither != 0.0:
- buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
-
- if remove_dc_offset:
- buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
-
- if preemphasis is not None:
- buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
- buffer[:, 0] *= 1 - preemphasis
-
- buffer[:, :frame_length] *= window
-
- spectrogram[:, frame_idx] = fft_func(buffer)
-
- # Note: ** is much faster than np.power
- if power is not None:
- spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
-
- # Apply mel filters if provided
- if mel_filters is not None:
- result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
- spectrogram = np.maximum(mel_floor, result)
-
- # Convert to log scale if specified
- if power is not None and log_mel is not None:
- if log_mel == "log":
- spectrogram = np.log(spectrogram)
- elif log_mel == "log10":
- spectrogram = np.log10(spectrogram)
- elif log_mel == "dB":
- if power == 1.0:
- spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
- elif power == 2.0:
- spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
- else:
- raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
- else:
- raise ValueError(f"Unknown log_mel option: {log_mel}")
-
- spectrogram = np.asarray(spectrogram, dtype)
-
- spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
-
- return spectrogram_list
-
def power_to_db(
spectrogram: np.ndarray,
@@ -1094,55 +807,6 @@ def power_to_db(
return spectrogram
-def power_to_db_batch(
- spectrogram: np.ndarray,
- reference: float = 1.0,
- min_value: float = 1e-10,
- db_range: float | None = None,
-) -> np.ndarray:
- """
- Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
- using basic logarithm properties for numerical stability.
-
- This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
-
- Args:
- spectrogram (`np.ndarray`):
- The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
- Note that a power spectrogram has the amplitudes squared!
- reference (`float`, *optional*, defaults to 1.0):
- Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
- the loudest part to 0 dB. Must be greater than zero.
- min_value (`float`, *optional*, defaults to `1e-10`):
- The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
- `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
- db_range (`float`, *optional*):
- Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
- peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
-
- Returns:
- `np.ndarray`: the batch of spectrograms in decibels
- """
- if reference <= 0.0:
- raise ValueError("reference must be greater than zero")
- if min_value <= 0.0:
- raise ValueError("min_value must be greater than zero")
-
- reference = max(min_value, reference)
-
- spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
- spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
-
- if db_range is not None:
- if db_range <= 0.0:
- raise ValueError("db_range must be greater than zero")
- # Apply db_range clipping per batch item
- max_values = spectrogram.max(axis=(1, 2), keepdims=True)
- spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
-
- return spectrogram
-
-
def amplitude_to_db(
spectrogram: np.ndarray,
reference: float = 1.0,
@@ -1192,46 +856,3 @@ def amplitude_to_db(
return spectrogram
-def amplitude_to_db_batch(
- spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: float | None = None
-) -> np.ndarray:
- """
- Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
- using basic logarithm properties for numerical stability.
-
- The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
-
- Args:
- spectrogram (`np.ndarray`):
- The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
- reference (`float`, *optional*, defaults to 1.0):
- Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
- the loudest part to 0 dB. Must be greater than zero.
- min_value (`float`, *optional*, defaults to `1e-5`):
- The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
- `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
- db_range (`float`, *optional*):
- Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
- peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
-
- Returns:
- `np.ndarray`: the batch of spectrograms in decibels
- """
- if reference <= 0.0:
- raise ValueError("reference must be greater than zero")
- if min_value <= 0.0:
- raise ValueError("min_value must be greater than zero")
-
- reference = max(min_value, reference)
-
- spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
- spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
-
- if db_range is not None:
- if db_range <= 0.0:
- raise ValueError("db_range must be greater than zero")
- # Apply db_range clipping per batch item
- max_values = spectrogram.max(axis=(1, 2), keepdims=True)
- spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
-
- return spectrogram
diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index f1b66f752da4..b30a056f7794 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -15,20 +15,15 @@
Feature extraction saving/loading class for common feature extractors.
"""
-import copy
-import json
import os
from collections import UserDict
from typing import TYPE_CHECKING, Any, TypeVar, Union
import numpy as np
-from huggingface_hub import create_repo, is_offline_mode
-from .dynamic_module_utils import custom_object_save
+from .preprocessing_base import PreprocessingMixin
from .utils import (
FEATURE_EXTRACTOR_NAME,
- PROCESSOR_NAME,
- PushToHubMixin,
TensorType,
_is_tensor_or_array_like,
copy_func,
@@ -38,9 +33,7 @@
is_torch_dtype,
logging,
requires_backends,
- safe_load_json_file,
)
-from .utils.hub import cached_file
if TYPE_CHECKING:
@@ -263,170 +256,21 @@ def maybe_to(v):
return self
-class FeatureExtractionMixin(PushToHubMixin):
+class FeatureExtractionMixin(PreprocessingMixin):
"""
This is a feature extraction mixin used to provide saving/loading functionality for sequential and audio feature
extractors.
"""
- _auto_class = None
-
- def __init__(self, **kwargs):
- """Set elements of `kwargs` as attributes."""
- # Pop "processor_class", it should not be saved in feature extractor config
- kwargs.pop("processor_class", None)
- # Additional attributes without default values
- for key, value in kwargs.items():
- try:
- setattr(self, key, value)
- except AttributeError as err:
- logger.error(f"Can't set {key} with value {value} for {self}")
- raise err
-
- @classmethod
- def from_pretrained(
- cls: type[SpecificFeatureExtractorType],
- pretrained_model_name_or_path: str | os.PathLike,
- cache_dir: str | os.PathLike | None = None,
- force_download: bool = False,
- local_files_only: bool = False,
- token: str | bool | None = None,
- revision: str = "main",
- **kwargs,
- ) -> SpecificFeatureExtractorType:
- r"""
- Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
- derived class of [`SequenceFeatureExtractor`].
-
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
-
- - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a feature extractor file saved using the
- [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
- `./my_model_directory/`.
- - a path or url to a saved feature extractor JSON *file*, e.g.,
- `./my_model_directory/preprocessor_config.json`.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the feature extractor files and override the cached versions
- if they exist.
- proxies (`dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
-
-
-
-
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
-
-
-
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
- If `False`, then this function returns just the final feature extractor object. If `True`, then this
- functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
- consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
- `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
- kwargs (`dict[str, Any]`, *optional*):
- The values in kwargs of any keys which are feature extractor attributes will be used to override the
- loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
- controlled by the `return_unused_kwargs` keyword parameter.
-
- Returns:
- A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
-
- Examples:
-
- ```python
- # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
- # derived class: *Wav2Vec2FeatureExtractor*
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
- "facebook/wav2vec2-base-960h"
- ) # Download feature_extraction_config from huggingface.co and cache.
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
- "./test/saved_model/"
- ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
- "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
- )
- assert feature_extractor.return_attention_mask is False
- feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
- "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
- )
- assert feature_extractor.return_attention_mask is False
- assert unused_kwargs == {"foo": False}
- ```"""
- kwargs["cache_dir"] = cache_dir
- kwargs["force_download"] = force_download
- kwargs["local_files_only"] = local_files_only
- kwargs["revision"] = revision
-
- if token is not None:
- kwargs["token"] = token
-
- feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
-
- return cls.from_dict(feature_extractor_dict, **kwargs)
-
- def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
- """
- Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
- [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
-
- Args:
- save_directory (`str` or `os.PathLike`):
- Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- kwargs (`dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- """
- if os.path.isfile(save_directory):
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
-
- os.makedirs(save_directory, exist_ok=True)
-
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
- repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
- files_timestamps = self._get_files_timestamps(save_directory)
-
- # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self)
-
- # If we save using the predefined names, we can load using `from_pretrained`
- output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
-
- self.to_json_file(output_feature_extractor_file)
- logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
-
- if push_to_hub:
- self._upload_modified_files(
- save_directory,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=kwargs.get("token"),
- )
-
- return [output_feature_extractor_file]
+ _config_name = FEATURE_EXTRACTOR_NAME
+ _type_key = "feature_extractor_type"
+ _nested_config_keys = ["feature_extractor", "audio_processor"]
+ _auto_class_default = "AutoFeatureExtractor"
+ _file_type_label = "feature extractor"
+ _excluded_dict_keys = {"mel_filters", "window"}
+ _extra_init_pops = []
+ _config_filename_kwarg = None
+ _subfolder_default = None
@classmethod
def get_feature_extractor_dict(
@@ -443,104 +287,7 @@ def get_feature_extractor_dict(
Returns:
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
"""
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- subfolder = kwargs.pop("subfolder", None)
- token = kwargs.pop("token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
-
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
-
- user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
- if os.path.isfile(pretrained_model_name_or_path):
- resolved_feature_extractor_file = pretrained_model_name_or_path
- resolved_processor_file = None
- is_local = True
- else:
- feature_extractor_file = FEATURE_EXTRACTOR_NAME
- try:
- # Load from local folder or from cache or download from model Hub and cache
- resolved_processor_file = cached_file(
- pretrained_model_name_or_path,
- filename=PROCESSOR_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _raise_exceptions_for_missing_entries=False,
- )
- resolved_feature_extractor_file = cached_file(
- pretrained_model_name_or_path,
- filename=feature_extractor_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _raise_exceptions_for_missing_entries=False,
- )
- except OSError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
- # the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise OSError(
- f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
- " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
- )
-
- # Load feature_extractor dict. Priority goes as (nested config if found -> image processor config)
- # We are downloading both configs because almost all models have a `processor_config.json` but
- # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
- feature_extractor_dict = None
- if resolved_processor_file is not None:
- processor_dict = safe_load_json_file(resolved_processor_file)
- if "feature_extractor" in processor_dict or "audio_processor" in processor_dict:
- feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor"))
-
- if resolved_feature_extractor_file is not None and feature_extractor_dict is None:
- feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file)
-
- if feature_extractor_dict is None:
- raise OSError(
- f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
- " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a {feature_extractor_file} file"
- )
-
- if is_local:
- logger.info(f"loading configuration file {resolved_feature_extractor_file}")
- else:
- logger.info(
- f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
- )
-
- return feature_extractor_dict, kwargs
+ return cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
@classmethod
def from_dict(
@@ -581,89 +328,6 @@ def from_dict(
else:
return feature_extractor
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["feature_extractor_type"] = self.__class__.__name__
- if "mel_filters" in output:
- del output["mel_filters"]
- if "window" in output:
- del output["window"]
- return output
-
- @classmethod
- def from_json_file(cls, json_file: str | os.PathLike) -> "FeatureExtractionMixin":
- """
- Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
- a JSON file of parameters.
-
- Args:
- json_file (`str` or `os.PathLike`):
- Path to the JSON file containing the parameters.
-
- Returns:
- A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
- object instantiated from that JSON file.
- """
- with open(json_file, encoding="utf-8") as reader:
- text = reader.read()
- feature_extractor_dict = json.loads(text)
- return cls(**feature_extractor_dict)
-
- def to_json_string(self) -> str:
- """
- Serializes this instance to a JSON string.
-
- Returns:
- `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
- """
- dictionary = self.to_dict()
-
- for key, value in dictionary.items():
- if isinstance(value, np.ndarray):
- dictionary[key] = value.tolist()
-
- return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
-
- def to_json_file(self, json_file_path: str | os.PathLike):
- """
- Save this instance to a JSON file.
-
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this feature_extractor instance's parameters will be saved.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- writer.write(self.to_json_string())
-
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
-
- @classmethod
- def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
- """
- Register this class with a given auto class. This should only be used for custom feature extractors as the ones
- in the library are already mapped with `AutoFeatureExtractor`.
-
-
-
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
- The auto class to register this new feature extractor with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
-
- import transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
- cls._auto_class = auto_class
-
FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py
index 72db8fcc9bec..79d2f7bf2aec 100644
--- a/src/transformers/image_processing_base.py
+++ b/src/transformers/image_processing_base.py
@@ -12,26 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import copy
-import json
import os
from typing import Any, TypeVar
-import numpy as np
-from huggingface_hub import create_repo, is_offline_mode
-
-from .dynamic_module_utils import custom_object_save
from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .image_utils import is_valid_image, load_image
+from .preprocessing_base import PreprocessingMixin
from .utils import (
IMAGE_PROCESSOR_NAME,
- PROCESSOR_NAME,
- PushToHubMixin,
copy_func,
logging,
- safe_load_json_file,
)
-from .utils.hub import cached_file
ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
@@ -58,175 +49,21 @@ class BatchFeature(BaseBatchFeature):
# TODO: (Amy) - factor out the common parts of this and the feature extractor
-class ImageProcessingMixin(PushToHubMixin):
+class ImageProcessingMixin(PreprocessingMixin):
"""
This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
extractors.
"""
- _auto_class = None
-
- def __init__(self, **kwargs):
- """Set elements of `kwargs` as attributes."""
- # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
- # `XXXImageProcessor`, this attribute and its value are misleading.
- kwargs.pop("feature_extractor_type", None)
- # Pop "processor_class", should not be saved with image processing config anymore
- kwargs.pop("processor_class", None)
- # Additional attributes without default values
- for key, value in kwargs.items():
- try:
- setattr(self, key, value)
- except AttributeError as err:
- logger.error(f"Can't set {key} with value {value} for {self}")
- raise err
-
- @classmethod
- def from_pretrained(
- cls: type[ImageProcessorType],
- pretrained_model_name_or_path: str | os.PathLike,
- cache_dir: str | os.PathLike | None = None,
- force_download: bool = False,
- local_files_only: bool = False,
- token: str | bool | None = None,
- revision: str = "main",
- **kwargs,
- ) -> ImageProcessorType:
- r"""
- Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
-
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
-
- - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
- huggingface.co.
- - a path to a *directory* containing a image processor file saved using the
- [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
- `./my_model_directory/`.
- - a path or url to a saved image processor JSON *file*, e.g.,
- `./my_model_directory/preprocessor_config.json`.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model image processor should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the image processor files and override the cached versions if
- they exist.
- proxies (`dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- token (`str` or `bool`, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
- the token generated when running `hf auth login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
-
-
-
-
- To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
-
-
-
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
- If `False`, then this function returns just the final image processor object. If `True`, then this
- functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
- consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
- `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
- subfolder (`str`, *optional*, defaults to `""`):
- In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
- specify the folder name here.
- kwargs (`dict[str, Any]`, *optional*):
- The values in kwargs of any keys which are image processor attributes will be used to override the
- loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
- controlled by the `return_unused_kwargs` keyword parameter.
-
- Returns:
- A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
-
- Examples:
-
- ```python
- # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
- # derived class: *CLIPImageProcessor*
- image_processor = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32"
- ) # Download image_processing_config from huggingface.co and cache.
- image_processor = CLIPImageProcessor.from_pretrained(
- "./test/saved_model/"
- ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
- image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
- image_processor = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32", do_normalize=False, foo=False
- )
- assert image_processor.do_normalize is False
- image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
- "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
- )
- assert image_processor.do_normalize is False
- assert unused_kwargs == {"foo": False}
- ```"""
- kwargs["cache_dir"] = cache_dir
- kwargs["force_download"] = force_download
- kwargs["local_files_only"] = local_files_only
- kwargs["revision"] = revision
-
- if token is not None:
- kwargs["token"] = token
-
- image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
-
- return cls.from_dict(image_processor_dict, **kwargs)
-
- def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
- """
- Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
- [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
-
- Args:
- save_directory (`str` or `os.PathLike`):
- Directory where the image processor JSON file will be saved (will be created if it does not exist).
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
- repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
- namespace).
- kwargs (`dict[str, Any]`, *optional*):
- Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
- """
- if os.path.isfile(save_directory):
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
-
- os.makedirs(save_directory, exist_ok=True)
-
- if push_to_hub:
- commit_message = kwargs.pop("commit_message", None)
- repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
- repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
- files_timestamps = self._get_files_timestamps(save_directory)
-
- # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
- # loaded from the Hub.
- if self._auto_class is not None:
- custom_object_save(self, save_directory, config=self)
-
- # If we save using the predefined names, we can load using `from_pretrained`
- output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
-
- self.to_json_file(output_image_processor_file)
- logger.info(f"Image processor saved in {output_image_processor_file}")
-
- if push_to_hub:
- self._upload_modified_files(
- save_directory,
- repo_id,
- files_timestamps,
- commit_message=commit_message,
- token=kwargs.get("token"),
- )
-
- return [output_image_processor_file]
+ _config_name = IMAGE_PROCESSOR_NAME
+ _type_key = "image_processor_type"
+ _nested_config_keys = ["image_processor"]
+ _auto_class_default = "AutoImageProcessor"
+ _file_type_label = "image processor"
+ _excluded_dict_keys = set()
+ _extra_init_pops = ["feature_extractor_type"]
+ _config_filename_kwarg = "image_processor_filename"
+ _subfolder_default = ""
@classmethod
def get_image_processor_dict(
@@ -248,227 +85,7 @@ def get_image_processor_dict(
Returns:
`tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
"""
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", "")
- image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
-
- from_pipeline = kwargs.pop("_from_pipeline", None)
- from_auto_class = kwargs.pop("_from_auto", False)
-
- user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
- if from_pipeline is not None:
- user_agent["using_pipeline"] = from_pipeline
-
- if is_offline_mode() and not local_files_only:
- logger.info("Offline mode: forcing local_files_only=True")
- local_files_only = True
-
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
- is_local = os.path.isdir(pretrained_model_name_or_path)
- if os.path.isdir(pretrained_model_name_or_path):
- image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
- if os.path.isfile(pretrained_model_name_or_path):
- resolved_image_processor_file = pretrained_model_name_or_path
- resolved_processor_file = None
- is_local = True
- else:
- image_processor_file = image_processor_filename
- try:
- resolved_processor_file = cached_file(
- pretrained_model_name_or_path,
- filename=PROCESSOR_NAME,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _raise_exceptions_for_missing_entries=False,
- )
- resolved_image_processor_file = cached_file(
- pretrained_model_name_or_path,
- filename=image_processor_file,
- cache_dir=cache_dir,
- force_download=force_download,
- proxies=proxies,
- local_files_only=local_files_only,
- token=token,
- user_agent=user_agent,
- revision=revision,
- subfolder=subfolder,
- _raise_exceptions_for_missing_entries=False,
- )
- except OSError:
- # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
- # the original exception.
- raise
- except Exception:
- # For any other exception, we throw a generic error.
- raise OSError(
- f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
- " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a {image_processor_filename} file"
- )
-
- # Load image_processor dict. Priority goes as (nested config if found -> image processor config)
- # We are downloading both configs because almost all models have a `processor_config.json` but
- # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
- image_processor_dict = None
- if resolved_processor_file is not None:
- processor_dict = safe_load_json_file(resolved_processor_file)
- if "image_processor" in processor_dict:
- image_processor_dict = processor_dict["image_processor"]
-
- if resolved_image_processor_file is not None and image_processor_dict is None:
- image_processor_dict = safe_load_json_file(resolved_image_processor_file)
-
- if image_processor_dict is None:
- raise OSError(
- f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
- " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
- f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
- f" directory containing a {image_processor_filename} file"
- )
-
- if is_local:
- logger.info(f"loading configuration file {resolved_image_processor_file}")
- else:
- logger.info(
- f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
- )
-
- return image_processor_dict, kwargs
-
- @classmethod
- def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs):
- """
- Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
-
- Args:
- image_processor_dict (`dict[str, Any]`):
- Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
- retrieved from a pretrained checkpoint by leveraging the
- [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
- kwargs (`dict[str, Any]`):
- Additional parameters from which to initialize the image processor object.
-
- Returns:
- [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
- parameters.
- """
- image_processor_dict = image_processor_dict.copy()
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
- image_processor_dict.update({k: v for k, v in kwargs.items() if k in cls.valid_kwargs.__annotations__})
- image_processor = cls(**image_processor_dict)
-
- # Apply extra kwargs to instance (BC for remote code, e.g. phi4_multimodal)
- extra_keys = []
- for key in reversed(list(kwargs.keys())):
- if hasattr(image_processor, key) and key not in cls.valid_kwargs.__annotations__:
- setattr(image_processor, key, kwargs.pop(key, None))
- extra_keys.append(key)
- if extra_keys:
- logger.warning_once(
- f"Image processor {cls.__name__}: kwargs {extra_keys} were applied for backward compatibility. "
- f"To avoid this warning, add them to valid_kwargs: create a custom TypedDict extending "
- f"ImagesKwargs with these keys and set it as the `valid_kwargs` class attribute."
- )
-
- logger.info(f"Image processor {image_processor}")
- if return_unused_kwargs:
- return image_processor, kwargs
- else:
- return image_processor
-
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary.
-
- Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["image_processor_type"] = self.__class__.__name__
-
- return output
-
- @classmethod
- def from_json_file(cls, json_file: str | os.PathLike):
- """
- Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
- file of parameters.
-
- Args:
- json_file (`str` or `os.PathLike`):
- Path to the JSON file containing the parameters.
-
- Returns:
- A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
- instantiated from that JSON file.
- """
- with open(json_file, encoding="utf-8") as reader:
- text = reader.read()
- image_processor_dict = json.loads(text)
- return cls(**image_processor_dict)
-
- def to_json_string(self) -> str:
- """
- Serializes this instance to a JSON string.
-
- Returns:
- `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
- """
- dictionary = self.to_dict()
-
- for key, value in dictionary.items():
- if isinstance(value, np.ndarray):
- dictionary[key] = value.tolist()
-
- return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
-
- def to_json_file(self, json_file_path: str | os.PathLike):
- """
- Save this instance to a JSON file.
-
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this image_processor instance's parameters will be saved.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- writer.write(self.to_json_string())
-
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
-
- @classmethod
- def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
- """
- Register this class with a given auto class. This should only be used for custom image processors as the ones
- in the library are already mapped with `AutoImageProcessor `.
-
-
-
- Args:
- auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
- The auto class to register this new image processor with.
- """
- if not isinstance(auto_class, str):
- auto_class = auto_class.__name__
-
- import transformers.models.auto as auto_module
-
- if not hasattr(auto_module, auto_class):
- raise ValueError(f"{auto_class} is not a valid auto class.")
-
- cls._auto_class = auto_class
+ return cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
def fetch_images(self, image_url_or_urls: str | list[str] | list[list[str]]):
"""
diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py
index 9756866b6333..9fb1d9761ee1 100644
--- a/src/transformers/image_processing_utils.py
+++ b/src/transformers/image_processing_utils.py
@@ -14,7 +14,6 @@
import math
from collections.abc import Iterable
-from copy import deepcopy
from functools import partial
from typing import Any
@@ -193,20 +192,11 @@ class MyImageProcessor(TorchvisionBackend):
def __init__(self, **kwargs: Unpack[ImagesKwargs]):
super().__init__(**kwargs)
- attributes = {}
- for key in self.valid_kwargs.__annotations__:
- kwarg = kwargs.pop(key, None)
- if kwarg is not None:
- attributes[key] = kwarg
- else:
- attributes[key] = deepcopy(getattr(self, key, None))
+ attributes = {key: getattr(self, key) for key in self._valid_kwargs_names}
attributes = self._standardize_kwargs(**attributes)
for key, value in attributes.items():
setattr(self, key, value)
- # get valid kwargs names
- self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
-
def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature:
"""Preprocess an image or a batch of images."""
return self.preprocess(images, *args, **kwargs)
diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py
new file mode 100644
index 000000000000..7d9ba2cddec7
--- /dev/null
+++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py
@@ -0,0 +1,72 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ return_padding_mask = False
+ do_batch_spectrogram = False
+
+ max_length_frames = 1024
+ do_normalize = True
+
+ # AudioSet normalization constants
+ ast_mean = -4.2677393
+ ast_std = 4.5689974
+
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ win_length=400,
+ hop_length=160,
+ window_fn="hann_window",
+ power=2.0,
+ center=False,
+ periodic=False,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=128,
+ f_min=20.0,
+ f_max=8000.0,
+ mel_scale="kaldi",
+ triangularize_in_mel_space=True,
+ ),
+ log_mode="log",
+ preemphasis=0.97,
+ remove_dc_offset=True,
+ mel_floor=1.192092955078125e-07,
+ )
+
+ def extract_spectrogram(self, audio, **kwargs):
+ return [self._kaldi_fbank(waveform, num_mel_bins=128, window_type="hanning") for waveform in audio]
+
+ def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of):
+ # Always pad/truncate to max_length_frames regardless of caller's padding args
+ return super()._pad_features(features, "max_length", self.max_length_frames, True, pad_to_multiple_of)
+
+ def _postprocess_output(self, output, **kwargs):
+ # Rename to audio_values (AST convention) and apply AudioSet normalization
+ features = output.pop("audio_features")
+ if self.do_normalize:
+ features = (features - self.ast_mean) / (self.ast_std * 2)
+ output["audio_values"] = features
+ return output
+
+
+__all__ = ["AudioSpectrogramTransformerAudioProcessor"]
diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
index ee69d1d0b991..80faf5663dec 100644
--- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
+++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py
@@ -11,225 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for Audio Spectrogram Transformer.
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_audio_spectrogram_transformer import AudioSpectrogramTransformerAudioProcessor
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, is_speech_available, is_torch_available, logging
-
-
-if is_speech_available():
- import torchaudio.compliance.kaldi as ta_kaldi
-
-if is_torch_available():
- import torch
-
-
-logger = logging.get_logger(__name__)
-
-
-class ASTFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Audio Spectrogram Transformer (AST) feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
- otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- num_mel_bins (`int`, *optional*, defaults to 128):
- Number of Mel-frequency bins.
- max_length (`int`, *optional*, defaults to 1024):
- Maximum length to which to pad/truncate the extracted features.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to normalize the log-Mel features using `mean` and `std`.
- mean (`float`, *optional*, defaults to -4.2677393):
- The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default.
- std (`float`, *optional*, defaults to 4.5689974):
- The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation
- by default.
- return_attention_mask (`bool`, *optional*, defaults to `False`):
- Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
- """
-
- model_input_names = ["input_values", "attention_mask"]
-
- def __init__(
- self,
- feature_size=1,
- sampling_rate=16000,
- num_mel_bins=128,
- max_length=1024,
- padding_value=0.0,
- do_normalize=True,
- mean=-4.2677393,
- std=4.5689974,
- return_attention_mask=False,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.num_mel_bins = num_mel_bins
- self.max_length = max_length
- self.do_normalize = do_normalize
- self.mean = mean
- self.std = std
- self.return_attention_mask = return_attention_mask
-
- if not is_speech_available():
- mel_filters = mel_filter_bank(
- num_frequency_bins=257,
- num_mel_filters=self.num_mel_bins,
- min_frequency=20,
- max_frequency=sampling_rate // 2,
- sampling_rate=sampling_rate,
- norm=None,
- mel_scale="kaldi",
- triangularize_in_mel_space=True,
- )
-
- self.mel_filters = mel_filters
- self.window = window_function(400, "hann", periodic=False)
-
- def _extract_fbank_features(
- self,
- waveform: np.ndarray,
- max_length: int,
- ) -> np.ndarray:
- """
- Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
- and hence the waveform should not be normalized before feature extraction.
- """
- # waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
- if is_speech_available():
- waveform = torch.from_numpy(waveform).unsqueeze(0)
- fbank = ta_kaldi.fbank(
- waveform,
- sample_frequency=self.sampling_rate,
- window_type="hanning",
- num_mel_bins=self.num_mel_bins,
- )
- else:
- waveform = np.squeeze(waveform)
- fbank = spectrogram(
- waveform,
- self.window,
- frame_length=400,
- hop_length=160,
- fft_length=512,
- power=2.0,
- center=False,
- preemphasis=0.97,
- mel_filters=self.mel_filters,
- log_mel="log",
- mel_floor=1.192092955078125e-07,
- remove_dc_offset=True,
- ).T
-
- fbank = torch.from_numpy(fbank)
-
- n_frames = fbank.shape[0]
- difference = max_length - n_frames
-
- # pad or truncate, depending on difference
- if difference > 0:
- pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
- fbank = pad_module(fbank)
- elif difference < 0:
- fbank = fbank[0:max_length, :]
-
- fbank = fbank.numpy()
-
- return fbank
-
- def normalize(self, input_values: np.ndarray) -> np.ndarray:
- return (input_values - (self.mean)) / (self.std * 2)
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- sampling_rate: int | None = None,
- return_tensors: str | TensorType | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- """
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [raw_speech]
-
- # extract fbank features and pad/truncate to max_length
- features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech]
-
- # convert into BatchFeature
- padded_inputs = BatchFeature({"input_values": features})
-
- # make sure list is in array format
- input_values = padded_inputs.get("input_values")
- if isinstance(input_values[0], list):
- padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values]
-
- # normalization
- if self.do_normalize:
- padded_inputs["input_values"] = [self.normalize(feature) for feature in input_values]
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+ASTFeatureExtractor = deprecated_feature_extractor(
+ AudioSpectrogramTransformerAudioProcessor, "ASTFeatureExtractor"
+)
__all__ = ["ASTFeatureExtractor"]
diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py
new file mode 100644
index 000000000000..20525c3bfce3
--- /dev/null
+++ b/src/transformers/models/clap/audio_processing_clap.py
@@ -0,0 +1,129 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+from ...utils import PaddingStrategy
+
+
+class ClapAudioProcessor(NumpyAudioBackend):
+ sample_rate = 48000
+ force_mono = True
+ max_length = 480000
+ truncation_mode = "rand_trunc" # "fusion" or "rand_trunc"
+
+ _mel_configs = {
+ "rand_trunc": MelScaleConfig(n_mels=64, f_min=50, f_max=14000, mel_scale="slaney", norm="slaney", frequency_bin_mode="linspace"),
+ "fusion": MelScaleConfig(n_mels=64, f_min=50, f_max=14000, mel_scale="htk", frequency_bin_mode="linspace"),
+ }
+
+ def __init__(self, **kwargs):
+ truncation_mode = kwargs.pop("truncation_mode", self.truncation_mode)
+ self.truncation_mode = truncation_mode
+ self.spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(n_fft=1024, hop_length=480, power=2.0),
+ mel_scale_config=self._mel_configs[truncation_mode],
+ log_mode="dB",
+ computation_dtype="float64",
+ )
+ super().__init__(**kwargs)
+ # rand_trunc: base class truncates via pad() → _truncate_single (random offset)
+ # fusion: no pre-truncation; full mel is extracted then chunked
+ self.truncation = truncation_mode == "rand_trunc"
+
+ def _get_padding_strategies(self, padding=False, max_length=None):
+ # CLAP always pads to max_length, not to the longest in the batch
+ if padding is True and max_length is not None:
+ return PaddingStrategy.MAX_LENGTH
+ return super()._get_padding_strategies(padding=padding, max_length=max_length)
+
+ def pad(self, audio, *args, **kwargs):
+ self._is_longer_flags = []
+ return super().pad(audio, *args, **kwargs)
+
+ def _truncate_single(self, audio_el, max_length):
+ """Random-offset truncation for rand_trunc mode, also tracks which samples were longer."""
+ self._is_longer_flags.append(audio_el.shape[-1] > max_length)
+ if audio_el.shape[-1] > max_length:
+ idx = np.random.randint(0, audio_el.shape[-1] - max_length + 1)
+ return audio_el[..., idx : idx + max_length]
+ return audio_el
+
+ def extract_spectrogram(self, audio, *, spectrogram_config=None, audio_ranges=None, **kwargs):
+ """Extract mel spectrogram and shape output (1 view for rand_trunc, 4 for fusion)."""
+ is_fusion = self.truncation_mode == "fusion"
+ chunk_frames = self.max_length // self.spectrogram_config.stft_config.hop_length + 1
+
+ if isinstance(audio, np.ndarray) and audio.ndim == 2:
+ waveforms = list(audio)
+ elif isinstance(audio, np.ndarray) and audio.ndim == 1:
+ waveforms = [audio]
+ else:
+ waveforms = audio
+
+ mels = []
+ is_longer = []
+ for waveform in waveforms:
+ mel = super().extract_spectrogram(waveform, spectrogram_config=self.spectrogram_config).T # (time, n_mels)
+ total_frames = mel.shape[0]
+
+ if is_fusion and total_frames > chunk_frames:
+ mels.append(self._random_mel_fusion(mel, total_frames, chunk_frames))
+ is_longer.append(True)
+ elif is_fusion:
+ mels.append(np.stack([mel, mel, mel, mel], axis=0))
+ is_longer.append(False)
+ else:
+ mels.append(mel[np.newaxis])
+ is_longer.append(False)
+
+ if is_fusion:
+ self._is_longer_flags = is_longer
+ return mels
+
+ def _random_mel_fusion(self, mel, total_frames, chunk_frames):
+ import torch
+
+ ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
+ if len(ranges[1]) == 0:
+ ranges[1] = [0]
+ if len(ranges[2]) == 0:
+ ranges[2] = [0]
+ idx_front = np.random.choice(ranges[0])
+ idx_middle = np.random.choice(ranges[1])
+ idx_back = np.random.choice(ranges[2])
+
+ mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
+ mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
+ mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
+
+ mel_tensor = torch.tensor(mel[None, None, :])
+ mel_shrink = torch.nn.functional.interpolate(
+ mel_tensor, size=[chunk_frames, 64], mode="bilinear", align_corners=False
+ )
+ mel_shrink = mel_shrink[0][0].numpy()
+ return np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
+
+ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config):
+ """Return CLAP's is_longer flag instead of a standard attention mask."""
+ is_longer = getattr(self, "_is_longer_flags", None) or [False] * len(audio_ranges)
+ if self.truncation_mode == "fusion" and sum(is_longer) == 0:
+ rand_idx = np.random.randint(0, len(is_longer))
+ is_longer[rand_idx] = True
+ return {"is_longer": [[longer] for longer in is_longer]}
+
+
+__all__ = ["ClapAudioProcessor"]
diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py
index 8f0a34d2cf4e..79c3c9353825 100644
--- a/src/transformers/models/clap/feature_extraction_clap.py
+++ b/src/transformers/models/clap/feature_extraction_clap.py
@@ -11,354 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for CLAP."""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_clap import ClapAudioProcessor
-import copy
-from typing import Any
-
-import numpy as np
-import torch
-
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, logging
-from ...utils.import_utils import requires
-
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("torch",))
-class ClapFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a CLAP feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time
- Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 64):
- The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters
- (`n_mels`).
- sampling_rate (`int`, *optional*, defaults to 48000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves
- to warn users if the audio fed to the feature extractor does not have the same sampling rate.
- hop_length (`int`,*optional*, defaults to 480):
- Length of the overlapping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split
- in smaller `frames` with a step of `hop_length` between each frame.
- max_length_s (`int`, *optional*, defaults to 10):
- The maximum input length of the model in seconds. This is used to pad the audio.
- fft_window_size (`int`, *optional*, defaults to 1024):
- Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency
- resolution of the spectrogram. 400 means that the fourier transform is computed on windows of 400 samples.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- return_attention_mask (`bool`, *optional*, defaults to `False`):
- Whether or not the model should return the attention masks corresponding to the input.
- frequency_min (`float`, *optional*, defaults to 0):
- The lowest frequency of interest. The STFT will not be computed for values below this.
- frequency_max (`float`, *optional*, defaults to 14000):
- The highest frequency of interest. The STFT will not be computed for values above this.
- top_db (`float`, *optional*):
- The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the
- `audio_utils.power_to_db` function
- truncation (`str`, *optional*, defaults to `"fusion"`):
- Truncation pattern for long audio inputs. Two patterns are available:
- - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a
- downsampled version of the entire mel spectrogram.
- If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a copy
- of the original mel obtained from the padded audio.
- - `rand_trunc` will select a random crop of the mel spectrogram.
- padding (`str`, *optional*, defaults to `"repeatpad"`):
- Padding pattern for shorter audio inputs. Three patterns were originally implemented:
- - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
- - `repeat`: the audio is repeated and then cut to fit the `max_length`
- - `pad`: the audio is padded.
- """
-
- model_input_names = ["input_features", "is_longer"]
-
- def __init__(
- self,
- feature_size=64,
- sampling_rate=48_000,
- hop_length=480,
- max_length_s=10,
- fft_window_size=1024,
- padding_value=0.0,
- return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
- frequency_min: float = 0,
- frequency_max: float = 14_000,
- top_db: int | None = None,
- truncation: str = "fusion",
- padding: str = "repeatpad",
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
- self.top_db = top_db
- self.truncation = truncation
- self.padding = padding
- self.fft_window_size = fft_window_size
- self.nb_frequency_bins = (fft_window_size >> 1) + 1
- self.hop_length = hop_length
- self.max_length_s = max_length_s
- self.nb_max_samples = max_length_s * sampling_rate
- self.sampling_rate = sampling_rate
- self.frequency_min = frequency_min
- self.frequency_max = frequency_max
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=self.nb_frequency_bins,
- num_mel_filters=feature_size,
- min_frequency=frequency_min,
- max_frequency=frequency_max,
- sampling_rate=sampling_rate,
- norm=None,
- mel_scale="htk",
- )
- self.mel_filters_slaney = mel_filter_bank(
- num_frequency_bins=self.nb_frequency_bins,
- num_mel_filters=feature_size,
- min_frequency=frequency_min,
- max_frequency=frequency_max,
- sampling_rate=sampling_rate,
- norm="slaney",
- mel_scale="slaney",
- )
-
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary.
-
- Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, except for the
- mel filter banks, which do not need to be saved or printed as they are too long.
- """
- output = copy.deepcopy(self.__dict__)
- output["feature_extractor_type"] = self.__class__.__name__
- if "mel_filters" in output:
- del output["mel_filters"]
- if "mel_filters_slaney" in output:
- del output["mel_filters_slaney"]
- return output
-
- def _np_extract_fbank_features(self, waveform: np.ndarray, mel_filters: np.ndarray | None = None) -> np.ndarray:
- """
- Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter
- banks are used depending on the truncation pattern:
- - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from
- calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation`
- is set to `"fusion"`.
- - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used
- `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original
- implementation when the truncation mode is not `"fusion"`.
- """
- log_mel_spectrogram = spectrogram(
- waveform,
- window_function(self.fft_window_size, "hann"),
- frame_length=self.fft_window_size,
- hop_length=self.hop_length,
- power=2.0,
- mel_filters=mel_filters,
- log_mel="dB",
- )
- return log_mel_spectrogram.T
-
- def _random_mel_fusion(self, mel, total_frames, chunk_frames):
- ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3)
- if len(ranges[1]) == 0:
- # if the audio is too short, we just use the first chunk
- ranges[1] = [0]
- if len(ranges[2]) == 0:
- # if the audio is too short, we just use the first chunk
- ranges[2] = [0]
- # randomly choose index for each part
- idx_front = np.random.choice(ranges[0])
- idx_middle = np.random.choice(ranges[1])
- idx_back = np.random.choice(ranges[2])
-
- mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :]
- mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :]
- mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :]
-
- mel = torch.tensor(mel[None, None, :])
- mel_shrink = torch.nn.functional.interpolate(
- mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False
- )
- mel_shrink = mel_shrink[0][0].numpy()
- mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0)
- return mel_fusion
-
- def _get_input_mel(self, waveform: np.ndarray, max_length, truncation, padding) -> np.ndarray:
- """
- Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments.
- Four different path are possible:
- - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram
- will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram
- are then stacked together. They will later be used for `feature_fusion`.
- - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is
- padded based on `padding`.
- - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded
- based on `padding`, and is repeated `4` times.
- - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel
- spectrogram will be computed on a random crop of the waveform.
-
- """
- if waveform.shape[0] > max_length:
- if truncation == "rand_trunc":
- longer = True
- # random crop to max_length (for compatibility) -> this should be handled by self.pad
- overflow = len(waveform) - max_length
- idx = np.random.randint(0, overflow + 1)
- waveform = waveform[idx : idx + max_length]
- input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
- elif truncation == "fusion":
- mel = self._np_extract_fbank_features(waveform, self.mel_filters)
- chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed
- total_frames = mel.shape[0]
- if chunk_frames == total_frames:
- # there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length.
- # In this case, we just use the whole audio.
- input_mel = np.stack([mel, mel, mel, mel], axis=0)
- longer = False
- else:
- input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames)
- longer = True
- else:
- raise NotImplementedError(f"data_truncating {truncation} not implemented")
-
- else:
- longer = False
- # only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding
- if waveform.shape[0] < max_length:
- if padding == "repeat":
- n_repeat = int(max_length / len(waveform))
- waveform = np.tile(waveform, n_repeat + 1)[:max_length]
- if padding == "repeatpad":
- n_repeat = int(max_length / len(waveform))
- waveform = np.tile(waveform, n_repeat)
- waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0)
-
- if truncation == "fusion":
- input_mel = self._np_extract_fbank_features(waveform, self.mel_filters)
- input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0)
- else:
- input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :]
-
- return input_mel, longer
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: str | None = None,
- padding: str | None = None,
- max_length: int | None = None,
- sampling_rate: int | None = None,
- return_tensors: str | TensorType | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- truncation (`str`, *optional*):
- Truncation pattern for long audio inputs. Two patterns are available:
- - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
- a downsampled version of the entire mel spectrogram.
- If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a
- copy of the original mel obtained from the padded audio.
- - `rand_trunc` will select a random crop of the mel spectrogram.
- padding (`str`, *optional*):
- Padding pattern for shorter audio inputs. Three patterns were originally implemented:
- - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`.
- - `repeat`: the audio is repeated and then cut to fit the `max_length`
- - `pad`: the audio is padded.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
- - `'pt'`: Return PyTorch `torch.np.array` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- """
- truncation = truncation if truncation is not None else self.truncation
- padding = padding if padding else self.padding
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float64)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float64)
-
- # always return batch
- if not is_batched:
- raw_speech = [np.asarray(raw_speech)]
-
- # convert to mel spectrogram, truncate and pad if needed.
- padded_inputs = [
- self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding)
- for waveform in raw_speech
- ]
-
- input_mel = []
- is_longer = []
- for mel, longer in padded_inputs:
- input_mel.append(mel)
- is_longer.append(longer)
-
- if truncation == "fusion" and sum(is_longer) == 0:
- # if no audio is longer than 10s, then randomly select one audio to be longer
- rand_idx = np.random.randint(0, len(input_mel))
- is_longer[rand_idx] = True
-
- if isinstance(input_mel[0], list):
- input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel]
-
- # is_longer is a list of bool
- is_longer = [[longer] for longer in is_longer]
-
- input_features = {"input_features": input_mel, "is_longer": is_longer}
- input_features = BatchFeature(input_features)
-
- if return_tensors is not None:
- input_features = input_features.convert_to_tensors(return_tensors)
-
- return input_features
+ClapFeatureExtractor = deprecated_feature_extractor(ClapAudioProcessor, "ClapFeatureExtractor")
__all__ = ["ClapFeatureExtractor"]
diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py
new file mode 100644
index 000000000000..6272c795d1fd
--- /dev/null
+++ b/src/transformers/models/clvp/audio_processing_clvp.py
@@ -0,0 +1,61 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class ClvpAudioProcessor(NumpyAudioBackend):
+ sample_rate = 22050
+ force_mono = True
+ max_length = 132300 # 6 seconds at 22050 Hz
+ truncation = True
+ mask_level = "audio"
+
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=1024,
+ hop_length=256,
+ window_fn="hann_window",
+ power=2.0,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ f_min=0.0,
+ f_max=8000.0,
+ norm="slaney",
+ mel_scale="htk",
+ frequency_bin_mode="linspace",
+ ),
+ log_mode="log",
+ mel_floor=1e-5,
+ computation_dtype="float64",
+ )
+
+ def __init__(self, mel_norms=None, **kwargs):
+ super().__init__(**kwargs)
+ self.mel_norms = mel_norms
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs):
+ # Compute log and mel_norms division in float64 before casting to float32
+ # to match the legacy feature extractor's precision
+ mel_floor = spectrogram_config.mel_floor
+ features = np.log(np.maximum(mel_floor, features))
+ if self.mel_norms is not None:
+ features = features / np.array(self.mel_norms)[:, None]
+ return features.astype(np.float32)
+
+__all__ = ["ClvpAudioProcessor"]
diff --git a/src/transformers/models/clvp/feature_extraction_clvp.py b/src/transformers/models/clvp/feature_extraction_clvp.py
index cc39e6aca677..e5966a9b2f02 100644
--- a/src/transformers/models/clvp/feature_extraction_clvp.py
+++ b/src/transformers/models/clvp/feature_extraction_clvp.py
@@ -11,227 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_clvp import ClvpAudioProcessor
-"""
-Feature extractor class for CLVP
-"""
-
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class ClvpFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a CLVP feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts log-mel-spectrogram features from raw speech using a custom numpy implementation of the `Short
- Time Fourier Transform` which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 80):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 22050):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- default_audio_length (`int`, *optional*, defaults to 6):
- The default length of raw audio in seconds. If `max_length` is not set during `__call__` then it will
- automatically be set to default_audio_length * `self.sampling_rate`.
- hop_length (`int`, *optional*, defaults to 256):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- chunk_length (`int`, *optional*, defaults to 30):
- The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio
- sequences.
- n_fft (`int`, *optional*, defaults to 1024):
- Size of the Fourier transform.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- mel_norms (`list` of length `feature_size`, *optional*):
- If `mel_norms` is provided then it will be used to normalize the log-mel spectrograms along each
- mel-filter.
- return_attention_mask (`bool`, *optional*, defaults to `False`):
- Whether to return the attention mask. If left to the default, it will return the attention mask.
-
- [What are attention masks?](../glossary#attention-mask)
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=80,
- sampling_rate=22050,
- default_audio_length=6,
- hop_length=256,
- chunk_length=30,
- n_fft=1024,
- padding_value=0.0,
- mel_norms=None,
- return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
- self.n_fft = n_fft
- self.hop_length = hop_length
- self.chunk_length = chunk_length
- self.n_samples = chunk_length * sampling_rate
- self.nb_max_frames = self.n_samples // hop_length
- self.sampling_rate = sampling_rate
- self.default_audio_length = default_audio_length
- self.mel_norms = mel_norms
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=1 + (n_fft // 2),
- num_mel_filters=feature_size,
- min_frequency=0.0,
- max_frequency=8000.0,
- sampling_rate=sampling_rate,
- norm="slaney",
- mel_scale="htk",
- )
-
- def _np_extract_fbank_features(self, waveform: np.ndarray) -> np.ndarray:
- """
- This method first computes the log-mel spectrogram of the provided audio then applies normalization along the
- each mel-filterbank, if `mel_norms` is provided.
- """
- log_spec = spectrogram(
- waveform,
- window_function(self.n_fft, "hann"),
- frame_length=self.n_fft,
- hop_length=self.hop_length,
- power=2.0,
- mel_filters=self.mel_filters,
- log_mel=None,
- )
-
- log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None))
-
- if self.mel_norms is not None:
- log_spec = log_spec / np.array(self.mel_norms)[:, None]
-
- return log_spec
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- sampling_rate: int | None = None,
- truncation: bool = True,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = True,
- padding: str | None = "max_length",
- max_length: int | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- `ClvpFeatureExtractor` is used to extract various voice specific properties such as the pitch and tone of the
- voice, speaking speed, and even speaking defects like a lisp or stuttering from a sample voice or `raw_speech`.
-
- First the voice is padded or truncated in a way such that it becomes a waveform of `self.default_audio_length`
- seconds long and then the log-mel spectrogram is extracted from it.
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether to return the attention mask. If left to the default, it will return the attention mask.
-
- [What are attention masks?](../glossary#attention-mask)
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values / vectors.
- max_length (`int`, *optional*):
- The maximum input length of the inputs.
- """
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [np.asarray([raw_speech]).T]
-
- batched_speech = BatchFeature({"input_features": raw_speech})
-
- max_length = self.default_audio_length * self.sampling_rate if max_length is None else max_length
-
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- )
-
- # make sure list is in array format
- input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
-
- input_features = [
- self._np_extract_fbank_features(waveform).astype(np.float32) for waveform in input_features[0]
- ]
-
- if isinstance(input_features[0], list):
- padded_inputs["input_features"] = [np.asarray(feature) for feature in input_features]
- else:
- padded_inputs["input_features"] = input_features
-
- return padded_inputs.convert_to_tensors(return_tensors)
+ClvpFeatureExtractor = deprecated_feature_extractor(ClvpAudioProcessor, "ClvpFeatureExtractor")
__all__ = ["ClvpFeatureExtractor"]
diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py
new file mode 100644
index 000000000000..f0a27bd57555
--- /dev/null
+++ b/src/transformers/models/dac/audio_processing_dac.py
@@ -0,0 +1,24 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import NumpyAudioBackend
+
+
+class DacAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ add_channel_dim = True
+
+
+__all__ = ["DacAudioProcessor"]
diff --git a/src/transformers/models/dac/feature_extraction_dac.py b/src/transformers/models/dac/feature_extraction_dac.py
index 7f910f57f09f..f255d22ebba5 100644
--- a/src/transformers/models/dac/feature_extraction_dac.py
+++ b/src/transformers/models/dac/feature_extraction_dac.py
@@ -11,160 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for DAC"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_dac import DacAudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class DacFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs an Dac feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used for padding.
- hop_length (`int`, *optional*, defaults to 512):
- Overlap length between successive windows.
- """
-
- model_input_names = ["input_values", "n_quantizers"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 16000,
- padding_value: float = 0.0,
- hop_length: int = 512,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.hop_length = hop_length
-
- def __call__(
- self,
- raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy | None = None,
- truncation: bool | None = False,
- max_length: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
- `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
- (`feature_size = 2`).
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- truncation (`bool`, *optional*, defaults to `False`):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if padding and truncation:
- raise ValueError("Both padding and truncation were set. Make sure you only set one.")
- elif padding is None:
- # by default let's pad the inputs
- padding = True
-
- is_batched = bool(
- isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
- elif not is_batched and not isinstance(raw_audio, np.ndarray):
- raw_audio = np.asarray(raw_audio, dtype=np.float32)
- elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
- raw_audio = raw_audio.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_audio = [np.asarray(raw_audio).T]
-
- # verify inputs are valid
- for idx, example in enumerate(raw_audio):
- if example.ndim > 2:
- raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
- if self.feature_size == 1 and example.ndim != 1:
- raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
- if self.feature_size == 2:
- raise ValueError("Stereo audio isn't supported for now")
-
- input_values = BatchFeature({"input_values": raw_audio})
-
- # normal padding on batch
- padded_inputs = self.pad(
- input_values,
- max_length=max_length,
- truncation=truncation,
- padding=padding,
- return_attention_mask=padding,
- pad_to_multiple_of=self.hop_length,
- )
- if padding:
- padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
- if padding:
- padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
-
- input_values = []
- for example in padded_inputs.pop("input_values"):
- if self.feature_size == 1:
- example = example[..., None]
- input_values.append(example.T)
-
- padded_inputs["input_values"] = input_values
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+DacFeatureExtractor = deprecated_feature_extractor(DacAudioProcessor, "DacFeatureExtractor")
__all__ = ["DacFeatureExtractor"]
diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py
new file mode 100644
index 000000000000..e1b7b0301e71
--- /dev/null
+++ b/src/transformers/models/dia/audio_processing_dia.py
@@ -0,0 +1,25 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import NumpyAudioBackend
+
+
+class DiaAudioProcessor(NumpyAudioBackend):
+ sample_rate = 44100
+ force_mono = True
+ add_channel_dim = True
+ pad_to_multiple_of = 512
+
+
+__all__ = ["DiaAudioProcessor"]
diff --git a/src/transformers/models/dia/feature_extraction_dia.py b/src/transformers/models/dia/feature_extraction_dia.py
index eda1ead6e014..d358589b4282 100644
--- a/src/transformers/models/dia/feature_extraction_dia.py
+++ b/src/transformers/models/dia/feature_extraction_dia.py
@@ -11,169 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for Dia"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_dia import DiaAudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class DiaFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs an Dia feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used for padding.
- hop_length (`int`, *optional*, defaults to 512):
- Overlap length between successive windows.
- """
-
- model_input_names = ["input_values", "n_quantizers"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 16000,
- padding_value: float = 0.0,
- hop_length: int = 512,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.hop_length = hop_length
-
- def __call__(
- self,
- raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy | None = None,
- truncation: bool | None = False,
- max_length: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
- `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
- (`feature_size = 2`).
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- truncation (`bool`, *optional*, defaults to `False`):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if padding and truncation:
- raise ValueError("Both padding and truncation were set. Make sure you only set one.")
- elif padding is None:
- # by default let's pad the inputs
- padding = True
-
- is_batched = bool(
- isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
- elif not is_batched and not isinstance(raw_audio, np.ndarray):
- raw_audio = np.asarray(raw_audio, dtype=np.float32)
- elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
- raw_audio = raw_audio.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_audio = [np.asarray(raw_audio).T]
-
- # convert stereo to mono if necessary, unique to Dia
- for idx, example in enumerate(raw_audio):
- if self.feature_size == 2 and example.ndim == 2:
- raw_audio[idx] = np.mean(example, -1)
-
- # verify inputs are valid
- for idx, example in enumerate(raw_audio):
- if example.ndim > 2:
- raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
- if self.feature_size == 1 and example.ndim != 1:
- raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
- if self.feature_size == 2 and example.ndim != 1: # note the conversion before
- raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
-
- input_values = BatchFeature({"input_values": raw_audio})
-
- # temporarily treat it as if we were mono as we also convert stereo to mono
- original_feature_size = self.feature_size
- self.feature_size = 1
-
- # normal padding on batch
- padded_inputs = self.pad(
- input_values,
- max_length=max_length,
- truncation=truncation,
- padding=padding,
- return_attention_mask=True,
- pad_to_multiple_of=self.hop_length,
- )
- padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
-
- input_values = []
- for example in padded_inputs.pop("input_values"):
- if self.feature_size == 1:
- example = example[..., None]
- input_values.append(example.T)
-
- padded_inputs["input_values"] = input_values
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- # rewrite back to original feature size
- self.feature_size = original_feature_size
-
- return padded_inputs
+DiaFeatureExtractor = deprecated_feature_extractor(DiaAudioProcessor, "DiaFeatureExtractor")
__all__ = ["DiaFeatureExtractor"]
diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py
new file mode 100644
index 000000000000..022a7e145313
--- /dev/null
+++ b/src/transformers/models/encodec/audio_processing_encodec.py
@@ -0,0 +1,24 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import NumpyAudioBackend
+
+
+class EncodecAudioProcessor(NumpyAudioBackend):
+ sample_rate = 24000
+ force_mono = True
+ add_channel_dim = True
+
+
+__all__ = ["EncodecAudioProcessor"]
diff --git a/src/transformers/models/encodec/feature_extraction_encodec.py b/src/transformers/models/encodec/feature_extraction_encodec.py
index 383936000243..2f1644ac912a 100644
--- a/src/transformers/models/encodec/feature_extraction_encodec.py
+++ b/src/transformers/models/encodec/feature_extraction_encodec.py
@@ -11,195 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for EnCodec."""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_encodec import EncodecAudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class EncodecFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs an EnCodec feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Instantiating a feature extractor with the defaults will yield a similar configuration to that of the
- [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
- sampling_rate (`int`, *optional*, defaults to 24000):
- The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values.
- chunk_length_s (`float`, *optional*):
- If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
- overlap (`float`, *optional*):
- Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
- formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
- """
-
- model_input_names = ["input_values", "padding_mask"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 24000,
- padding_value: float = 0.0,
- chunk_length_s: float | None = None,
- overlap: float | None = None,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.chunk_length_s = chunk_length_s
- self.overlap = overlap
-
- # This is a property because you might want to change the chunk_length_s on the fly
- @property
- def chunk_length(self) -> int | None:
- if self.chunk_length_s is None:
- return None
- else:
- return int(self.chunk_length_s * self.sampling_rate)
-
- # This is a property because you might want to change the chunk_length_s on the fly
- @property
- def chunk_stride(self) -> int | None:
- if self.chunk_length_s is None or self.overlap is None:
- return None
- else:
- return max(1, int((1.0 - self.overlap) * self.chunk_length))
-
- def __call__(
- self,
- raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy | None = None,
- truncation: bool | None = False,
- max_length: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
- `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
- (`feature_size = 2`).
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- truncation (`bool`, *optional*, defaults to `False`):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if padding and truncation:
- raise ValueError("Both padding and truncation were set. Make sure you only set one.")
- elif padding is None:
- # by default let's pad the inputs
- padding = True
-
- is_batched = bool(
- isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
- elif not is_batched and not isinstance(raw_audio, np.ndarray):
- raw_audio = np.asarray(raw_audio, dtype=np.float32)
- elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
- raw_audio = raw_audio.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_audio = [np.asarray(raw_audio).T]
-
- # verify inputs are valid
- for idx, example in enumerate(raw_audio):
- if example.ndim > 2:
- raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
- if self.feature_size == 1 and example.ndim != 1:
- raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
- if self.feature_size == 2 and example.shape[-1] != 2:
- raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
-
- padded_inputs = None
- input_values = BatchFeature({"input_values": raw_audio})
- if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
- if truncation:
- max_length = min(array.shape[0] for array in raw_audio)
- nb_step = int(np.floor(max_length / self.chunk_stride))
- max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
- elif padding:
- max_length = max(array.shape[0] for array in raw_audio)
- nb_step = int(np.ceil(max_length / self.chunk_stride))
- max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
- padding = "max_length"
- else:
- padded_inputs = input_values
-
- # normal padding on batch
- if padded_inputs is None:
- padded_inputs = self.pad(
- input_values,
- max_length=max_length,
- truncation=truncation,
- padding=padding,
- return_attention_mask=padding,
- )
- if padding:
- padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
-
- input_values = []
- for example in padded_inputs.pop("input_values"):
- if self.feature_size == 1:
- example = example[..., None]
- input_values.append(example.T)
-
- padded_inputs["input_values"] = input_values
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+EncodecFeatureExtractor = deprecated_feature_extractor(EncodecAudioProcessor, "EncodecFeatureExtractor")
__all__ = ["EncodecFeatureExtractor"]
diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py
new file mode 100644
index 000000000000..23f63b8bdb19
--- /dev/null
+++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py
@@ -0,0 +1,139 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+def _unfold(array, dimension, size, step):
+ """NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
+ if array.ndim == 1:
+ array = array[np.newaxis, :]
+ batch_size, original_length = array.shape
+ num_frames = (original_length - size) // step + 1
+ if num_frames <= 0:
+ return np.zeros((batch_size, 0, size), dtype=array.dtype)
+ output_shape = (batch_size, num_frames, size)
+ output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
+ return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
+
+
+class Gemma3nAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ max_length = 480000 # 30 seconds
+ truncation = True
+ pad_to_multiple_of = 128
+ preemphasis_htk_flavor = True
+
+ # n_fft = 1024 (512 frame_length → next power of 2 → 512 → ×2 fft_overdrive)
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=1024,
+ win_length=512,
+ hop_length=160,
+ power=1.0,
+ center=False,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=128,
+ f_min=125.0,
+ f_max=7600.0,
+ mel_scale="htk",
+ matmul_order="features_first",
+ ),
+ mel_floor=1e-5,
+ log_mode="log",
+ preemphasis=0.97,
+ computation_dtype="float64",
+ )
+
+ def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs):
+ super().__init__(**kwargs)
+
+ # Pre-compute window in float32 to match the upstream FE exactly
+ win_length = self.spectrogram_config.stft_config.win_length
+ hann_arange = np.arange(win_length, dtype=np.float32)
+ self.window = (0.5 * (1 - np.cos(2 * np.pi * hann_arange / win_length))).astype(np.float32)
+
+ n_mels = self.spectrogram_config.mel_scale_config.n_mels
+ if per_bin_mean is not None:
+ self.per_bin_mean = np.array(per_bin_mean).reshape(1, n_mels)
+ else:
+ self.per_bin_mean = None
+
+ if per_bin_stddev is not None:
+ self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, n_mels)
+ else:
+ self.per_bin_stddev = None
+
+ def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs):
+ """HTK-style preemphasis on frames extracted with an extra sample."""
+ preemphasis = spectrogram_config.preemphasis
+ if preemphasis is not None and preemphasis > 0.0:
+ if self.preemphasis_htk_flavor:
+ first = frames[..., :1] * (1.0 - preemphasis)
+ rest = frames[..., 1:-1] - preemphasis * frames[..., :-2]
+ return np.concatenate([first, rest], axis=-1)
+ else:
+ return frames[..., 1:] - preemphasis * frames[..., :-1]
+ return frames[..., :-1]
+
+ def _stft(self, audio, *, spectrogram_config, **kwargs):
+ """Unfold-based STFT with extra-sample framing for HTK preemphasis.
+
+ Extracts frames of win_length+1 so that _apply_frame_processing can
+ reduce them to win_length after HTK preemphasis. Returns (batch, time, freq).
+ """
+ stft_cfg = spectrogram_config.stft_config
+
+ frame_size_for_unfold = stft_cfg.win_length + 1
+ frames = _unfold(audio, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length)
+
+ frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs)
+
+ frames = frames * self.window
+ stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1)
+ return np.abs(stft)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs):
+ """Apply log compression and per-bin normalization."""
+ result = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs)
+
+ if self.per_bin_mean is not None:
+ result = result - self.per_bin_mean
+ if self.per_bin_stddev is not None:
+ result = result / self.per_bin_stddev
+
+ return result.astype(np.float32)
+
+ def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False):
+ """Frame count matching the FE's downsampled attention mask approach.
+
+ The upstream FE computes the mask by slicing the sample-level attention
+ mask every hop_length steps, which yields ceil(audio_length / hop_length)
+ valid frames rather than the unfold-based count.
+ """
+ hop_length = spectrogram_config.stft_config.hop_length
+ if include_center_frame:
+ # For padded length we still use the unfold formula to get total frames
+ frame_size = spectrogram_config.stft_config.win_length + 1
+ return (audio_lengths - frame_size) // hop_length + 1
+ # Match FE: attention_mask[::hop_length] gives this many valid entries
+ return (audio_lengths + hop_length - 1) // hop_length
+
+
+__all__ = ["Gemma3nAudioProcessor"]
diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py
index e2b24fb1f19f..1b111b76b49d 100644
--- a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py
+++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py
@@ -11,323 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_gemma3n import Gemma3nAudioProcessor
-import math
-from collections.abc import Sequence
-
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-def create_fb_matrix(
- n_freqs: int,
- f_min: float,
- f_max: float,
- n_mels: int,
- sample_rate: int,
- fft_length: int,
- norm: str | None = None,
-) -> np.ndarray:
- r"""Create a frequency bin conversion matrix (NumPy version).
-
- Args:
- n_freqs (int): Number of frequencies to highlight/apply
- f_min (float): Minimum frequency (Hz)
- f_max (float): Maximum frequency (Hz)
- n_mels (int): Number of mel filterbanks
- sample_rate (int): Sample rate of the audio waveform
- fft_length (int): FFT length
- norm (Optional[str]): If 'slaney', divide the triangular mel weights by
- the width of the mel band (area normalization). (Default: ``None``)
-
- Returns:
- np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``,
- ``n_mels``)
- meaning number of frequencies to highlight/apply to x the number of
- filterbanks.
- Each column is a filterbank so that assuming there is a matrix A of
- size (..., ``n_freqs``), the applied result would be
- ``A @ create_fb_matrix_numpy(A.shape[-1], ...)``.
- """
-
- if norm is not None and norm != "slaney":
- raise ValueError("norm must be one of None or 'slaney'")
-
- # freq bins
- all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length)
-
- # calculate mel freq bins
- # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
- m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0))
- m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0))
- m_pts = np.linspace(m_min, m_max, n_mels + 2)
- # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
- f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0)
- # calculate difference between each mel point and each stft freq point in Hz
- f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
- slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2)
- # create overlapping triangles
- zero = np.zeros(1, dtype=np.float32)
- down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
- up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
- fb = np.maximum(zero, np.minimum(down_slopes, up_slopes))
-
- if norm is not None and norm == "slaney":
- # Slaney-style mel is scaled to be approx constant energy per channel
- enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
- fb *= np.expand_dims(enorm, 0)
-
- return fb
-
-
-def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray:
- """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim."""
- if array.ndim != 2:
- raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).")
- if dimension != -1 and dimension != array.ndim - 1:
- raise ValueError("This unfold implementation only supports unfolding the last dimension.")
-
- batch_size, original_length = array.shape
- num_frames = (original_length - size) // step + 1
-
- if num_frames <= 0:
- return np.zeros((batch_size, 0, size), dtype=array.dtype)
-
- output_shape = (batch_size, num_frames, size)
- output_strides = (array.strides[0], array.strides[1] * step, array.strides[1])
-
- return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides)
-
-
-class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor):
- """An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037.
-
- Args:
- feature_size (`int`, *optional*, defaults to 128):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether to return the attention mask for the generated MEL spectrograms.
- frame_length_ms (`float`, *optional*, defaults to 32.0):
- The length of a frame in milliseconds.
- hop_length_ms (`float`, *optional*, defaults to 10.0):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- min_frequency (`float`, *optional*, defaults to 125.0):
- The minimum frequency (in Hz) for the Mel filterbank.
- max_frequency (`float`, *optional*, defaults to 7600.0):
- The maximum frequency (in Hz) for the Mel filterbank.
- preemphasis (`float`, *optional*, defaults to 0.97):
- The preemphasis coefficient.
- preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`):
- Whether to use HTK-style preemphasis.
- fft_overdrive (`bool`, *optional*, defaults to `True`):
- Whether to use FFT overdrive.
- dither (`float`, *optional*, defaults to 0.0):
- Adds dithering. In other words, adds a small Gaussian noise to each frame.
- E.g. use 0.0001 to add dithering with a normal distribution centered
- around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
- The value 0.0 means no dithering.
- Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
- the high log_mel_fbank values for signals with hard-zero sections,
- when VAD cutoff is present in the signal.
- input_scale_factor (`float`, *optional*, defaults to 1.0):
- Scaling factor applied to the input waveform.
- mel_floor (`float`, *optional*, defaults to 1e-05):
- Minimum value for Mel spectrograms to avoid log(0).
- per_bin_mean (`Optional[Sequence[float]]`, *optional*):
- Mean values for per-bin normalization.
- per_bin_stddev (`Optional[Sequence[float]]`, *optional*):
- Standard deviation values for per-bin normalization.
- """
-
- model_input_names = ["input_features", "input_features_mask"]
-
- def __init__(
- self,
- feature_size: int = 128,
- sampling_rate: int = 16_000,
- padding_value: float = 0.0,
- return_attention_mask: bool = True,
- frame_length_ms: float = 32.0,
- hop_length_ms: float = 10.0,
- min_frequency: float = 125.0,
- max_frequency: float = 7600.0,
- preemphasis: float = 0.97,
- preemphasis_htk_flavor: bool = True,
- fft_overdrive: bool = True,
- dither: float = 0.0,
- input_scale_factor: float = 1.0,
- mel_floor: float = 1e-5,
- per_bin_mean: Sequence[float] | None = None,
- per_bin_stddev: Sequence[float] | None = None,
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
-
- self.min_frequency = min_frequency
- self.max_frequency = max_frequency
- self.preemphasis = preemphasis
- self.preemphasis_htk_flavor = preemphasis_htk_flavor
- self.fft_overdrive = fft_overdrive
- self.dither = dither
- self.input_scale_factor = input_scale_factor
- self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0))
- self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0))
- self.mel_floor = np.array(mel_floor, dtype=np.float64)
-
- fft_length = 2 ** math.ceil(math.log2(self.frame_length))
- if self.fft_overdrive:
- fft_length *= 2
- self.fft_length = fft_length
-
- hann_arange = np.arange(self.frame_length, dtype=np.float32)
- window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))
- self.window = window.astype(np.float32)
-
- self.mel_filters = create_fb_matrix(
- n_freqs=self.fft_length // 2 + 1,
- f_min=min_frequency,
- f_max=max_frequency,
- n_mels=feature_size,
- sample_rate=self.sampling_rate,
- norm=None,
- fft_length=fft_length,
- )
-
- if per_bin_mean is not None:
- self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size)
- else:
- self.per_bin_mean = None
-
- if per_bin_stddev is not None:
- self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size)
- else:
- self.per_bin_stddev = None
-
- def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
- """"""
- if waveform.ndim == 1: # If single waveform, add batch dimension
- waveform = np.expand_dims(waveform, axis=0)
-
- if self.dither > 0.0:
- waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype)
-
- if self.input_scale_factor != 1.0:
- waveform = waveform * self.input_scale_factor
-
- frame_size_for_unfold = self.frame_length + 1
-
- # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold]
- frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length)
-
- if self.preemphasis > 0.0:
- if self.preemphasis_htk_flavor:
- first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis)
- rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2]
- frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1)
- else:
- frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1]
- else:
- frames = frames_to_process[..., :-1]
-
- frames = frames * self.window # Broadcasting window
- stft = np.fft.rfft(frames, n=self.fft_length, axis=-1)
-
- magnitude_spec = np.abs(stft)
-
- mel_spec = np.matmul(magnitude_spec, self.mel_filters)
- log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor))
-
- if self.per_bin_mean is not None:
- log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting
-
- if self.per_bin_stddev is not None:
- log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting
-
- mel_spectrogram = log_mel_spec.squeeze(0)
- mask = attention_mask[:: self.hop_length].astype(bool)
- # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why???
- return mel_spectrogram, mask[: mel_spectrogram.shape[0]]
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy = "longest",
- max_length: int | None = 480_000,
- truncation: bool = True,
- pad_to_multiple_of: int | None = 128,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = True,
- **kwargs,
- ) -> BatchFeature:
- """Creates a batch of MEL spectrograms from the provided raw speech.
-
- This implementation uses a different algorithm for windowing and preemphasis compared to the built-in
- `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this
- carefully when selecting an audio feature extractor, especially with pre-trained models.
-
- Args:
- raw_speech:
- The audio for which MEL spectrograms are created.
- padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`):
- The padding strategy to use for batches of audio with different lengths.
- max_length (`int`, *optional*, defaults to 480000):
- If provided, defines the maximum length of the audio to allow. Audio longer than this will be
- truncated if `truncation=True`.
- truncation (`bool`, *optional*, defaults to `True`):
- Whether or not to truncate audio above `max_length`.
- pad_to_multiple_of (`int`, *optional*, defaults to 128):
- When padding, pad to a multiple of this value. The default value is defined for optimal TPU support.
- return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`):
- The type of tensors to return (e.g., NumPy, or Torch).
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether to return the attention mask for the generated MEL spectrograms.
- """
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence))
- is_batched = is_batched_numpy or is_batched_sequence
-
- # Always return a batch
- if not is_batched:
- raw_speech = [raw_speech]
- raw_speech = [np.asarray([rs]).T for rs in raw_speech]
-
- batched_speech = self.pad(
- BatchFeature({"input_features": raw_speech}),
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- )
-
- prepared_speech = []
- prepared_speech_mask = []
- for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask):
- speech, mask = self._extract_spectrogram(speech.T, mask)
- prepared_speech.append(speech.astype(np.float32))
- prepared_speech_mask.append(mask)
-
- return BatchFeature(
- {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask},
- tensor_type=return_tensors,
- )
+Gemma3nAudioFeatureExtractor = deprecated_feature_extractor(Gemma3nAudioProcessor, "Gemma3nAudioFeatureExtractor")
__all__ = ["Gemma3nAudioFeatureExtractor"]
diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py
new file mode 100644
index 000000000000..98915a5afeb9
--- /dev/null
+++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py
@@ -0,0 +1,85 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+
+
+class GraniteSpeechAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ return_padding_mask = False
+ do_extract_spectrogram = True
+ projector_window_size = 15
+ projector_downsample_rate = 5
+ n_fft = 512
+ win_length = 400
+ hop_length = 160
+ n_mels = 80
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ import torchaudio
+
+ self.mel_filters_transform = torchaudio.transforms.MelSpectrogram(
+ sample_rate=self.sample_rate,
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ )
+
+ def extract_spectrogram(self, audio, **kwargs):
+ # Use torchaudio MelSpectrogram to match upstream FE exactly
+ melspec = self.mel_filters_transform.to(device=audio.device)
+ with torch.no_grad():
+ mel = melspec(audio.float())
+ logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
+ mx = logmel.amax(dim=(-2, -1), keepdim=True)
+ logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
+ # Remove last frame if odd
+ if logmel.shape[1] % 2 == 1:
+ logmel = logmel[:, :-1]
+ # Stacking by 2
+ features = logmel.reshape(audio.shape[0], -1, 2 * logmel.shape[-1])
+ return features
+
+ def _postprocess_output(self, output, audio_ranges=None, **kwargs):
+ hop_length = self.hop_length
+
+ # Compute audio_embed_sizes from original audio lengths
+ effective_window_size = self.projector_window_size // self.projector_downsample_rate
+ audio_embed_sizes = []
+ for start, end in audio_ranges:
+ raw_length = end - start
+ mel_length = raw_length // hop_length + 1
+ encoder_length = mel_length // 2
+ nblocks = math.ceil(encoder_length / self.projector_window_size)
+ projector_length = nblocks * effective_window_size
+ audio_embed_sizes.append(projector_length)
+
+ # Build input_features_mask matching the FE
+ input_features_mask = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor(
+ audio_embed_sizes
+ ).view(-1, 1)
+
+ output["audio_embed_sizes"] = audio_embed_sizes
+ output["audio_features_mask"] = input_features_mask
+ return output
+
+
+__all__ = ["GraniteSpeechAudioProcessor"]
diff --git a/src/transformers/models/granite_speech/feature_extraction_granite_speech.py b/src/transformers/models/granite_speech/feature_extraction_granite_speech.py
index cd32d0433bae..15bab8e6466f 100644
--- a/src/transformers/models/granite_speech/feature_extraction_granite_speech.py
+++ b/src/transformers/models/granite_speech/feature_extraction_granite_speech.py
@@ -11,174 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for Granite Speech."""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_granite_speech import GraniteSpeechAudioProcessor
-import math
-from collections.abc import Sequence
-
-import numpy as np
-
-from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
-from ...tokenization_utils_base import AudioInput
-from ...utils import is_torch_available, is_torchaudio_available, logging
-from ...utils.import_utils import requires_backends
-
-
-logger = logging.get_logger(__name__)
-
-if is_torch_available():
- import torch
-
-if is_torchaudio_available():
- import torchaudio
-
-
-class GraniteSpeechFeatureExtractor(FeatureExtractionMixin):
- model_input_names = ["input_features"]
-
- def __init__(
- self,
- sampling_rate: int = 16000,
- n_fft: int = 512,
- win_length: int = 400,
- hop_length: int = 160,
- n_mels: int = 80,
- projector_window_size: int = 15,
- projector_downsample_rate: int = 5,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.sampling_rate = sampling_rate
- self.melspec_kwargs = {
- "sample_rate": sampling_rate,
- "n_fft": n_fft,
- "win_length": win_length,
- "hop_length": hop_length,
- "n_mels": n_mels,
- }
- requires_backends(self, ["torchaudio"])
- self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
- self.projector_window_size = projector_window_size
- self.projector_downsample_rate = projector_downsample_rate
-
- def __call__(
- self,
- audios: AudioInput,
- device: str | None = "cpu",
- ) -> BatchFeature:
- requires_backends(self, ["torchaudio"])
-
- speech_inputs = {}
- batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios)
- speech_inputs["input_features"] = self._extract_mel_spectrograms(
- batched_audio,
- device=device,
- )
- audio_embed_sizes = self._get_num_audio_features(audio_lengths)
- speech_inputs["audio_embed_sizes"] = audio_embed_sizes
- # TODO (@alex-jw-brooks): Currently input_features_mask is not
- # a great name, because input_features and input_features_mask
- # have different shapes (before/after the projector).
- #
- # We should align this with other multimodal models, e.g,. llava
- # and qwen2audio and refactor this to ensure input_feature_mask
- # has the same dimensionality as input_features, or compute it in
- # the model based on the audio embedding sizes (since we do not
- # have an attention mask for the audio features to infer padding from).
- speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor(
- audio_embed_sizes
- ).view(-1, 1)
- return BatchFeature(data=speech_inputs)
-
- def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"):
- """
- Compute the Mel features to be passed to the conformer encoder.
- """
- requires_backends(self, ["torchaudio"])
- if device is not None:
- melspec = self.mel_filters.to(device)
- audio = audio.to(device)
- else:
- melspec = self.mel_filters
-
- bsz = audio.shape[0]
- with torch.no_grad():
- # Compute mel features
- mel = melspec(audio.float())
- logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_()
- mx = logmel.amax(dim=(-2, -1), keepdim=True)
- logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1)
- # remove last frame if odd
- if logmel.shape[1] % 2 == 1:
- logmel = logmel[:, :-1]
-
- # stacking and skipping by 2
- audio = logmel.reshape(bsz, -1, 2 * logmel.shape[-1])
-
- return audio
-
- def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int]:
- """
- Gets the (variable length) number of features (i.e., projector output) for the sequences
- being considered.
-
- Args:
- audio_lengths (`Sequence[int]`):
- Sequence of one or more raw audio lengths.
- """
- hop_length = self.melspec_kwargs["hop_length"]
- effective_window_size = self.projector_window_size // self.projector_downsample_rate
-
- projector_lengths = []
- for raw_length in audio_lengths:
- # mel sequence length computation
- mel_length = raw_length // hop_length + 1
- # encoder frame takes two mel features
- encoder_length = mel_length // 2
- nblocks = math.ceil(encoder_length / self.projector_window_size)
- # projector output length
- projector_length = nblocks * effective_window_size
- projector_lengths.append(projector_length)
-
- return projector_lengths
-
- def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]:
- """
- Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking.
-
- Args:
- audios (`AudioInput`):
- Audio sequence, numpy array, or torch tensor.
- """
- requires_backends(self, ["torch"])
-
- # Coerce to PyTorch tensors if we have numpy arrays, since
- # currently we have a dependency on torch/torchaudio anyway
- if isinstance(audios, np.ndarray):
- audios = torch.from_numpy(audios)
- elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray):
- audios = [torch.from_numpy(arr) for arr in audios]
-
- if isinstance(audios, torch.Tensor):
- if audios.ndim == 1:
- audios = audios.unsqueeze(0)
- if not torch.is_floating_point(audios):
- raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1")
-
- if audios.shape[0] > 1:
- logger.warning("Audio samples are already collated; assuming they all have the same length")
- lengths = [audios.shape[-1]] * audios.shape[0]
- return audios, lengths
-
- elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor):
- if not torch.is_floating_point(audios[0]):
- raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1")
- lengths = [audio.shape[-1] for audio in audios]
- audios = [audio.squeeze(0) for audio in audios]
- audios = torch.nn.utils.rnn.pad_sequence(audios, batch_first=True, padding_value=0.0)
- return audios, lengths
-
- raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays")
+GraniteSpeechFeatureExtractor = deprecated_feature_extractor(
+ GraniteSpeechAudioProcessor, "GraniteSpeechFeatureExtractor"
+)
__all__ = ["GraniteSpeechFeatureExtractor"]
diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py
new file mode 100644
index 000000000000..a07b213a2c9d
--- /dev/null
+++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py
@@ -0,0 +1,43 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+
+
+class KyutaiSpeechToTextAudioProcessor(NumpyAudioBackend):
+ sample_rate = 24000
+ force_mono = True
+ add_channel_dim = True
+ audio_delay_seconds = 2.5
+ audio_silence_prefix_seconds = 1.0
+
+ def _postprocess_output(self, output, **kwargs):
+ # Add silence prefix (left) and delay (right) padding
+ pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate)
+ pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate)
+
+ if pad_left > 0 or pad_right > 0:
+ output["audio_values"] = np.pad(
+ output["audio_values"], [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0,
+ )
+ output["audio_values_mask"] = np.pad(
+ output["audio_values_mask"], [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0,
+ )
+
+ return output
+
+
+__all__ = ["KyutaiSpeechToTextAudioProcessor"]
diff --git a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py
index b472473a19e5..5abc645f3f8a 100644
--- a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py
+++ b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py
@@ -1,10 +1,7 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py.
-# Do NOT edit this file manually as any edits will be overwritten by the generation of
-# the file from the modular. If any change should be done, please apply the change to the
-# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this.
+# This file is now a thin backward-compatibility wrapper. The original was auto-generated from modular_kyutai_speech_to_text.py.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
-# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved.
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,218 +14,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_kyutai_speech_to_text import KyutaiSpeechToTextAudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs an KyutaiSpeechToText feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
- sampling_rate (`int`, *optional*, defaults to 24000):
- The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values.
- chunk_length_s (`float`, *optional*):
- If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded.
- overlap (`float`, *optional*):
- Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following
- formulae : `int((1.0 - self.overlap) * self.chunk_length)`.
- audio_delay_seconds (`float`, *optional*, defaults to 0.0):
- The delay in seconds to add after the audio (right padding).
- audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0):
- The silence prefix in seconds to add before the audio (left padding).
- """
-
- model_input_names = ["input_values", "padding_mask"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 24000,
- padding_value: float = 0.0,
- chunk_length_s: float | None = None,
- overlap: float | None = None,
- audio_delay_seconds: float | None = 0.0,
- audio_silence_prefix_seconds: float | None = 0.0,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.chunk_length_s = chunk_length_s
- self.overlap = overlap
- self.audio_delay_seconds = audio_delay_seconds
- self.audio_silence_prefix_seconds = audio_silence_prefix_seconds
-
- # This is a property because you might want to change the chunk_length_s on the fly
- @property
- def chunk_length(self) -> int | None:
- if self.chunk_length_s is None:
- return None
- else:
- return int(self.chunk_length_s * self.sampling_rate)
-
- # This is a property because you might want to change the chunk_length_s on the fly
- @property
- def chunk_stride(self) -> int | None:
- if self.chunk_length_s is None or self.overlap is None:
- return None
- else:
- return max(1, int((1.0 - self.overlap) * self.chunk_length))
-
- def __call__(
- self,
- raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy | None = None,
- truncation: bool | None = False,
- max_length: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape
- `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio
- (`feature_size = 2`).
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- truncation (`bool`, *optional*, defaults to `False`):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if padding and truncation:
- raise ValueError("Both padding and truncation were set. Make sure you only set one.")
- elif padding is None:
- # by default let's pad the inputs
- padding = True
-
- is_batched = bool(
- isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
- elif not is_batched and not isinstance(raw_audio, np.ndarray):
- raw_audio = np.asarray(raw_audio, dtype=np.float32)
- elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
- raw_audio = raw_audio.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_audio = [np.asarray(raw_audio).T]
-
- # verify inputs are valid
- for idx, example in enumerate(raw_audio):
- if example.ndim > 2:
- raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
- if self.feature_size == 1 and example.ndim != 1:
- raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
- if self.feature_size == 2 and example.shape[-1] != 2:
- raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels")
-
- padded_inputs = None
- input_values = BatchFeature({"input_values": raw_audio})
- if self.chunk_stride is not None and self.chunk_length is not None and max_length is None:
- if truncation:
- max_length = min(array.shape[0] for array in raw_audio)
- nb_step = int(np.floor(max_length / self.chunk_stride))
- max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
- elif padding:
- max_length = max(array.shape[0] for array in raw_audio)
- nb_step = int(np.ceil(max_length / self.chunk_stride))
- max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length
- padding = "max_length"
- else:
- padded_inputs = input_values
-
- # normal padding on batch
- if padded_inputs is None:
- padded_inputs = self.pad(
- input_values,
- max_length=max_length,
- truncation=truncation,
- padding=padding,
- return_attention_mask=padding,
- )
-
- if padding:
- padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
-
- # now let's pad left and right
- pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate)
- pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate)
- padded_inputs["input_values"] = np.pad(
- padded_inputs["input_values"],
- ((0, 0), (pad_left, pad_right)),
- mode="constant",
- constant_values=0.0,
- )
- if padding:
- padded_inputs["padding_mask"] = np.pad(
- padded_inputs["padding_mask"],
- ((0, 0), (pad_left, pad_right)),
- mode="constant",
- constant_values=0,
- )
-
- input_values = []
- for example in padded_inputs.pop("input_values"):
- if self.feature_size == 1:
- example = example[..., None]
- input_values.append(example.T)
-
- padded_inputs["input_values"] = input_values
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
-
+KyutaiSpeechToTextFeatureExtractor = deprecated_feature_extractor(
+ KyutaiSpeechToTextAudioProcessor, "KyutaiSpeechToTextFeatureExtractor"
+)
__all__ = ["KyutaiSpeechToTextFeatureExtractor"]
diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py
new file mode 100644
index 000000000000..a1b581628988
--- /dev/null
+++ b/src/transformers/models/lasr/audio_processing_lasr.py
@@ -0,0 +1,55 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class LasrAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ hop_length=160,
+ win_length=400,
+ power=2.0,
+ center=False,
+ periodic=False,
+ left_align_fft=True,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=128,
+ f_min=125.0,
+ f_max=7500.0,
+ mel_scale="kaldi",
+ triangularize_in_mel_space=True,
+ bands_to_zero=1,
+ computation_dtype="float64",
+ matmul_order="features_first",
+ ),
+ log_mode="log",
+ mel_floor=1e-5,
+ computation_dtype="float64",
+ )
+
+ def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False):
+ stft_cfg = spectrogram_config.stft_config
+ win_length = stft_cfg.win_length or stft_cfg.n_fft
+ return (audio_lengths - win_length) // stft_cfg.hop_length + 1
+
+
+__all__ = ["LasrAudioProcessor"]
diff --git a/src/transformers/models/lasr/feature_extraction_lasr.py b/src/transformers/models/lasr/feature_extraction_lasr.py
index 7cf1822ee40d..90b1954ec5f2 100644
--- a/src/transformers/models/lasr/feature_extraction_lasr.py
+++ b/src/transformers/models/lasr/feature_extraction_lasr.py
@@ -11,265 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_lasr import LasrAudioProcessor
-import numpy as np
-import torch
-
-from ...audio_utils import hertz_to_mel
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, logging
-from ...utils.import_utils import requires
-
-
-logger = logging.get_logger(__name__)
-
-
-# TODO: @eustlb, we should be able to remove this and use mel_filter_bank from audio_utils
-def linear_to_mel_weight_matrix(
- num_mel_bins: int,
- num_spectrogram_bins: int,
- sample_rate: float,
- lower_edge_hertz: float,
- upper_edge_hertz: float,
- dtype,
-) -> np.ndarray:
- """NumPy-port of the JAX mel weight matrix logic."""
- # We use float64 for precision, matching the JAX implementation.
- internal_dtype = np.float64
-
- # HTK excludes the spectrogram DC bin.
- bands_to_zero = 1
- nyquist_hertz = sample_rate / 2.0
- linear_frequencies = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins, dtype=internal_dtype)[bands_to_zero:]
- spectrogram_bins_mel = hertz_to_mel(linear_frequencies, mel_scale="kaldi")[:, np.newaxis]
-
- edges = np.linspace(
- hertz_to_mel(lower_edge_hertz, mel_scale="kaldi"),
- hertz_to_mel(upper_edge_hertz, mel_scale="kaldi"),
- num_mel_bins + 2,
- dtype=internal_dtype,
- )
-
- lower_edge_mel, center_mel, upper_edge_mel = (
- edges[:-2][np.newaxis, :],
- edges[1:-1][np.newaxis, :],
- edges[2:][np.newaxis, :],
- )
-
- lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
- upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
- mel_weights_matrix = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes))
- return np.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]).astype(dtype)
-
-
-@requires(backends=("torch",))
-class LasrFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a LASR feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
- Fourier Transform` which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 128):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- hop_length (`int`, *optional*, defaults to 160):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- n_fft (`int`, *optional*, defaults to 512):
- Size of the Fourier transform.
- win_length (`int`, *optional*, defaults to 400):
- The window length for the STFT computation.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=128,
- sampling_rate=16000,
- hop_length=160,
- n_fft=512,
- win_length=400,
- padding_value=0.0,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- self.hop_length = hop_length
- self.n_fft = n_fft
- self.win_length = win_length
- self.mel_filters = torch.from_numpy(
- linear_to_mel_weight_matrix(
- num_mel_bins=feature_size,
- num_spectrogram_bins=n_fft // 2 + 1,
- sample_rate=sampling_rate,
- lower_edge_hertz=125.0,
- upper_edge_hertz=7500.0,
- dtype=np.float64,
- )
- )
-
- def _torch_extract_fbank_features(self, waveform, device="cpu"):
- # spectrogram
- window = torch.hann_window(self.win_length, periodic=False, device=device, dtype=torch.float64)
- waveform = waveform.to(torch.float64)
-
- # TODO: @eustlb, to be standardized
- # here we cannot use directly torch.stft because every fft frame is padded with zeros
- # due to unfold then rfft, while torch.stft unfolds with the number of fft points
- frames = waveform.unfold(-1, self.win_length, self.hop_length)
- stft = torch.fft.rfft(window * frames, n=self.n_fft)
- power_spec = torch.abs(stft) ** 2
-
- # log mel spectrogram
- mel_filters = self.mel_filters.to(device)
- mel_spec = torch.clamp(power_spec @ mel_filters, min=1e-5)
- mel_spec = torch.log(mel_spec)
-
- return mel_spec
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = None,
- padding: str | None = "longest",
- max_length: int | None = None,
- sampling_rate: int | None = None,
- do_normalize: bool | None = None,
- device: str | None = "cpu",
- return_token_timestamps: bool | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
- the STFT computation if available, otherwise a slower NumPy based one.
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values / vectors.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance of the model.
- device (`str`, *optional*, defaults to `'cpu'`):
- Specifies the device for computation of the log-mel spectrogram of audio signals in the
- `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
- return_token_timestamps (`bool`, *optional*, defaults to `None`):
- Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
-
- Whether or not to return the number of frames of the input raw_speech.
- These num_frames can be used by the model to compute word level timestamps.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- # Convert to torch tensor
- if isinstance(raw_speech, np.ndarray):
- raw_speech = torch.tensor(raw_speech)
- elif isinstance(raw_speech, (list, tuple)):
- if isinstance(raw_speech[0], (list, np.ndarray)):
- raw_speech = [torch.tensor(speech) for speech in raw_speech]
- else: # list[float]
- raw_speech = torch.tensor(raw_speech)
-
- is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
- if is_batched_torch and len(raw_speech.shape) > 2:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- raw_speech = raw_speech.mean(-1)
-
- is_batched_sequence = isinstance(raw_speech, (list, tuple))
- if is_batched_sequence:
- for speech in raw_speech:
- if len(speech.shape) > 1:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- speech = speech.mean(-1)
-
- if is_batched_torch or is_batched_sequence:
- raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
- else:
- raw_speech = [raw_speech[:, None].to(torch.float32)]
-
- batched_speech = BatchFeature({"input_features": raw_speech})
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- return_tensors="pt",
- )
- input_features = padded_inputs.input_features.squeeze(-1)
- input_features = self._torch_extract_fbank_features(input_features, device)
- data = {
- "input_features": input_features.to(torch.float32),
- }
-
- if return_attention_mask:
- attention_mask = padded_inputs.attention_mask[:, self.win_length - 1 :: self.hop_length]
- data["attention_mask"] = attention_mask.to(torch.bool)
-
- return BatchFeature(data=data, tensor_type=return_tensors)
+LasrFeatureExtractor = deprecated_feature_extractor(LasrAudioProcessor, "LasrFeatureExtractor")
__all__ = ["LasrFeatureExtractor"]
diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py
new file mode 100644
index 000000000000..1585ffae93d0
--- /dev/null
+++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py
@@ -0,0 +1,81 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...utils.import_utils import requires
+
+
+class MusicgenMelodyAudioProcessor(TorchAudioBackend):
+ sample_rate = 32000
+ force_mono = True
+ do_extract_spectrogram = True
+ return_padding_mask = False
+ n_fft = 16384
+ hop_length = 4096
+ n_chroma = 12
+ chunk_length = 30
+
+ @requires(backends=("librosa", "torch"))
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ import librosa
+ import torch
+
+ self.chroma_filters = torch.from_numpy(
+ librosa.filters.chroma(sr=self.sample_rate, n_fft=self.n_fft, tuning=0, n_chroma=self.n_chroma)
+ ).float()
+
+ def extract_spectrogram(self, audio, **kwargs):
+ import torch
+ import torchaudio
+
+ waveform = audio # Already a batched tensor from _to_batch
+ device = waveform.device
+ batch_size = waveform.shape[0]
+
+ # Pad if too short for FFT
+ if waveform.shape[-1] < self.n_fft:
+ pad = self.n_fft - waveform.shape[-1]
+ rest = 0 if pad % 2 == 0 else 1
+ waveform = torch.nn.functional.pad(waveform, (pad // 2, pad // 2 + rest), "constant", 0)
+
+ # Add channel dim for spectrogram: (batch, 1, length)
+ waveform = waveform.unsqueeze(1)
+
+ # Power spectrogram (normalized)
+ spec_transform = torchaudio.transforms.Spectrogram(
+ n_fft=self.n_fft, win_length=self.n_fft, hop_length=self.hop_length,
+ power=2, center=True, pad=0, normalized=True,
+ ).to(device)
+ spec = spec_transform(waveform).squeeze(1)
+
+ # Chroma features
+ chroma_filters = self.chroma_filters.to(device)
+ raw_chroma = torch.einsum("cf, ...ft->...ct", chroma_filters, spec)
+
+ # Normalize with inf norm
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=float("inf"), dim=-2, eps=1e-6)
+
+ # Transpose: (batch, chroma, frames) -> (batch, frames, chroma)
+ norm_chroma = norm_chroma.transpose(1, 2)
+
+ # One-hot encoding: argmax along chroma dim
+ idx = norm_chroma.argmax(-1, keepdim=True)
+ norm_chroma[:] = 0
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+ return norm_chroma
+
+
+__all__ = ["MusicgenMelodyAudioProcessor"]
diff --git a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py
index 1811fa11e630..c41ea0666292 100644
--- a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py
@@ -11,324 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for Musicgen Melody
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_musicgen_melody import MusicgenMelodyAudioProcessor
-import copy
-from typing import Any
-
-import numpy as np
-
-from ...audio_utils import chroma_filter_bank
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, is_torch_available, is_torchaudio_available, logging
-from ...utils.import_utils import requires
-
-
-if is_torch_available():
- import torch
-
-if is_torchaudio_available():
- import torchaudio
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("torchaudio",))
-class MusicgenMelodyFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a MusicgenMelody feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts chroma features from audio processed by [Demucs](https://github.com/adefossez/demucs/tree/main) or
- directly from raw audio waveform.
-
- Args:
- feature_size (`int`, *optional*, defaults to 12):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 32000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- hop_length (`int`, *optional*, defaults to 4096):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- chunk_length (`int`, *optional*, defaults to 30):
- The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio
- sequences.
- n_fft (`int`, *optional*, defaults to 16384):
- Size of the Fourier transform.
- num_chroma (`int`, *optional*, defaults to 12):
- Number of chroma bins to use.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio.
- return_attention_mask (`bool`, *optional*, defaults to `False`):
- Whether to return the attention mask. Can be overwritten when calling the feature extractor.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
- stem_indices (`list[int]`, *optional*, defaults to `[3, 2]`):
- Stem channels to extract if demucs outputs are passed.
- """
-
- model_input_names = ["input_features"]
-
- def __init__(
- self,
- feature_size=12,
- sampling_rate=32000,
- hop_length=4096,
- chunk_length=30,
- n_fft=16384,
- num_chroma=12,
- padding_value=0.0,
- return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
- stem_indices=[3, 2],
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
- self.n_fft = n_fft
- self.hop_length = hop_length
- self.chunk_length = chunk_length
- self.n_samples = chunk_length * sampling_rate
- self.sampling_rate = sampling_rate
- self.chroma_filters = torch.from_numpy(
- chroma_filter_bank(sampling_rate=sampling_rate, num_frequency_bins=n_fft, tuning=0, num_chroma=num_chroma)
- ).float()
- self.spectrogram = torchaudio.transforms.Spectrogram(
- n_fft=n_fft, win_length=n_fft, hop_length=hop_length, power=2, center=True, pad=0, normalized=True
- )
- self.stem_indices = stem_indices
-
- def _torch_extract_fbank_features(self, waveform: torch.Tensor) -> torch.Tensor:
- """
- Compute the chroma spectrogram of the provided audio using the torchaudio spectrogram implementation and the librosa chroma features.
- """
-
- # if wav length is not long enough, pad it
- wav_length = waveform.shape[-1]
- if wav_length < self.n_fft:
- pad = self.n_fft - wav_length
- rest = 0 if pad % 2 == 0 else 1
- waveform = torch.nn.functional.pad(waveform, (pad // 2, pad // 2 + rest), "constant", 0)
-
- # squeeze alongside channel dimension
- spec = self.spectrogram(waveform).squeeze(1)
-
- # sum along the frequency dimension
- raw_chroma = torch.einsum("cf, ...ft->...ct", self.chroma_filters, spec)
-
- # normalise with max value
- norm_chroma = torch.nn.functional.normalize(raw_chroma, p=float("inf"), dim=-2, eps=1e-6)
-
- # transpose time and chroma dimension -> (batch, time, chroma)
- norm_chroma = norm_chroma.transpose(1, 2)
-
- # replace max value alongside chroma dimension with 1 and replace the rest with 0
- idx = norm_chroma.argmax(-1, keepdim=True)
- norm_chroma[:] = 0
- norm_chroma.scatter_(dim=-1, index=idx, value=1)
-
- return norm_chroma
-
- def _extract_stem_indices(self, audio, sampling_rate=None):
- """
- Extracts stems from the output of the [Demucs](https://github.com/adefossez/demucs/tree/main) audio separation model,
- then converts to mono-channel and resample to the feature extractor sampling rate.
-
- Args:
- audio (`torch.Tensor` of shape `(batch_size, num_stems, channel_size, audio_length)`):
- The output of the Demucs model to be processed.
- sampling_rate (`int`, *optional*):
- Demucs sampling rate. If not specified, defaults to `44000`.
- """
- sampling_rate = 44000 if sampling_rate is None else sampling_rate
-
- # extract "vocals" and "others" sources from audio encoder (demucs) output
- # [batch_size, num_stems, channel_size, audio_length]
- wav = audio[:, torch.tensor(self.stem_indices)]
-
- # merge extracted stems to single waveform
- wav = wav.sum(1)
-
- # convert to mono-channel waveform
- wav = wav.mean(dim=1, keepdim=True)
-
- # resample to model sampling rate
- # not equivalent to julius.resample
- if sampling_rate != self.sampling_rate:
- wav = torchaudio.functional.resample(
- wav, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24
- )
-
- # [batch_size, 1, audio_length] -> [batch_size, audio_length]
- wav = wav.squeeze(1)
-
- return wav
-
- def __call__(
- self,
- audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: bool = True,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = None,
- padding: str | None = True,
- max_length: int | None = None,
- sampling_rate: int | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- audio (`torch.Tensor`, `np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a torch tensor, a numpy array, a list of float
- values, a list of numpy arrays, a list of torch tensors, or a list of list of float values.
- If `audio` is the output of Demucs, it has to be a torch tensor of shape `(batch_size, num_stems, channel_size, audio_length)`.
- Otherwise, it must be mono or stereo channel audio.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
- For Musicgen Melody models, audio `attention_mask` is not necessary.
-
-
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- Note that if `audio` is the output of Demucs, `sampling_rate` must be the sampling rate at which Demucs operates.
- """
-
- if sampling_rate is None:
- logger.warning_once(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if isinstance(audio, torch.Tensor) and len(audio.shape) == 4:
- logger.warning_once(
- "`audio` is a 4-dimensional torch tensor and has thus been recognized as the output of `Demucs`. "
- "If this is not the case, make sure to read Musicgen Melody docstrings and "
- "to correct `audio` to get the right behaviour."
- "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody"
- )
- audio = self._extract_stem_indices(audio, sampling_rate=sampling_rate)
- elif sampling_rate is not None and sampling_rate != self.sampling_rate:
- audio = torchaudio.functional.resample(
- audio, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24
- )
-
- is_batched = isinstance(audio, (np.ndarray, torch.Tensor)) and len(audio.shape) > 1
- is_batched = is_batched or (
- isinstance(audio, (list, tuple)) and (isinstance(audio[0], (torch.Tensor, np.ndarray, tuple, list)))
- )
-
- if is_batched and not isinstance(audio[0], torch.Tensor):
- audio = [torch.tensor(speech, dtype=torch.float32).unsqueeze(-1) for speech in audio]
- elif is_batched:
- audio = [speech.unsqueeze(-1) for speech in audio]
- elif not is_batched and not isinstance(audio, torch.Tensor):
- audio = torch.tensor(audio, dtype=torch.float32).unsqueeze(-1)
-
- if isinstance(audio[0], torch.Tensor) and audio[0].dtype is torch.float64:
- audio = [speech.to(torch.float32) for speech in audio]
-
- # always return batch
- if not is_batched:
- audio = [audio]
-
- if len(audio[0].shape) == 3:
- logger.warning_once(
- "`audio` has been detected as a batch of stereo signals. Will be convert to mono signals. "
- "If this is an undesired behaviour, make sure to read Musicgen Melody docstrings and "
- "to correct `audio` to get the right behaviour."
- "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody"
- )
- # convert to mono-channel waveform
- audio = [stereo.mean(dim=0) for stereo in audio]
-
- batched_speech = BatchFeature({"input_features": audio})
-
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length if max_length else self.n_samples,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- return_tensors="pt",
- )
-
- input_features = self._torch_extract_fbank_features(padded_inputs["input_features"].squeeze(-1))
-
- padded_inputs["input_features"] = input_features
-
- if return_attention_mask:
- # rescale from raw audio length to spectrogram length
- padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
-
- def to_dict(self) -> dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- output = copy.deepcopy(self.__dict__)
- output["feature_extractor_type"] = self.__class__.__name__
- if "mel_filters" in output:
- del output["mel_filters"]
- if "window" in output:
- del output["window"]
- if "chroma_filters" in output:
- del output["chroma_filters"]
- if "spectrogram" in output:
- del output["spectrogram"]
- return output
+MusicgenMelodyFeatureExtractor = deprecated_feature_extractor(
+ MusicgenMelodyAudioProcessor, "MusicgenMelodyFeatureExtractor"
+)
__all__ = ["MusicgenMelodyFeatureExtractor"]
diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py
new file mode 100644
index 000000000000..5df813fabae5
--- /dev/null
+++ b/src/transformers/models/parakeet/audio_processing_parakeet.py
@@ -0,0 +1,149 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class ParakeetAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ hop_length=160,
+ win_length=400,
+ window_fn="hann_window",
+ power=2.0,
+ pad_mode="constant",
+ periodic=False,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ f_min=0.0,
+ norm="slaney",
+ mel_scale="slaney",
+ ),
+ preemphasis=0.97,
+ log_mode="log",
+ mel_floor=2**-24,
+ )
+
+ def _mel_filter_bank(self, spectrogram_config):
+ """Compute mel filters via numpy for exact numerical match with the feature extractor.
+
+ The FE uses librosa which accumulates into a float32 array per-band.
+ Replicating that truncation pattern is needed for bit-exact results.
+ """
+ import numpy as np
+ import torch
+
+ from ...audio_utils import hertz_to_mel, mel_to_hertz
+
+ stft_cfg = spectrogram_config.stft_config
+ mel_cfg = spectrogram_config.mel_scale_config
+ n_fft = stft_cfg.n_fft
+ n_mels = mel_cfg.n_mels
+ f_min = mel_cfg.f_min
+ f_max = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2
+
+ mel_min = hertz_to_mel(f_min, mel_scale=mel_cfg.mel_scale)
+ mel_max = hertz_to_mel(f_max, mel_scale=mel_cfg.mel_scale)
+ mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
+ filter_freqs = mel_to_hertz(mel_pts.copy(), mel_scale=mel_cfg.mel_scale)
+ fft_freqs = np.linspace(0, self.sample_rate / 2, 1 + n_fft // 2)
+
+ fdiff = np.diff(filter_freqs)
+ ramps = np.subtract.outer(filter_freqs, fft_freqs)
+
+ # Accumulate into f32 per-band to match librosa's truncation pattern
+ weights = np.zeros((n_mels, 1 + n_fft // 2), dtype=np.float32)
+ for i in range(n_mels):
+ lower = -ramps[i] / fdiff[i]
+ upper = ramps[i + 2] / fdiff[i + 1]
+ weights[i] = np.maximum(0, np.minimum(lower, upper))
+
+ if mel_cfg.norm == "slaney":
+ enorm = 2.0 / (filter_freqs[2 : n_mels + 2] - filter_freqs[:n_mels])
+ weights *= enorm[:, np.newaxis]
+
+ return torch.from_numpy(weights.T).to(torch.float32)
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ import torch
+
+ magnitudes = torch.view_as_real(stft_out)
+ magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1))
+ if power != 1.0:
+ magnitudes = magnitudes.pow(power)
+ return magnitudes
+
+ def _needs_manual_framing(self, spectrogram_config):
+ # Preemphasis is handled waveform-level in _stft; no per-frame processing needed.
+ return spectrogram_config.remove_dc_offset or spectrogram_config.stft_config.left_align_fft
+
+ def _stft(self, audio, *, spectrogram_config, audio_ranges=None, **kwargs):
+ import torch
+
+ audio_lengths = torch.tensor(
+ [end - start for start, end in audio_ranges], device=audio.device
+ ) if audio_ranges is not None else None
+
+ # Waveform-level preemphasis with masking to zero out padding
+ preemphasis = spectrogram_config.preemphasis
+ if preemphasis is not None:
+ audio = torch.cat(
+ [audio[:, :1], audio[:, 1:] - preemphasis * audio[:, :-1]], dim=1
+ )
+ if audio_lengths is not None:
+ timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < audio_lengths.unsqueeze(1)
+ audio = audio.masked_fill(~timemask, 0.0)
+
+ return super()._stft(audio, spectrogram_config=spectrogram_config, **kwargs)
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ import torch
+
+ return torch.matmul(self.mel_filters.T, features)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, audio_ranges=None, **kwargs):
+ import torch
+
+ # Match FE: log(mel_spec + guard_value) instead of log(clamp(mel_spec, guard_value))
+ features = torch.log(features + spectrogram_config.mel_floor)
+
+ # (batch, mels, frames) -> (batch, frames, mels)
+ features = features.permute(0, 2, 1)
+
+ # Per-utterance normalization
+ if audio_ranges is not None:
+ stft_cfg = spectrogram_config.stft_config
+ audio_lengths = torch.tensor([end - start for start, end in audio_ranges])
+ features_lengths = torch.floor_divide(
+ audio_lengths + stft_cfg.n_fft // 2 * 2 - stft_cfg.n_fft, stft_cfg.hop_length
+ )
+ attention_mask = torch.arange(features.shape[1])[None, :] < features_lengths[:, None]
+ mask = attention_mask.unsqueeze(-1)
+ mel_masked = features * mask
+ mean = mel_masked.sum(dim=1) / features_lengths.unsqueeze(-1)
+ mean = mean.unsqueeze(1)
+ variance = ((mel_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1)
+ std = torch.sqrt(variance).unsqueeze(1)
+ features = (features - mean) / (std + 1e-5)
+ features *= mask
+
+ return features
+
+
+__all__ = ["ParakeetAudioProcessor"]
diff --git a/src/transformers/models/parakeet/feature_extraction_parakeet.py b/src/transformers/models/parakeet/feature_extraction_parakeet.py
index c745d02c9629..92f02cd0a9f4 100644
--- a/src/transformers/models/parakeet/feature_extraction_parakeet.py
+++ b/src/transformers/models/parakeet/feature_extraction_parakeet.py
@@ -11,275 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_parakeet import ParakeetAudioProcessor
-import numpy as np
-import torch
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, is_librosa_available, logging
-from ...utils.import_utils import requires
-
-
-if is_librosa_available():
- import librosa
-
-
-EPSILON = 1e-5
-LOG_ZERO_GUARD_VALUE = 2**-24
-
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("torch", "librosa"))
-class ParakeetFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Parakeet feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
- Fourier Transform` which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 80):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- hop_length (`int`, *optional*, defaults to 160):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- n_fft (`int`, *optional*, defaults to 512):
- Size of the Fourier transform.
- win_length (`int`, *optional*, defaults to 400):
- The window length for the STFT computation.
- preemphasis (`float`, *optional*, defaults to 0.97):
- A preemphasis filter coefficient. 0.0 means no preemphasis filter.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=80,
- sampling_rate=16000,
- hop_length=160,
- n_fft=512,
- win_length=400,
- preemphasis=0.97,
- padding_value=0.0,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- self.hop_length = hop_length
- self.n_fft = n_fft
- self.win_length = win_length
- self.preemphasis = preemphasis
-
- # TODO: @eustlb, for now we use librosa to compute the mel filters
- # indeed mel_filter_bank uses np.float64 (while librosa uses np.float32), giving numerical differences
- # self.mel_filters = mel_filter_bank(
- # num_frequency_bins=n_fft // 2 + 1,
- # num_mel_filters=feature_size,
- # min_frequency=0.0,
- # max_frequency=sampling_rate / 2,
- # sampling_rate=sampling_rate,
- # norm="slaney",
- # mel_scale="slaney",
- # )
- mel_filters = librosa.filters.mel(
- sr=sampling_rate, n_fft=n_fft, n_mels=feature_size, fmin=0.0, fmax=sampling_rate / 2, norm="slaney"
- )
- self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32)
-
- def _torch_extract_fbank_features(self, waveform, device="cpu"):
- # spectrogram
- window = torch.hann_window(self.win_length, periodic=False, device=device)
- stft = torch.stft(
- waveform,
- self.n_fft,
- hop_length=self.hop_length,
- win_length=self.win_length,
- window=window,
- return_complex=True,
- pad_mode="constant",
- )
- # Let's math original implementation
- # magnitudes = torch.abs(stft) ** 2
- magnitudes = torch.view_as_real(stft)
- magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1))
- magnitudes = magnitudes.pow(2)
-
- # log mel spectrogram
- mel_filters = self.mel_filters.to(device)
- mel_spec = mel_filters @ magnitudes
- mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE)
-
- # (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters)
- mel_spec = mel_spec.permute(0, 2, 1)
-
- return mel_spec
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = None,
- padding: str | None = "longest",
- max_length: int | None = None,
- sampling_rate: int | None = None,
- do_normalize: bool | None = None,
- device: str | None = "cpu",
- return_token_timestamps: bool | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
- the STFT computation if available, otherwise a slower NumPy based one.
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values / vectors.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance of the model.
- device (`str`, *optional*, defaults to `'cpu'`):
- Specifies the device for computation of the log-mel spectrogram of audio signals in the
- `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
- return_token_timestamps (`bool`, *optional*, defaults to `None`):
- Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
-
- Whether or not to return the number of frames of the input raw_speech.
- These num_frames can be used by the model to compute word level timestamps.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- # Convert to torch tensor
- if isinstance(raw_speech, np.ndarray):
- raw_speech = torch.tensor(raw_speech)
- elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
- raw_speech = [torch.tensor(speech) for speech in raw_speech]
-
- is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
- if is_batched_torch and len(raw_speech.shape) > 2:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- raw_speech = raw_speech.mean(-1)
-
- is_batched_sequence = isinstance(raw_speech, (list, tuple))
- if is_batched_sequence:
- for speech in raw_speech:
- if len(speech.shape) > 1:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- speech = speech.mean(-1)
-
- if is_batched_torch or is_batched_sequence:
- raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
- else:
- raw_speech = [raw_speech[:, None].to(torch.float32)]
-
- audio_lengths = [len(speech) for speech in raw_speech]
- batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths})
-
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_tensors="pt",
- )
- input_features = padded_inputs.input_features.squeeze(-1)
-
- # preemphasis
- if self.preemphasis is not None:
- timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze(
- 0
- ) < padded_inputs.audio_lengths.unsqueeze(1)
- input_features = torch.cat(
- [input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1
- )
- input_features = input_features.masked_fill(~timemask, 0.0)
-
- input_features = self._torch_extract_fbank_features(input_features, device)
- features_lengths = torch.floor_divide(
- padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length
- )
- attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None]
-
- # normalize mel features, ignoring padding
- mask = attention_mask.unsqueeze(-1)
- input_features_masked = input_features * mask
- mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1)
- mean = mean.unsqueeze(1)
- variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1)
- std = torch.sqrt(variance).unsqueeze(1)
- input_features = (input_features - mean) / (std + EPSILON)
- input_features *= mask
-
- return BatchFeature(
- data={
- "input_features": input_features,
- "attention_mask": attention_mask,
- },
- tensor_type=return_tensors,
- )
+ParakeetFeatureExtractor = deprecated_feature_extractor(ParakeetAudioProcessor, "ParakeetFeatureExtractor")
__all__ = ["ParakeetFeatureExtractor"]
diff --git a/src/transformers/models/pe_audio/audio_processing_pe_audio.py b/src/transformers/models/pe_audio/audio_processing_pe_audio.py
new file mode 100644
index 000000000000..1c8969b28ed2
--- /dev/null
+++ b/src/transformers/models/pe_audio/audio_processing_pe_audio.py
@@ -0,0 +1,23 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import NumpyAudioBackend
+
+
+class PeAudioAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+
+
+__all__ = ["PeAudioAudioProcessor"]
diff --git a/src/transformers/models/pe_audio/feature_extraction_pe_audio.py b/src/transformers/models/pe_audio/feature_extraction_pe_audio.py
index a7738d3089ac..da1f7d34a86f 100644
--- a/src/transformers/models/pe_audio/feature_extraction_pe_audio.py
+++ b/src/transformers/models/pe_audio/feature_extraction_pe_audio.py
@@ -11,150 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_pe_audio import PeAudioAudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...processing_utils import load_audio
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class PeAudioFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a PeAudioFeatureExtractor feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features. Use 1 for mono, 2 for stereo.
- sampling_rate (`int`, *optional*, defaults to 48000):
- The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used for padding.
- hop_length (`int`, *optional*, defaults to 1920):
- Overlap length between successive windows.
- """
-
- model_input_names = ["input_values"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 48_000,
- padding_value: float = 0.0,
- hop_length: int = 1920,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.hop_length = hop_length
-
- def _reflect_pad(self, wav):
- if len(wav) % self.hop_length == 0:
- return wav
- p1d = (0, self.hop_length - (len(wav) % self.hop_length))
- return np.pad(wav, p1d, "reflect")
-
- def __call__(
- self,
- raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | str | list[str],
- padding: bool | str | PaddingStrategy | None = None,
- truncation: bool | None = False,
- max_length: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- ) -> BatchFeature:
- from_file = False
- if isinstance(raw_audio, str):
- raw_audio = [raw_audio]
-
- if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str):
- loaded = []
- for audio_file in raw_audio:
- loaded.append(load_audio(audio_file, self.sampling_rate))
- raw_audio = loaded
- from_file = True
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- elif not from_file:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if padding and truncation:
- raise ValueError("Both padding and truncation were set. Make sure you only set one.")
- elif padding is None:
- # by default let's pad the inputs
- padding = True
-
- is_batched = bool(
- isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio]
- elif not is_batched and not isinstance(raw_audio, np.ndarray):
- raw_audio = np.asarray(raw_audio, dtype=np.float32)
- elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64):
- raw_audio = raw_audio.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_audio = [np.asarray(raw_audio).T]
-
- if isinstance(raw_audio, list):
- raw_audio = [self._reflect_pad(x) for x in raw_audio]
- else:
- raw_audio = self._reflect_pad(raw_audio)
-
- # verify inputs are valid
- for example in raw_audio:
- if example.ndim > 2:
- raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}")
- if self.feature_size == 1 and example.ndim != 1:
- raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels")
- if self.feature_size == 2:
- raise ValueError("Stereo audio isn't supported for now")
-
- input_values = BatchFeature({"input_values": raw_audio})
-
- # normal padding on batch
- padded_inputs = self.pad(
- input_values,
- max_length=max_length,
- truncation=truncation,
- padding=padding,
- return_attention_mask=padding,
- pad_to_multiple_of=self.hop_length,
- )
- if padding:
- padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask")
- if padding:
- padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :]
-
- input_values = []
- for example in padded_inputs.pop("input_values"):
- if self.feature_size == 1:
- example = example[..., None]
- input_values.append(example.T)
-
- padded_inputs["input_values"] = input_values
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+PeAudioFeatureExtractor = deprecated_feature_extractor(PeAudioAudioProcessor, "PeAudioFeatureExtractor")
__all__ = ["PeAudioFeatureExtractor"]
diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py
new file mode 100644
index 000000000000..a63321c9a346
--- /dev/null
+++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py
@@ -0,0 +1,127 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank
+
+
+class Phi4MultimodalAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ audio_compression_rate = 8
+ audio_downsample_rate = 1
+ audio_feat_stride = 1
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ win_length=400,
+ hop_length=160,
+ window_fn="hamming_window",
+ periodic=False,
+ center=False,
+ power=2.0,
+ window_dtype="float64",
+ ),
+ preemphasis=0.97,
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ f_min=0,
+ f_max=7690,
+ mel_scale="kaldi",
+ triangularize_in_mel_space=True,
+ matmul_order="features_first",
+ ),
+ mel_floor=1.0,
+ log_mode="log",
+ )
+
+ def _mel_filter_bank(self, spectrogram_config):
+ stft_cfg = spectrogram_config.stft_config
+ mel_cfg = spectrogram_config.mel_scale_config
+ mel_filters_np = mel_filter_bank(
+ num_frequency_bins=1 + stft_cfg.n_fft // 2,
+ num_mel_filters=mel_cfg.n_mels,
+ min_frequency=mel_cfg.f_min,
+ max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2,
+ sampling_rate=self.sample_rate,
+ norm=mel_cfg.norm,
+ mel_scale=mel_cfg.mel_scale,
+ triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space,
+ )
+ return torch.from_numpy(mel_filters_np).to(torch.float32)
+
+ def _apply_frame_processing(self, frames, *, spectrogram_config, audio_ranges=None, **kwargs):
+ # Mask frames that overlap the boundary between real audio and padding
+ stft_cfg = spectrogram_config.stft_config
+ win_length = stft_cfg.win_length or stft_cfg.n_fft
+ hop_length = stft_cfg.hop_length or win_length // 2
+ batch_size = frames.shape[0]
+
+ if audio_ranges is not None and batch_size > 1:
+ audio_lengths_t = torch.tensor([end - start for start, end in audio_ranges])
+ to_mask_idxs = torch.arange(batch_size)[audio_lengths_t != audio_lengths_t.max()]
+ if to_mask_idxs.numel() > 0:
+ frames = frames.clone()
+ down = (audio_lengths_t[to_mask_idxs] - win_length) // hop_length + 1
+ up = audio_lengths_t[to_mask_idxs] // hop_length - 1
+ offset = down.min()
+ max_idx = up.max()
+
+ mask_range = torch.arange(max_idx - offset).expand(to_mask_idxs.shape[0], -1)
+ mask = ((down - offset).unsqueeze(1) <= mask_range) & (mask_range < (up - offset).unsqueeze(1))
+ mask = mask.unsqueeze(-1).expand(-1, -1, win_length)
+
+ masked_frames = frames[to_mask_idxs, offset:max_idx].masked_fill_(mask, 0)
+ frames[to_mask_idxs, offset:max_idx] = masked_frames
+
+ frames_prev = torch.roll(frames, 1, dims=-1)
+ frames_prev[..., 0] = frames_prev[..., 1]
+ return (frames - spectrogram_config.preemphasis * frames_prev) * 32768
+
+ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None):
+ frames = frames * window
+ if frame_length < n_fft:
+ frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length))
+ # Cast to complex64 before abs() to match the FE's precision path
+ spec = torch.fft.rfft(frames, n=n_fft).to(torch.complex64)
+ if stft_cfg.normalized:
+ spec = spec / window.pow(2.0).sum().sqrt()
+ return spec.transpose(-2, -1)
+
+ def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False):
+ win_length = spectrogram_config.stft_config.win_length or spectrogram_config.stft_config.n_fft
+ hop_length = spectrogram_config.stft_config.hop_length or win_length // 2
+ return (audio_lengths - win_length) // hop_length + 1
+
+ def _compute_audio_embed_size(self, audio_frames):
+ integer = audio_frames // self.audio_compression_rate
+ remainder = audio_frames % self.audio_compression_rate
+ result = integer + (remainder > 0).to(integer.dtype)
+
+ integer = result // self.audio_downsample_rate
+ remainder = result % self.audio_downsample_rate
+ result = integer + (remainder > 0).to(integer.dtype)
+
+ return result
+
+ def _postprocess_output(self, output, **kwargs):
+ feature_lengths = output["audio_features_mask"].sum(dim=-1)
+ feature_lengths = feature_lengths * self.audio_feat_stride
+ output["audio_embed_sizes"] = self._compute_audio_embed_size(feature_lengths)
+ return output
+
+
+__all__ = ["Phi4MultimodalAudioProcessor"]
diff --git a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py
index 9ce98251e50e..78d4727cbccd 100644
--- a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py
+++ b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py
@@ -11,271 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_phi4_multimodal import Phi4MultimodalAudioProcessor
-"""
-Processor class for Phi4Multimodal
-"""
-
-import numpy as np
-
-from ...audio_utils import AudioInput, mel_filter_bank
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...image_processing_utils import BatchFeature
-from ...utils import TensorType, is_torch_available, logging
-
-
-if is_torch_available():
- import torch
-
-
-logger = logging.get_logger(__name__)
-
-
-class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor):
- model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"]
-
- def __init__(
- self,
- feature_size: int = 80,
- sampling_rate: int = 16000,
- hop_length: int = 160,
- n_fft: int = 512,
- win_length: int = 400,
- preemphasis: float = 0.97,
- padding_value: float = 0.0,
- audio_compression_rate: int = 8,
- audio_downsample_rate: int = 1,
- audio_feat_stride: int = 1,
- mel_min_frequency: float = 0,
- mel_max_frequency: float = 7690,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- self.hop_length = hop_length
- self.n_fft = n_fft
- self.win_length = win_length
- self.preemphasis = preemphasis
- self.padding_value = padding_value
- self.audio_compression_rate = audio_compression_rate
- self.audio_downsample_rate = audio_downsample_rate
- self.audio_feat_stride = audio_feat_stride
-
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=self.n_fft // 2 + 1,
- num_mel_filters=self.feature_size,
- min_frequency=mel_min_frequency,
- max_frequency=mel_max_frequency,
- sampling_rate=self.sampling_rate,
- triangularize_in_mel_space=True,
- mel_scale="kaldi",
- )
-
- def __call__(
- self,
- raw_speech: AudioInput,
- sampling_rate: int | None = None,
- pad_to_multiple_of: int | None = None,
- padding: str | None = "longest",
- max_length: int | None = None,
- truncation: bool = False,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = True,
- device: str | None = "cpu",
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for
- the STFT computation if available, otherwise a slower NumPy based one.
-
- Args:
- raw_speech (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor.
- For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or
- PyTorch tensor with first dimension being the batch size.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
- padding (`str`, *optional*, defaults to "longest"):
- Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length.
- truncation (`bool`, *optional*, defaults to False):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of numpy arrays. Acceptable values are:
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether to return the extracted audio input features' attention mask.
- device (`str`, *optional*, defaults to "cpu"):
- Specifies the device for computation of the audio features. (e.g., "cpu", "cuda")
-
- Returns:
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
- - **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size).
- - **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,).
- - **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length).
- If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- # Convert to torch tensor
- if isinstance(raw_speech, np.ndarray):
- raw_speech = torch.tensor(raw_speech)
- elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray):
- raw_speech = [torch.tensor(speech) for speech in raw_speech]
-
- is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
- if is_batched_torch and len(raw_speech.shape) > 2:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- raw_speech = raw_speech.mean(-1)
-
- is_batched_sequence = isinstance(raw_speech, (list, tuple))
- if is_batched_sequence:
- for speech in raw_speech:
- if len(speech.shape) > 1:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- speech = speech.mean(-1)
-
- if is_batched_torch or is_batched_sequence:
- raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
- else:
- raw_speech = [raw_speech[:, None].to(torch.float32)]
-
- audio_lengths = [len(speech) for speech in raw_speech]
-
- # convert into correct format for padding
- batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths})
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_tensors="pt",
- )
- input_features = padded_inputs.audio_input_features.squeeze(-1)
- audio_lengths = padded_inputs.audio_lengths
-
- input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device)
-
- feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1
- feature_lengths = feature_lengths * self.audio_feat_stride
- audio_embed_sizes = self._compute_audio_embed_size(feature_lengths)
-
- feature_attention_mask = (
- torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max())
- )
- feature_attention_mask = (
- feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None
- )
-
- data = {
- "audio_input_features": input_features,
- "audio_embed_sizes": audio_embed_sizes,
- }
- if feature_attention_mask is not None and return_attention_mask:
- data["audio_attention_mask"] = feature_attention_mask
-
- return BatchFeature(data=data, tensor_type=return_tensors)
-
- # TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy
- def _torch_extract_fbank_features(
- self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu"
- ) -> "torch.FloatTensor":
- """
- Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation.
-
- Args:
- waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`):
- The batched waveforms.
- audio_lengths (`torch.Tensor` of shape `(batch_size,)`):
- The lengths of the waveforms along the max_audio_length dimension.
- device (`str`, *optional*, defaults to "cpu"):
- The device to run the computation on. (e.g., "cpu", "cuda")
-
- Returns:
- `torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`:
- The log mel-scaled spectrogram of the batched waveforms.
- """
- fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64)
-
- # batched implementation
- batch_size = waveform.shape[0]
- frames = waveform.unfold(-1, self.win_length, self.hop_length)
-
- # ---
- # the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame
- # we need to ensure that the corresponding frames for the padded input also mask these values
- if batch_size > 1:
- frames = frames.clone()
- # concerned batch indices
- to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()]
- if to_mask_batch_idxs.numel() > 0:
- batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1
- batch_idxs_up = (audio_lengths[to_mask_batch_idxs] // self.hop_length) - 1
- offset_idx = batch_idxs_down.min()
- max_idx = batch_idxs_up.max()
-
- mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1)
- mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & (
- mask < (batch_idxs_up - offset_idx).unsqueeze(1)
- )
- mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length)
- masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0)
- frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames
- # ---
-
- # apply pre-emphasis first order filter on fft windows
- frames_prev = torch.roll(frames, 1, dims=-1)
- frames_prev[:, :, 0] = frames_prev[:, :, 1]
- frames = (frames - self.preemphasis * frames_prev) * 32768
-
- # apply fft
- S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1)
- S = S.view(frames.shape[0], -1, S.shape[-1])
- S = S.to(torch.complex64)
-
- spec = torch.abs(S)
- spec_power = spec**2
-
- # apply triangular mel filter bank
- mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
- log_spec = torch.clamp(spec_power @ mel_filters, min=1.0)
- log_spec = torch.log(log_spec)
-
- return log_spec
-
- def _compute_audio_embed_size(self, audio_frames):
- integer = audio_frames // self.audio_compression_rate
- remainder = audio_frames % self.audio_compression_rate
- result = integer + (remainder > 0).to(integer.dtype)
-
- integer = result // self.audio_downsample_rate
- remainder = result % self.audio_downsample_rate
- result = integer + (remainder > 0).to(integer.dtype) # qformer compression
-
- return result
+Phi4MultimodalFeatureExtractor = deprecated_feature_extractor(
+ Phi4MultimodalAudioProcessor, "Phi4MultimodalFeatureExtractor"
+)
__all__ = ["Phi4MultimodalFeatureExtractor"]
diff --git a/src/transformers/models/pop2piano/audio_processing_pop2piano.py b/src/transformers/models/pop2piano/audio_processing_pop2piano.py
new file mode 100644
index 000000000000..9cd546b15a59
--- /dev/null
+++ b/src/transformers/models/pop2piano/audio_processing_pop2piano.py
@@ -0,0 +1,34 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# NOTE: Full Pop2Piano feature extraction requires the Essentia library for
+# beat detection (RhythmExtractor2013) and scipy for beat interpolation.
+# This audio processor provides the basic mel spectrogram configuration but
+# does not implement the complete beat-aligned segmentation pipeline.
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class Pop2PianoAudioProcessor(NumpyAudioBackend):
+ sample_rate = 22050
+ force_mono = True
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(n_fft=4096, hop_length=1024, power=2.0),
+ mel_scale_config=MelScaleConfig(n_mels=512, f_min=10.0, mel_scale="htk"),
+ log_mode="log10",
+ )
+
+
+__all__ = ["Pop2PianoAudioProcessor"]
diff --git a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py
index 4e770fcb1b71..3ab91ec37d43 100644
--- a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py
+++ b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py
@@ -11,442 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for Pop2Piano"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_pop2piano import Pop2PianoAudioProcessor
-import warnings
-
-import numpy
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, spectrogram
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import (
- TensorType,
- is_essentia_available,
- is_librosa_available,
- is_scipy_available,
- logging,
- requires_backends,
-)
-from ...utils.import_utils import requires
-
-
-if is_essentia_available():
- import essentia.standard
-
-if is_librosa_available():
- import librosa
-
-if is_scipy_available():
- import scipy
-
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("essentia", "librosa", "scipy", "torch"))
-class Pop2PianoFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Pop2Piano feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed
- to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as
- well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate
- extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate
- and preprocessed and then log mel spectogram is computed from that to be used in our transformer model.
-
- Args:
- sampling_rate (`int`, *optional*, defaults to 22050):
- Target Sampling rate of audio signal. It's the sampling rate that we forward to the model.
- padding_value (`int`, *optional*, defaults to 0):
- Padding value used to pad the audio. Should correspond to silences.
- window_size (`int`, *optional*, defaults to 4096):
- Length of the window in samples to which the Fourier transform is applied.
- hop_length (`int`, *optional*, defaults to 1024):
- Step size between each window of the waveform, in samples.
- min_frequency (`float`, *optional*, defaults to 10.0):
- Lowest frequency that will be used in the log-mel spectrogram.
- feature_size (`int`, *optional*, defaults to 512):
- The feature dimension of the extracted features.
- num_bars (`int`, *optional*, defaults to 2):
- Determines interval between each sequence.
- """
-
- model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"]
-
- def __init__(
- self,
- sampling_rate: int = 22050,
- padding_value: int = 0,
- window_size: int = 4096,
- hop_length: int = 1024,
- min_frequency: float = 10.0,
- feature_size: int = 512,
- num_bars: int = 2,
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- **kwargs,
- )
- self.sampling_rate = sampling_rate
- self.padding_value = padding_value
- self.window_size = window_size
- self.hop_length = hop_length
- self.min_frequency = min_frequency
- self.feature_size = feature_size
- self.num_bars = num_bars
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=(self.window_size // 2) + 1,
- num_mel_filters=self.feature_size,
- min_frequency=self.min_frequency,
- max_frequency=float(self.sampling_rate // 2),
- sampling_rate=self.sampling_rate,
- norm=None,
- mel_scale="htk",
- )
-
- def mel_spectrogram(self, sequence: np.ndarray):
- """
- Generates MelSpectrogram.
-
- Args:
- sequence (`numpy.ndarray`):
- The sequence of which the mel-spectrogram will be computed.
- """
- mel_specs = []
- for seq in sequence:
- window = np.hanning(self.window_size + 1)[:-1]
- mel_specs.append(
- spectrogram(
- waveform=seq,
- window=window,
- frame_length=self.window_size,
- hop_length=self.hop_length,
- power=2.0,
- mel_filters=self.mel_filters,
- )
- )
- mel_specs = np.array(mel_specs)
-
- return mel_specs
-
- def extract_rhythm(self, audio: np.ndarray):
- """
- This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as
- tempo in bpm for an audio signal. For more information please visit
- https://essentia.upf.edu/reference/std_RhythmExtractor2013.html .
-
- Args:
- audio(`numpy.ndarray`):
- raw audio waveform which is passed to the Rhythm Extractor.
- """
- requires_backends(self, ["essentia"])
- essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature")
- bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio)
-
- return bpm, beat_times, confidence, estimates, essentia_beat_intervals
-
- def interpolate_beat_times(
- self, beat_times: numpy.ndarray, steps_per_beat: numpy.ndarray, n_extend: numpy.ndarray
- ):
- """
- This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is
- then used to convert raw audio to log-mel-spectrogram.
-
- Args:
- beat_times (`numpy.ndarray`):
- beat_times is passed into `scipy.interpolate.interp1d` for processing.
- steps_per_beat (`int`):
- used as an parameter to control the interpolation.
- n_extend (`int`):
- used as an parameter to control the interpolation.
- """
-
- requires_backends(self, ["scipy"])
- beat_times_function = scipy.interpolate.interp1d(
- np.arange(beat_times.size),
- beat_times,
- bounds_error=False,
- fill_value="extrapolate",
- )
-
- ext_beats = beat_times_function(
- np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend)
- )
-
- return ext_beats
-
- def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray):
- """
- Preprocessing for log-mel-spectrogram
-
- Args:
- audio (`numpy.ndarray` of shape `(audio_length, )` ):
- Raw audio waveform to be processed.
- beatstep (`numpy.ndarray`):
- Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by
- the value at beatstep[0].
- """
-
- if audio is not None and len(audio.shape) != 1:
- raise ValueError(
- f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}."
- )
- if beatstep[0] > 0.0:
- beatstep = beatstep - beatstep[0]
-
- num_steps = self.num_bars * 4
- num_target_steps = len(beatstep)
- extrapolated_beatstep = self.interpolate_beat_times(
- beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1
- )
-
- sample_indices = []
- max_feature_length = 0
- for i in range(0, num_target_steps, num_steps):
- start_idx = i
- end_idx = min(i + num_steps, num_target_steps)
- start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate)
- end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate)
- sample_indices.append((start_sample, end_sample))
- max_feature_length = max(max_feature_length, end_sample - start_sample)
- padded_batch = []
- for start_sample, end_sample in sample_indices:
- feature = audio[start_sample:end_sample]
- padded_feature = np.pad(
- feature,
- ((0, max_feature_length - feature.shape[0]),),
- "constant",
- constant_values=0,
- )
- padded_batch.append(padded_feature)
-
- padded_batch = np.asarray(padded_batch)
- return padded_batch, extrapolated_beatstep
-
- def _pad(self, features: np.ndarray, add_zero_line=True):
- features_shapes = [each_feature.shape for each_feature in features]
- attention_masks, padded_features = [], []
- for i, each_feature in enumerate(features):
- # To pad "input_features".
- if len(each_feature.shape) == 3:
- features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1]
- attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64)
- feature_padding = ((0, 0), (0, features_pad_value), (0, 0))
- attention_mask_padding = (feature_padding[0], feature_padding[1])
-
- # To pad "beatsteps" and "extrapolated_beatstep".
- else:
- each_feature = each_feature.reshape(1, -1)
- features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0]
- attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1)
- feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value))
-
- each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value)
- attention_mask = np.pad(
- attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value
- )
-
- if add_zero_line:
- # if it is batched then we separate each examples using zero array
- zero_array_len = max([*zip(*features_shapes)][1])
-
- # we concatenate the zero array line here
- each_padded_feature = np.concatenate(
- [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0
- )
- attention_mask = np.concatenate(
- [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0
- )
-
- padded_features.append(each_padded_feature)
- attention_masks.append(attention_mask)
-
- padded_features = np.concatenate(padded_features, axis=0).astype(np.float32)
- attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64)
-
- return padded_features, attention_masks
-
- def pad(
- self,
- inputs: BatchFeature,
- is_batched: bool,
- return_attention_mask: bool,
- return_tensors: str | TensorType | None = None,
- ):
- """
- Pads the inputs to same length and returns attention_mask.
-
- Args:
- inputs (`BatchFeature`):
- Processed audio features.
- is_batched (`bool`):
- Whether inputs are batched or not.
- return_attention_mask (`bool`):
- Whether to return attention mask or not.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- If nothing is specified, it will return list of `np.ndarray` arrays.
- Return:
- `BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added
- to it:
- - **attention_mask** numpy.ndarray of shape `(batch_size, max_input_features_seq_length)` --
- Example :
- 1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 that's why there are 2 zeros at
- the end indicating they are padded)
-
- 0, 0, 0, 0, 0 (zero pad to separate audio 1 and 2)
-
- 1, 1, 1, 1, 1 (audio 2)
-
- 0, 0, 0, 0, 0 (zero pad to separate audio 2 and 3)
-
- 1, 1, 1, 1, 1 (audio 3)
- - **attention_mask_beatsteps** numpy.ndarray of shape `(batch_size, max_beatsteps_seq_length)`
- - **attention_mask_extrapolated_beatstep** numpy.ndarray of shape `(batch_size,
- max_extrapolated_beatstep_seq_length)`
- """
-
- processed_features_dict = {}
- for feature_name, feature_value in inputs.items():
- if feature_name == "input_features":
- padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True)
- processed_features_dict[feature_name] = padded_feature_values
- if return_attention_mask:
- processed_features_dict["attention_mask"] = attention_mask
- else:
- padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False)
- processed_features_dict[feature_name] = padded_feature_values
- if return_attention_mask:
- processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask
-
- # If we are processing only one example, we should remove the zero array line since we don't need it to
- # separate examples from each other.
- if not is_batched and not return_attention_mask:
- processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...]
-
- outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors)
-
- return outputs
-
- def __call__(
- self,
- audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- sampling_rate: int | list[int],
- steps_per_beat: int = 2,
- resample: bool | None = True,
- return_attention_mask: bool | None = False,
- return_tensors: str | TensorType | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model.
-
- Args:
- audio (`np.ndarray`, `List`):
- The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a
- list of numpy arrays or a list of list of float values.
- sampling_rate (`int`):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- steps_per_beat (`int`, *optional*, defaults to 2):
- This is used in interpolating `beat_times`.
- resample (`bool`, *optional*, defaults to `True`):
- Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True
- during inference.
- return_attention_mask (`bool` *optional*, defaults to `False`):
- Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as
- output or not. Automatically set to True for batched inputs.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- If nothing is specified, it will return list of `np.ndarray` arrays.
- """
-
- requires_backends(self, ["librosa"])
- is_batched = isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list))
- if is_batched:
- # This enables the user to process files of different sampling_rate at same time
- if not isinstance(sampling_rate, list):
- raise ValueError(
- "Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. "
- f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]."
- )
- return_attention_mask = True if return_attention_mask is None else return_attention_mask
- else:
- audio = [audio]
- sampling_rate = [sampling_rate]
- return_attention_mask = False if return_attention_mask is None else return_attention_mask
-
- batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], []
- for single_raw_audio, single_sampling_rate in zip(audio, sampling_rate):
- bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm(
- audio=single_raw_audio
- )
- beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1)
-
- if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None:
- if resample:
- # Change sampling_rate to self.sampling_rate
- single_raw_audio = librosa.core.resample(
- single_raw_audio,
- orig_sr=single_sampling_rate,
- target_sr=self.sampling_rate,
- res_type="kaiser_best",
- )
- else:
- warnings.warn(
- f"The sampling_rate of the provided audio is different from the target sampling_rate "
- f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. "
- f"In these cases it is recommended to use `resample=True` in the `__call__` method to "
- f"get the optimal behaviour."
- )
-
- single_sampling_rate = self.sampling_rate
- start_sample = int(beatsteps[0] * single_sampling_rate)
- end_sample = int(beatsteps[-1] * single_sampling_rate)
-
- input_features, extrapolated_beatstep = self.preprocess_mel(
- single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0]
- )
-
- mel_specs = self.mel_spectrogram(input_features.astype(np.float32))
-
- # apply np.log to get log mel-spectrograms
- log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None))
-
- input_features = np.transpose(log_mel_specs, (0, -1, -2))
-
- batch_input_features.append(input_features)
- batch_beatsteps.append(beatsteps)
- batch_ext_beatstep.append(extrapolated_beatstep)
-
- output = BatchFeature(
- {
- "input_features": batch_input_features,
- "beatsteps": batch_beatsteps,
- "extrapolated_beatstep": batch_ext_beatstep,
- }
- )
-
- output = self.pad(
- output,
- is_batched=is_batched,
- return_attention_mask=return_attention_mask,
- return_tensors=return_tensors,
- )
-
- return output
+Pop2PianoFeatureExtractor = deprecated_feature_extractor(Pop2PianoAudioProcessor, "Pop2PianoFeatureExtractor")
__all__ = ["Pop2PianoFeatureExtractor"]
diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
index dca38d2a1d01..49490e8b9bce 100644
--- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py
@@ -19,7 +19,7 @@
"""Image processor class for Qwen2-VL."""
import math
-from typing import Iterable
+from collections.abc import Iterable
import torch
diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py
new file mode 100644
index 000000000000..de597178b446
--- /dev/null
+++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py
@@ -0,0 +1,94 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class SeamlessM4tAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ do_batch_spectrogram = False
+ stride = 2
+ pad_to_multiple_of = 2 # Align feature padding to stride
+
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ win_length=400,
+ hop_length=160,
+ window_fn="povey",
+ power=2.0,
+ center=False,
+ periodic=False,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ f_min=20.0,
+ f_max=8000.0,
+ mel_scale="kaldi",
+ triangularize_in_mel_space=True,
+ ),
+ log_mode="log",
+ preemphasis=0.97,
+ remove_dc_offset=True,
+ mel_floor=1.192092955078125e-07,
+ computation_dtype="float64",
+ )
+ waveform_scale = 32768.0
+
+ def extract_spectrogram(self, audio, **kwargs):
+ # Per-waveform fbank extraction returning (time, n_mels)
+ features = []
+ for waveform in audio:
+ waveform = np.squeeze(waveform) * self.waveform_scale
+ f = super().extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config)
+ features.append(f[0].T)
+ return features
+
+ def _postprocess_features(self, features, feature_lengths):
+ # Per-utterance mean/variance normalization (before padding)
+ normalized = []
+ for f in features:
+ mean = np.expand_dims(f.mean(axis=0), 0)
+ var = np.expand_dims(f.var(axis=0, ddof=1), 0)
+ normalized.append((f - mean) / np.sqrt(var + 1e-7))
+ return normalized
+
+ def _postprocess_output(self, output, feature_ranges=None, **kwargs):
+ features = output["audio_features"] # (batch, num_frames, num_channels)
+ batch_size, num_frames, num_channels = features.shape
+
+ # Stride concatenation
+ remainder = num_frames % self.stride
+ if remainder != 0:
+ features = features[:, :num_frames - remainder, :]
+ num_frames = num_frames - remainder
+
+ output["audio_features"] = features.reshape(batch_size, num_frames // self.stride, num_channels * self.stride)
+
+ # Adjust mask for stride
+ if "audio_features_mask" in output:
+ mask = output["audio_features_mask"]
+ if remainder != 0:
+ mask = mask[:, :num_frames]
+ indices = np.arange(0, num_frames)
+ output["audio_features_mask"] = mask[:, indices % self.stride == 1]
+
+ return output
+
+
+__all__ = ["SeamlessM4tAudioProcessor"]
diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py
index 1b18dcc33404..174bc72baa16 100644
--- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py
+++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py
@@ -11,295 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for SeamlessM4T
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_seamless_m4t import SeamlessM4tAudioProcessor
-import numpy as np
-
-from ...utils import is_torch_available
-
-
-if is_torch_available():
- import torch
-
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a SeamlessM4T feature extractor.
-
- This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech.
-
- Args:
- feature_size (`int`, *optional*, defaults to 80):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- num_mel_bins (`int`, *optional*, defaults to 80):
- Number of Mel-frequency bins.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding vectors.
- stride (`int`, *optional*, defaults to 2):
- Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to
- (batch_size,num_frames//stride,num_mel_bins*stride).
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=80,
- sampling_rate=16000,
- num_mel_bins=80,
- padding_value=0.0,
- stride=2,
- **kwargs,
- ):
- self.num_mel_bins = num_mel_bins
- self.return_attention_mask = True
- self.stride = stride
-
- mel_filters = mel_filter_bank(
- num_frequency_bins=257,
- num_mel_filters=self.num_mel_bins,
- min_frequency=20,
- max_frequency=sampling_rate // 2,
- sampling_rate=sampling_rate,
- norm=None,
- mel_scale="kaldi",
- triangularize_in_mel_space=True,
- )
-
- self.mel_filters = mel_filters
- self.window = window_function(400, "povey", periodic=False)
-
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- @staticmethod
- # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
- def zero_mean_unit_var_norm(
- input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
- ) -> list[np.ndarray]:
- """
- Every array in the list is normalized to have zero mean and unit variance
- """
- if attention_mask is not None:
- attention_mask = np.array(attention_mask, np.int32)
- normed_input_values = []
-
- for vector, length in zip(input_values, attention_mask.sum(-1)):
- normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
- if length < normed_slice.shape[0]:
- normed_slice[length:] = padding_value
-
- normed_input_values.append(normed_slice)
- else:
- normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
-
- return normed_input_values
-
- def _extract_fbank_features(
- self,
- waveform: np.ndarray,
- ) -> np.ndarray:
- """
- Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
- and hence the waveform should not be normalized before feature extraction.
- """
- # by default, it extracts the left channel if stereo
- if len(waveform.shape) == 2:
- waveform = waveform[0]
-
- waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers
- features = spectrogram(
- waveform,
- self.window,
- frame_length=400,
- hop_length=160,
- fft_length=512,
- power=2.0,
- center=False,
- preemphasis=0.97,
- mel_filters=self.mel_filters,
- log_mel="log",
- mel_floor=1.192092955078125e-07,
- remove_dc_offset=True,
- ).T
- return features
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy = True,
- pad_to_multiple_of: int | None = 2,
- max_length: int | None = None,
- truncation: bool = False,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- return_attention_mask: bool | None = None,
- do_normalize_per_mel_bins: bool | None = True,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `torch.Tensor`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`,
- `list[list[float]]`, `list[list[list[float]]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array,
- a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors,
- a list of list of float values or a list of a list of list of float values.
- If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `list[float]`, `raw_speech` is
- considered a single-channel, single-sample sound. In all other cases, the first dimension of
- `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `list[...]`,
- corresponds to the number of samples in the batch, and the number of channels
- (i.e. mono or stereo character) is derived from the other dimensions
- (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches).
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- pad_to_multiple_of (`int`, *optional*, defaults to 2):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`):
- Whether or not to zero-mean unit-variance normalize the input per mel-channel.
- kwargs (*optional*):
- Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature
- extractor.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- return_attention_mask = (
- return_attention_mask if return_attention_mask is not None else self.return_attention_mask
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 3:
- raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}")
-
- acceptable_types = (
- (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list)
- )
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types))
- )
-
- if is_batched:
- raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [raw_speech]
-
- # extract fbank features
- features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
-
- if do_normalize_per_mel_bins:
- # torch defaults to ddof=1, and numpy defaults to ddof=0
- features = [
- (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7)
- for x in features
- ]
-
- # convert into correct format for padding
- encoded_inputs = BatchFeature({"input_features": features})
-
- padded_inputs = self.pad(
- encoded_inputs,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=True,
- return_tensors="np",
- )
-
- # SeamlessM4T needs to process extracted features
- input_features = padded_inputs.get("input_features")
- attention_mask = padded_inputs.pop("attention_mask")
-
- batch_size, num_frames, num_channels = input_features.shape
-
- remainder = num_frames % self.stride
- if remainder != 0:
- input_features = input_features[:, : num_frames - remainder, :]
- attention_mask = attention_mask[:, : num_frames - remainder]
-
- input_features = np.reshape(
- input_features, (batch_size, num_frames // self.stride, num_channels * self.stride)
- )
-
- indices = np.arange(0, num_frames - remainder)
- attention_mask = attention_mask[:, indices % self.stride == 1]
-
- padded_inputs["input_features"] = input_features
- if return_attention_mask:
- padded_inputs["attention_mask"] = attention_mask
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+SeamlessM4TFeatureExtractor = deprecated_feature_extractor(SeamlessM4tAudioProcessor, "SeamlessM4TFeatureExtractor")
__all__ = ["SeamlessM4TFeatureExtractor"]
diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py
new file mode 100644
index 000000000000..5f9717738982
--- /dev/null
+++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py
@@ -0,0 +1,88 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+class SpeechToTextAudioProcessor(NumpyAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ do_batch_spectrogram = False
+
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=512,
+ win_length=400,
+ hop_length=160,
+ window_fn="povey",
+ power=2.0,
+ center=False,
+ periodic=False,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ f_min=20.0,
+ f_max=8000.0,
+ mel_scale="kaldi",
+ triangularize_in_mel_space=True,
+ ),
+ log_mode="log",
+ preemphasis=0.97,
+ remove_dc_offset=True,
+ mel_floor=1.192092955078125e-07,
+ )
+ waveform_scale = 32768.0
+
+ def __init__(self, normalize_means=True, normalize_vars=True, **kwargs):
+ super().__init__(**kwargs)
+ self.normalize_means = normalize_means
+ self.normalize_vars = normalize_vars
+
+ def _extract_fbank_features(self, waveform):
+ """Extract log-mel filterbank features for a single waveform."""
+ waveform = waveform * self.waveform_scale
+ return self._kaldi_fbank(waveform, num_mel_bins=80)
+
+ def extract_spectrogram(self, audio, **kwargs):
+ # Per-waveform fbank extraction returning (time, n_mels)
+ return [self._extract_fbank_features(waveform) for waveform in audio]
+
+ @staticmethod
+ def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, padding_value=0.0):
+ if normalize_means:
+ mean = x[:input_length].mean(axis=0)
+ x = np.subtract(x, mean)
+ if normalize_vars:
+ std = x[:input_length].std(axis=0)
+ x = np.divide(x, std)
+ if input_length < x.shape[0]:
+ x[input_length:] = padding_value
+ return x.astype(np.float32)
+
+ def _postprocess_output(self, output, feature_ranges=None, **kwargs):
+ # Apply utterance CMVN normalization on the padded, stacked features
+ features = output["audio_features"] # (batch, time, n_mels)
+ normalized = []
+ for i, (start, end) in enumerate(feature_ranges):
+ length = end - start
+ normalized.append(
+ self.utterance_cmvn(features[i], length, self.normalize_means, self.normalize_vars, self.padding_value)
+ )
+ output["audio_features"] = np.stack(normalized)
+ return output
+
+
+__all__ = ["SpeechToTextAudioProcessor"]
diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
index 9685e9be0134..584afc35f229 100644
--- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py
@@ -11,301 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for Speech2Text
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_speech_to_text import SpeechToTextAudioProcessor
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, is_speech_available, logging
-
-
-if is_speech_available():
- import torch
- import torchaudio.compliance.kaldi as ta_kaldi
-
-logger = logging.get_logger(__name__)
-
-
-class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Speech2Text feature extractor.
-
- This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users
- should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy
- otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features.
-
- Args:
- feature_size (`int`, *optional*, defaults to 80):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- num_mel_bins (`int`, *optional*, defaults to 80):
- Number of Mel-frequency bins.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding vectors.
- dither (`float`, *optional*, defaults to 0.0):
- Adds dithering. In other words, adds a small Gaussian noise to each frame.
- E.g. use 4.0 to add dithering with a normal distribution centered
- around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform).
- The value 0.0 means no dithering.
- Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
- values for signals with hard-zero sections, when VAD cutoff is present in the signal.
- do_ceptral_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
- normalize_means (`bool`, *optional*, defaults to `True`):
- Whether or not to zero-mean normalize the extracted features.
- normalize_vars (`bool`, *optional*, defaults to `True`):
- Whether or not to unit-variance normalize the extracted features.
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=80,
- sampling_rate=16000,
- num_mel_bins=80,
- padding_value=0.0,
- dither=0.0,
- do_ceptral_normalize=True,
- normalize_means=True,
- normalize_vars=True,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.num_mel_bins = num_mel_bins
- self.dither = dither
- self.do_ceptral_normalize = do_ceptral_normalize
- self.normalize_means = normalize_means
- self.normalize_vars = normalize_vars
- self.return_attention_mask = True
-
- if not is_speech_available():
- mel_filters = mel_filter_bank(
- num_frequency_bins=257,
- num_mel_filters=self.num_mel_bins,
- min_frequency=20,
- max_frequency=sampling_rate // 2,
- sampling_rate=sampling_rate,
- norm=None,
- mel_scale="kaldi",
- triangularize_in_mel_space=True,
- )
-
- self.mel_filters = mel_filters
- self.window = window_function(400, "povey", periodic=False)
-
- def _extract_fbank_features(
- self,
- waveform: np.ndarray,
- ) -> np.ndarray:
- """
- Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
- and hence the waveform should not be normalized before feature extraction.
- """
- waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers
- if is_speech_available():
- waveform = torch.from_numpy(waveform).unsqueeze(0)
- features = ta_kaldi.fbank(
- waveform,
- dither=self.dither,
- num_mel_bins=self.num_mel_bins,
- sample_frequency=self.sampling_rate,
- )
- features = features.numpy()
- else:
- waveform = np.squeeze(waveform)
- features = spectrogram(
- waveform,
- self.window,
- frame_length=400,
- hop_length=160,
- fft_length=512,
- power=2.0,
- center=False,
- dither=self.dither,
- preemphasis=0.97,
- mel_filters=self.mel_filters,
- log_mel="log",
- mel_floor=1.192092955078125e-07,
- remove_dc_offset=True,
- ).T
- return features
-
- @staticmethod
- def utterance_cmvn(
- x: np.ndarray,
- input_length: int,
- normalize_means: bool | None = True,
- normalize_vars: bool | None = True,
- padding_value: float = 0.0,
- ) -> np.ndarray:
- # make sure we normalize float32 arrays
- if normalize_means:
- mean = x[:input_length].mean(axis=0)
- x = np.subtract(x, mean)
- if normalize_vars:
- std = x[:input_length].std(axis=0)
- x = np.divide(x, std)
-
- if input_length < x.shape[0]:
- x[input_length:] = padding_value
-
- # make sure array is in float32
- x = x.astype(np.float32)
-
- return x
-
- def normalize(
- self, input_features: list[np.ndarray], attention_mask: np.ndarray | None = None
- ) -> list[np.ndarray]:
- lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features]
- return [
- self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value)
- for x, n in zip(input_features, lengths)
- ]
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy = False,
- max_length: int | None = None,
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- return_attention_mask: bool | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Speech2TextTransformer models, `attention_mask` should always be passed for batched inference, to
- avoid subtle bugs.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values / vectors.
- """
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [raw_speech]
-
- # extract fbank features
- features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
-
- # convert into correct format for padding
- encoded_inputs = BatchFeature({"input_features": features})
-
- padded_inputs = self.pad(
- encoded_inputs,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
-
- # make sure list is in array format
- input_features = padded_inputs.get("input_features")
- if isinstance(input_features[0], list):
- padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
-
- attention_mask = padded_inputs.get("attention_mask")
- if attention_mask is not None:
- padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
-
- # Utterance-level cepstral mean and variance normalization
- if self.do_ceptral_normalize:
- attention_mask = (
- np.array(attention_mask, dtype=np.int32)
- if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
- else None
- )
- padded_inputs["input_features"] = self.normalize(
- padded_inputs["input_features"], attention_mask=attention_mask
- )
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+Speech2TextFeatureExtractor = deprecated_feature_extractor(SpeechToTextAudioProcessor, "Speech2TextFeatureExtractor")
__all__ = ["Speech2TextFeatureExtractor"]
diff --git a/src/transformers/models/speecht5/audio_processing_speecht5.py b/src/transformers/models/speecht5/audio_processing_speecht5.py
new file mode 100644
index 000000000000..4fc4c2226d35
--- /dev/null
+++ b/src/transformers/models/speecht5/audio_processing_speecht5.py
@@ -0,0 +1,23 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...audio_processing_backends import TorchAudioBackend
+
+
+class SpeechT5AudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+
+
+__all__ = ["SpeechT5AudioProcessor"]
diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py
index 5b9ca2e1f954..1aece171a6f3 100644
--- a/src/transformers/models/speecht5/feature_extraction_speecht5.py
+++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py
@@ -11,364 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for SpeechT5."""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_speecht5 import SpeechT5AudioProcessor
-from typing import Any
-
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a SpeechT5 feature extractor.
-
- This class can pre-process a raw speech signal by (optionally) normalizing to zero-mean unit-variance, for use by
- the SpeechT5 speech encoder prenet.
-
- This class can also extract log-mel filter bank features from raw speech, for use by the SpeechT5 speech decoder
- prenet.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance for some models.
- num_mel_bins (`int`, *optional*, defaults to 80):
- The number of mel-frequency bins in the extracted spectrogram features.
- hop_length (`int`, *optional*, defaults to 16):
- Number of ms between windows. Otherwise referred to as "shift" in many papers.
- win_length (`int`, *optional*, defaults to 64):
- Number of ms per window.
- win_function (`str`, *optional*, defaults to `"hann_window"`):
- Name for the window function used for windowing, must be accessible via `torch.{win_function}`
- fmin (`float`, *optional*, defaults to 80):
- Minimum mel frequency in Hz.
- fmax (`float`, *optional*, defaults to 7600):
- Maximum mel frequency in Hz.
- mel_floor (`float`, *optional*, defaults to 1e-10):
- Minimum value of mel frequency banks..
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
- """
-
- model_input_names = ["input_values", "attention_mask"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 16000,
- padding_value: float = 0.0,
- do_normalize: bool = False,
- num_mel_bins: int = 80,
- hop_length: int = 16,
- win_length: int = 64,
- win_function: str = "hann_window",
- fmin: float = 80,
- fmax: float = 7600,
- mel_floor: float = 1e-10,
- return_attention_mask: bool = True,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.do_normalize = do_normalize
- self.return_attention_mask = return_attention_mask
-
- self.num_mel_bins = num_mel_bins
- self.hop_length = hop_length
- self.win_length = win_length
- self.win_function = win_function
- self.fmin = fmin
- self.fmax = fmax
- self.mel_floor = mel_floor
-
- self.sample_size = win_length * sampling_rate // 1000
- self.sample_stride = hop_length * sampling_rate // 1000
- self.n_fft = optimal_fft_length(self.sample_size)
- self.n_freqs = (self.n_fft // 2) + 1
-
- self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True)
-
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=self.n_freqs,
- num_mel_filters=self.num_mel_bins,
- min_frequency=self.fmin,
- max_frequency=self.fmax,
- sampling_rate=self.sampling_rate,
- norm="slaney",
- mel_scale="slaney",
- )
-
- @staticmethod
- # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
- def zero_mean_unit_var_norm(
- input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
- ) -> list[np.ndarray]:
- """
- Every array in the list is normalized to have zero mean and unit variance
- """
- if attention_mask is not None:
- attention_mask = np.array(attention_mask, np.int32)
- normed_input_values = []
-
- for vector, length in zip(input_values, attention_mask.sum(-1)):
- normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
- if length < normed_slice.shape[0]:
- normed_slice[length:] = padding_value
-
- normed_input_values.append(normed_slice)
- else:
- normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
-
- return normed_input_values
-
- def _extract_mel_features(
- self,
- one_waveform: np.ndarray,
- ) -> np.ndarray:
- """
- Extracts log-mel filterbank features for one waveform array (unbatched).
- """
- log_mel_spec = spectrogram(
- one_waveform,
- window=self.window,
- frame_length=self.sample_size,
- hop_length=self.sample_stride,
- fft_length=self.n_fft,
- mel_filters=self.mel_filters,
- mel_floor=self.mel_floor,
- log_mel="log10",
- )
- return log_mel_spec.T
-
- def __call__(
- self,
- audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | None = None,
- audio_target: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | None = None,
- padding: bool | str | PaddingStrategy = False,
- max_length: int | None = None,
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_attention_mask: bool | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Pass in a value for `audio` to extract waveform features. Pass in a value for `audio_target` to extract log-mel
- spectrogram features.
-
- Args:
- audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*):
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must
- be mono channel audio, not stereo, i.e. single float per timestep.
- audio_target (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*):
- The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a
- list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel
- spectrogram features.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` or `audio_target` input was sampled. It is strongly recommended
- to pass `sampling_rate` at the forward call to prevent silent errors.
- """
- if audio is None and audio_target is None:
- raise ValueError("You must provide either `audio` or `audio_target` values.")
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- if audio is not None:
- inputs = self._process_audio(
- audio,
- False,
- padding,
- max_length,
- truncation,
- pad_to_multiple_of,
- return_attention_mask,
- return_tensors,
- **kwargs,
- )
- else:
- inputs = None
-
- if audio_target is not None:
- inputs_target = self._process_audio(
- audio_target,
- True,
- padding,
- max_length,
- truncation,
- pad_to_multiple_of,
- return_attention_mask,
- return_tensors,
- **kwargs,
- )
-
- if inputs is None:
- return inputs_target
- else:
- inputs["labels"] = inputs_target["input_values"]
- decoder_attention_mask = inputs_target.get("attention_mask")
- if decoder_attention_mask is not None:
- inputs["decoder_attention_mask"] = decoder_attention_mask
-
- return inputs
-
- def _process_audio(
- self,
- speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- is_target: bool = False,
- padding: bool | str | PaddingStrategy = False,
- max_length: int | None = None,
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_attention_mask: bool | None = None,
- return_tensors: str | TensorType | None = None,
- **kwargs,
- ) -> BatchFeature:
- is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1
- if is_batched_numpy and len(speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- speech = [np.asarray(speech, dtype=np.float32) for speech in speech]
- elif not is_batched and not isinstance(speech, np.ndarray):
- speech = np.asarray(speech, dtype=np.float32)
- elif isinstance(speech, np.ndarray) and speech.dtype is np.dtype(np.float64):
- speech = speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- speech = [speech]
-
- # needed to make pad() work on spectrogram inputs
- feature_size_hack = self.feature_size
-
- # convert into correct format for padding
- if is_target:
- features = [self._extract_mel_features(waveform) for waveform in speech]
- encoded_inputs = BatchFeature({"input_values": features})
- self.feature_size = self.num_mel_bins
- else:
- encoded_inputs = BatchFeature({"input_values": speech})
-
- padded_inputs = self.pad(
- encoded_inputs,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
-
- self.feature_size = feature_size_hack
-
- # convert input values to correct format
- input_values = padded_inputs["input_values"]
- if not isinstance(input_values[0], np.ndarray):
- padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
- elif (
- not isinstance(input_values, np.ndarray)
- and isinstance(input_values[0], np.ndarray)
- and input_values[0].dtype is np.dtype(np.float64)
- ):
- padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
- elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64):
- padded_inputs["input_values"] = input_values.astype(np.float32)
-
- # convert attention_mask to correct format
- attention_mask = padded_inputs.get("attention_mask")
- if attention_mask is not None:
- padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
-
- # zero-mean and unit-variance normalization
- if not is_target and self.do_normalize:
- attention_mask = (
- attention_mask
- if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
- else None
- )
- padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
- padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
- )
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
-
- def to_dict(self) -> dict[str, Any]:
- output = super().to_dict()
-
- # Don't serialize these as they are derived from the other properties.
- names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
- for name in names:
- if name in output:
- del output[name]
-
- return output
+SpeechT5FeatureExtractor = deprecated_feature_extractor(SpeechT5AudioProcessor, "SpeechT5FeatureExtractor")
__all__ = ["SpeechT5FeatureExtractor"]
diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py
new file mode 100644
index 000000000000..6bd6e16b2af3
--- /dev/null
+++ b/src/transformers/models/univnet/audio_processing_univnet.py
@@ -0,0 +1,89 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+from ...audio_processing_backends import NumpyAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class UnivNetAudioProcessor(NumpyAudioBackend):
+ sample_rate = 24000
+ force_mono = True
+ mask_level = "audio"
+ mel_floor = 1e-9
+ compression_clip_val = 1e-5
+ compression_factor = 1.0
+ do_normalize = False
+ normalize_min = -11.512925148010254
+ normalize_max = 2.3143386840820312
+ max_length_s = 10
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=1024,
+ hop_length=256,
+ center=False,
+ window_fn="hann",
+ periodic=True,
+ power=1.0,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=100,
+ f_min=0.0,
+ f_max=12000.0,
+ mel_scale="slaney",
+ norm="slaney",
+ ),
+ log_mode="log",
+ mel_floor=1e-5,
+ computation_dtype="float64",
+ )
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.num_max_samples = self.max_length_s * self.sample_rate
+
+ def _stft(self, audio, *, spectrogram_config, **kwargs):
+ # UnivNet uses reflect padding with (n_fft - hop_length) / 2 instead of center padding
+ stft_cfg = spectrogram_config.stft_config
+ pad_amount = int((stft_cfg.n_fft - stft_cfg.hop_length) / 2)
+ if audio.ndim > 1:
+ audio = np.pad(audio, ((0, 0), (pad_amount, pad_amount)), mode="reflect")
+ else:
+ audio = np.pad(audio, (pad_amount, pad_amount), mode="reflect")
+ return super()._stft(audio, spectrogram_config=spectrogram_config, **kwargs)
+
+ def _compute_magnitudes(self, stft_out, power, spectrogram_config=None):
+ # UnivNet adds mel_floor inside the sqrt: sqrt(real² + imag² + mel_floor)
+ return np.sqrt(np.real(stft_out) ** 2 + np.imag(stft_out) ** 2 + self.mel_floor)
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ # UnivNet applies mel filterbank without a floor
+ return np.matmul(self.mel_filters.T, features)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs):
+ features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs)
+ if self.do_normalize:
+ features = 2 * ((features - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1
+ return features
+
+ def extract_spectrogram(self, audio, *, spectrogram_config, **kwargs):
+ features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs)
+ # Transpose from (..., n_mels, frames) to (..., frames, n_mels)
+ if isinstance(features, list):
+ return [np.swapaxes(f, -2, -1) for f in features]
+ return np.swapaxes(features, -2, -1)
+
+
+__all__ = ["UnivNetAudioProcessor"]
diff --git a/src/transformers/models/univnet/feature_extraction_univnet.py b/src/transformers/models/univnet/feature_extraction_univnet.py
index 84e9420a0f75..73ae758ee708 100644
--- a/src/transformers/models/univnet/feature_extraction_univnet.py
+++ b/src/transformers/models/univnet/feature_extraction_univnet.py
@@ -11,448 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Feature extractor class for UnivNetModel."""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_univnet import UnivNetAudioProcessor
-from typing import Any
-
-import numpy as np
-
-from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class UnivNetFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a UnivNet feature extractor.
-
- This class extracts log-mel-filter bank features from raw speech using the short time Fourier Transform (STFT). The
- STFT implementation follows that of TacoTron 2 and Hifi-GAN.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 24000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value to pad with when applying the padding strategy defined by the `padding` argument to
- [`UnivNetFeatureExtractor.__call__`]. Should correspond to audio silence. The `pad_end` argument to
- `__call__` will also use this padding value.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve the
- performance for some models.
- num_mel_bins (`int`, *optional*, defaults to 100):
- The number of mel-frequency bins in the extracted spectrogram features. This should match
- `UnivNetModel.config.num_mel_bins`.
- hop_length (`int`, *optional*, defaults to 256):
- The direct number of samples between sliding windows. Otherwise referred to as "shift" in many papers. Note
- that this is different from other audio feature extractors such as [`SpeechT5FeatureExtractor`] which take
- the `hop_length` in ms.
- win_length (`int`, *optional*, defaults to 1024):
- The direct number of samples for each sliding window. Note that this is different from other audio feature
- extractors such as [`SpeechT5FeatureExtractor`] which take the `win_length` in ms.
- win_function (`str`, *optional*, defaults to `"hann_window"`):
- Name for the window function used for windowing, must be accessible via `torch.{win_function}`
- filter_length (`int`, *optional*, defaults to 1024):
- The number of FFT components to use. If `None`, this is determined using
- `transformers.audio_utils.optimal_fft_length`.
- max_length_s (`int`, *optional*, defaults to 10):
- The maximum input length of the model in seconds. This is used to pad the audio.
- fmin (`float`, *optional*, defaults to 0.0):
- Minimum mel frequency in Hz.
- fmax (`float`, *optional*):
- Maximum mel frequency in Hz. If not set, defaults to `sampling_rate / 2`.
- mel_floor (`float`, *optional*, defaults to 1e-09):
- Minimum value of mel frequency banks. Note that the way [`UnivNetFeatureExtractor`] uses `mel_floor` is
- different than in [`transformers.audio_utils.spectrogram`].
- center (`bool`, *optional*, defaults to `False`):
- Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
- `t` will start at time `t * hop_length`.
- compression_factor (`float`, *optional*, defaults to 1.0):
- The multiplicative compression factor for dynamic range compression during spectral normalization.
- compression_clip_val (`float`, *optional*, defaults to 1e-05):
- The clip value applied to the waveform before applying dynamic range compression during spectral
- normalization.
- normalize_min (`float`, *optional*, defaults to -11.512925148010254):
- The min value used for Tacotron 2-style linear normalization. The default is the original value from the
- Tacotron 2 implementation.
- normalize_max (`float`, *optional*, defaults to 2.3143386840820312):
- The max value used for Tacotron 2-style linear normalization. The default is the original value from the
- Tacotron 2 implementation.
- model_in_channels (`int`, *optional*, defaults to 64):
- The number of input channels to the [`UnivNetModel`] model. This should match
- `UnivNetModel.config.model_in_channels`.
- pad_end_length (`int`, *optional*, defaults to 10):
- If padding the end of each waveform, the number of spectrogram frames worth of samples to append. The
- number of appended samples will be `pad_end_length * hop_length`.
- return_attention_mask (`bool`, *optional*, defaults to `True`):
- Whether or not [`~UnivNetFeatureExtractor.__call__`] should return `attention_mask`.
- """
-
- model_input_names = ["input_features", "noise_sequence", "padding_mask"]
-
- def __init__(
- self,
- feature_size: int = 1,
- sampling_rate: int = 24000,
- padding_value: float = 0.0,
- do_normalize: bool = False,
- num_mel_bins: int = 100,
- hop_length: int = 256,
- win_length: int = 1024,
- win_function: str = "hann_window",
- filter_length: int | None = 1024,
- max_length_s: int = 10,
- fmin: float = 0.0,
- fmax: float | None = None,
- mel_floor: float = 1e-9,
- center: bool = False,
- compression_factor: float = 1.0,
- compression_clip_val: float = 1e-5,
- normalize_min: float = -11.512925148010254,
- normalize_max: float = 2.3143386840820312,
- model_in_channels: int = 64,
- pad_end_length: int = 10,
- return_attention_mask=True,
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
-
- self.do_normalize = do_normalize
-
- self.num_mel_bins = num_mel_bins
- self.hop_length = hop_length
- self.win_length = win_length
- self.win_function = win_function
- self.filter_length = filter_length
- self.fmin = fmin
- if fmax is None:
- # Follows the librosa.filters.mel implementation
- fmax = float(sampling_rate) / 2
- self.fmax = fmax
- self.mel_floor = mel_floor
-
- self.max_length_s = max_length_s
- self.num_max_samples = max_length_s * sampling_rate
-
- if self.filter_length is None:
- self.n_fft = optimal_fft_length(self.win_length)
- else:
- self.n_fft = self.filter_length
- self.n_freqs = (self.n_fft // 2) + 1
-
- self.window = window_function(window_length=self.win_length, name=self.win_function, periodic=True)
-
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=self.n_freqs,
- num_mel_filters=self.num_mel_bins,
- min_frequency=self.fmin,
- max_frequency=self.fmax,
- sampling_rate=self.sampling_rate,
- norm="slaney",
- mel_scale="slaney",
- )
-
- self.center = center
- self.compression_factor = compression_factor
- self.compression_clip_val = compression_clip_val
- self.normalize_min = normalize_min
- self.normalize_max = normalize_max
- self.model_in_channels = model_in_channels
- self.pad_end_length = pad_end_length
-
- def normalize(self, spectrogram):
- return 2 * ((spectrogram - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1
-
- def denormalize(self, spectrogram):
- return self.normalize_min + (self.normalize_max - self.normalize_min) * ((spectrogram + 1) / 2)
-
- def mel_spectrogram(self, waveform: np.ndarray) -> np.ndarray:
- """
- Calculates log MEL spectrograms from a batch of waveforms. Note that the input waveform(s) will be padded by
- `int(self.n_fft - self.hop_length) / 2` on both sides using the `reflect` padding mode.
-
- Args:
- waveform (`np.ndarray` of shape `(length,)`):
- The input waveform. This must be a single real-valued, mono waveform.
-
- Returns:
- `numpy.ndarray`: Array containing a log-mel spectrogram of shape `(num_frames, num_mel_bins)`.
- """
- # Do custom padding based on the official MelGAN and Hifi-GAN implementations
- # See https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/utils/stft.py#L84-L86
- waveform = np.pad(
- waveform,
- (int((self.n_fft - self.hop_length) / 2), int((self.n_fft - self.hop_length) / 2)),
- mode="reflect",
- )
-
- # Get the complex spectrogram.
- # Note: waveform must be unbatched currently due to the implementation of spectrogram(...).
- complex_spectrogram = spectrogram(
- waveform,
- window=self.window,
- frame_length=self.n_fft,
- hop_length=self.hop_length,
- fft_length=self.n_fft,
- power=None,
- center=self.center,
- mel_filters=None,
- mel_floor=None,
- )
-
- # Apply the MEL filter bank and MEL floor manually since UnivNet uses a slightly different implementation
- amplitude_spectrogram = np.sqrt(
- np.real(complex_spectrogram) ** 2 + np.imag(complex_spectrogram) ** 2 + self.mel_floor
- )
- mel_spectrogram = np.matmul(self.mel_filters.T, amplitude_spectrogram)
-
- # Perform spectral normalization to get the log mel spectrogram.
- log_mel_spectrogram = np.log(
- np.clip(mel_spectrogram, a_min=self.compression_clip_val, a_max=None) * self.compression_factor
- )
-
- # Return spectrogram with num_mel_bins last
- return log_mel_spectrogram.T
-
- def generate_noise(
- self,
- noise_length: int,
- generator: np.random.Generator | None = None,
- ) -> np.ndarray:
- """
- Generates a random noise sequence of standard Gaussian noise for use in the `noise_sequence` argument of
- [`UnivNetModel.forward`].
-
- Args:
- spectrogram_length (`int`):
- The length (dim 0) of the generated noise.
- model_in_channels (`int`, *optional*, defaults to `None`):
- The number of features (dim 1) of the generated noise. This should correspond to the
- `model_in_channels` of the [`UnivNetGan`] model. If not set, this will default to
- `self.config.model_in_channels`.
- generator (`numpy.random.Generator`, *optional*, defaults to `None`)
- An optional `numpy.random.Generator` random number generator to control noise generation. If not set, a
- new generator with fresh entropy will be created.
-
- Returns:
- `numpy.ndarray`: Array containing random standard Gaussian noise of shape `(noise_length,
- model_in_channels)`.
- """
- if generator is None:
- generator = np.random.default_rng()
-
- noise_shape = (noise_length, self.model_in_channels)
- noise = generator.standard_normal(noise_shape, dtype=np.float32)
-
- return noise
-
- def batch_decode(self, waveforms, waveform_lengths=None) -> list[np.ndarray]:
- r"""
- Removes padding from generated audio after running [`UnivNetModel.forward`]. This returns a ragged list of 1D
- audio waveform arrays and not a single tensor/array because in general the waveforms will have different
- lengths after removing padding.
-
- Args:
- waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- The batched output waveforms from the [`UnivNetModel`].
- waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
- The batched lengths of each waveform before padding.
-
- Returns:
- `list[np.ndarray]`: A ragged list of 1D waveform arrays with padding removed.
- """
- # Collapse the batched waveform tensor to a list of 1D audio waveforms
- waveforms = [waveform.detach().to(device="cpu", copy=True).numpy() for waveform in waveforms]
-
- if waveform_lengths is not None:
- waveforms = [waveform[: waveform_lengths[i]] for i, waveform in enumerate(waveforms)]
-
- return waveforms
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- sampling_rate: int | None = None,
- padding: bool | str | PaddingStrategy = True,
- max_length: int | None = None,
- truncation: bool = True,
- pad_to_multiple_of: int | None = None,
- return_noise: bool = True,
- generator: np.random.Generator | None = None,
- pad_end: bool = False,
- pad_length: int | None = None,
- do_normalize: str | None = None,
- return_attention_mask: bool | None = None,
- return_tensors: str | TensorType | None = None,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the input `raw_speech` waveforms (according to the model's padding side and
- padding index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
-
- If `pad_end = True`, that padding will occur before the `padding` strategy is applied.
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`, *optional*, defaults to `True`):
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_noise (`bool`, *optional*, defaults to `True`):
- Whether to generate and return a noise waveform for use in [`UnivNetModel.forward`].
- generator (`numpy.random.Generator`, *optional*, defaults to `None`):
- An optional `numpy.random.Generator` random number generator to use when generating noise.
- pad_end (`bool`, *optional*, defaults to `False`):
- Whether to pad the end of each waveform with silence. This can help reduce artifacts at the end of the
- generated audio sample; see https://github.com/seungwonpark/melgan/issues/8 for more details. This
- padding will be done before the padding strategy specified in `padding` is performed.
- pad_length (`int`, *optional*, defaults to `None`):
- If padding the end of each waveform, the length of the padding in spectrogram frames. If not set, this
- will default to `self.config.pad_end_length`.
- do_normalize (`bool`, *optional*):
- Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve
- the performance for some models. If not set, this will default to `self.config.do_normalize`.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.np.array` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- """
- do_normalize = do_normalize if do_normalize is not None else self.do_normalize
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [np.asarray(raw_speech, dtype=np.float32)]
-
- # Pad end to reduce artifacts
- if pad_end:
- pad_length = pad_length if pad_length is not None else self.pad_end_length
- raw_speech = [
- np.pad(waveform, (0, pad_length * self.hop_length), constant_values=self.padding_value)
- for waveform in raw_speech
- ]
-
- batched_speech = BatchFeature({"input_features": raw_speech})
-
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length if max_length is not None else self.num_max_samples,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- )
-
- # make sure list is in array format
- # input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
- input_features = padded_inputs.get("input_features")
-
- mel_spectrograms = [self.mel_spectrogram(waveform) for waveform in input_features]
-
- if isinstance(input_features[0], list):
- batched_speech["input_features"] = [np.asarray(mel, dtype=np.float32) for mel in mel_spectrograms]
- else:
- batched_speech["input_features"] = [mel.astype(np.float32) for mel in mel_spectrograms]
-
- # convert attention_mask to correct format
- attention_mask = padded_inputs.get("attention_mask")
- if attention_mask is not None:
- batched_speech["padding_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
-
- if return_noise:
- noise = [
- self.generate_noise(spectrogram.shape[0], generator)
- for spectrogram in batched_speech["input_features"]
- ]
- batched_speech["noise_sequence"] = noise
-
- if do_normalize:
- batched_speech["input_features"] = [
- self.normalize(spectrogram) for spectrogram in batched_speech["input_features"]
- ]
-
- if return_tensors is not None:
- batched_speech = batched_speech.convert_to_tensors(return_tensors)
-
- return batched_speech
-
- def to_dict(self) -> dict[str, Any]:
- output = super().to_dict()
-
- # Don't serialize these as they are derived from the other properties.
- names = ["window", "mel_filters", "n_fft", "n_freqs", "num_max_samples"]
- for name in names:
- if name in output:
- del output[name]
-
- return output
+UnivNetFeatureExtractor = deprecated_feature_extractor(UnivNetAudioProcessor, "UnivNetFeatureExtractor")
__all__ = ["UnivNetFeatureExtractor"]
diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py
new file mode 100644
index 000000000000..866113b39b82
--- /dev/null
+++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py
@@ -0,0 +1,38 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+
+
+class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend):
+ sample_rate = 24000
+ force_mono = True
+ add_channel_dim = True
+
+ target_dB_FS = -25
+ eps = 1e-6
+
+ def _process_audio(self, audio_el):
+ audio_el = super()._process_audio(audio_el)
+ rms = torch.sqrt(torch.mean(audio_el**2))
+ audio_el = audio_el * (10 ** (self.target_dB_FS / 20) / (rms + self.eps))
+ max_val = torch.max(torch.abs(audio_el))
+ if max_val > 1.0:
+ audio_el = audio_el / (max_val + self.eps)
+ return audio_el
+
+
+__all__ = ["VibevoiceAcousticTokenizerAudioProcessor"]
diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py
index 6e0c82762283..0f38cd9df814 100644
--- a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py
+++ b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py
@@ -11,134 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_vibevoice_acoustic_tokenizer import VibevoiceAcousticTokenizerAudioProcessor
-from ...audio_utils import AudioInput, make_list_of_audio
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, logging
-from ...utils.import_utils import is_torch_available, requires
-
-
-if is_torch_available():
- import torch
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("torch",))
-class VibeVoiceAcousticTokenizerFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a VibeVoiceAcousticTokenizer feature extractor.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The number of channels.
- sampling_rate (`int`, *optional*, defaults to 24000):
- The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used for padding.
- normalize_audio (`bool`, *optional*, defaults to `True`):
- Whether to normalize audio to a target dB FS.
- target_dB_FS (`float`, *optional*, defaults to -25):
- Target dB FS for normalization.
- eps (`float`, *optional*, defaults to 1e-06):
- A small value to avoid division by zero when normalizing.
-
- """
-
- model_input_names = ["input_values", "padding_mask"]
-
- def __init__(
- self,
- feature_size=1,
- sampling_rate=24000,
- padding_value=0.0,
- normalize_audio=True,
- target_dB_FS=-25,
- eps=1e-6,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- self.normalize_audio = normalize_audio
- self.target_dB_FS = target_dB_FS
- self.eps = eps
-
- def __call__(
- self,
- audio: AudioInput,
- sampling_rate: int | None = None,
- padding: bool | str | PaddingStrategy | None = True,
- pad_to_multiple_of: int | None = None,
- return_attention_mask: bool | None = True,
- ) -> BatchFeature:
- """
- Args:
- audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`:
- The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a torch tensor,
- a list of numpy arrays or a list of torch tensors.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- # Ensure batch of mono tensors
- audio = make_list_of_audio(audio)
- for idx, example in enumerate(audio):
- example = torch.tensor(example, dtype=torch.float32)
- if example.ndim != 1:
- raise ValueError(f"Audio should be mono, got shape: {example.shape}")
- audio[idx] = example
-
- if self.normalize_audio:
- for idx, example in enumerate(audio):
- rms = torch.sqrt(torch.mean(example**2))
- example *= 10 ** (self.target_dB_FS / 20) / (rms + self.eps)
- max_val = torch.max(torch.abs(example))
- if max_val > 1.0:
- example = example / (max_val + self.eps)
- audio[idx] = example
-
- output_values = BatchFeature({"input_values": audio})
- if padding or pad_to_multiple_of:
- output_values = self.pad(
- output_values,
- padding=padding,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- )
- if return_attention_mask:
- output_values["padding_mask"] = output_values.pop("attention_mask")
-
- # add channel dimension
- output_values["input_values"] = output_values["input_values"][:, None, :]
-
- return output_values
+VibeVoiceAcousticTokenizerFeatureExtractor = deprecated_feature_extractor(
+ VibevoiceAcousticTokenizerAudioProcessor, "VibeVoiceAcousticTokenizerFeatureExtractor"
+)
__all__ = ["VibeVoiceAcousticTokenizerFeatureExtractor"]
diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py
new file mode 100644
index 000000000000..59ff6ad89176
--- /dev/null
+++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py
@@ -0,0 +1,62 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class VoxtralRealtimeAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=400,
+ hop_length=160,
+ power=2.0,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=128,
+ mel_scale="slaney",
+ norm="slaney",
+ computation_dtype="float64",
+ ),
+ log_mode="log10",
+ skip_last_frame=True,
+ )
+ global_log_mel_max = 1.5
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ mel_filters = self.mel_filters.to(device=features.device)
+ return torch.clamp(torch.matmul(mel_filters.T, features), min=spectrogram_config.mel_floor)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs):
+ features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs)
+
+ if self.global_log_mel_max is not None:
+ spec_max = torch.tensor(self.global_log_mel_max, device=features.device, dtype=features.dtype)
+ else:
+ spec_max = features.amax(dim=(-2, -1), keepdim=True)
+ features = torch.maximum(features, spec_max - 8.0)
+ features = (features + 4.0) / 4.0
+ return features
+
+ def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False):
+ stft_cfg = spectrogram_config.stft_config
+ win_length = stft_cfg.win_length or stft_cfg.n_fft
+ return (audio_lengths - win_length) // stft_cfg.hop_length + 1
+
+
+__all__ = ["VoxtralRealtimeAudioProcessor"]
diff --git a/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py
index 58355f3c0d7c..09e49995be51 100644
--- a/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py
+++ b/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py
@@ -11,236 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_voxtral_realtime import VoxtralRealtimeAudioProcessor
-import numpy as np
-import torch
-
-from ...audio_utils import mel_filter_bank
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, logging
-from ...utils.import_utils import requires
-
-
-logger = logging.get_logger(__name__)
-
-
-@requires(backends=("torch",))
-class VoxtralRealtimeFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a VOXTRAL_REALTIME feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
- Fourier Transform` which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 128):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- hop_length (`int`, *optional*, defaults to 160):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- n_fft (`int`, *optional*, defaults to 512):
- Size of the Fourier transform.
- win_length (`int`, *optional*, defaults to 400):
- The window length for the STFT computation.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- """
-
- model_input_names = ["input_features", "attention_mask"]
-
- def __init__(
- self,
- feature_size=128,
- sampling_rate=16000,
- hop_length=160,
- n_fft=400,
- win_length=400,
- padding_value=0.0,
- global_log_mel_max=1.5,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
-
- self.hop_length = hop_length
- self.n_fft = n_fft
- self.win_length = win_length
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=1 + n_fft // 2,
- num_mel_filters=feature_size,
- min_frequency=0.0,
- max_frequency=8000.0,
- sampling_rate=sampling_rate,
- norm="slaney",
- mel_scale="slaney",
- )
- self.global_log_mel_max = global_log_mel_max
-
- def _torch_extract_fbank_features(self, waveform, device: str = "cpu", center: bool = True):
- window = torch.hann_window(self.n_fft, device=device)
- stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True, center=center)
- magnitudes = stft[..., :-1].abs() ** 2
-
- mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
- mel_spec = mel_filters.T @ magnitudes
-
- log_spec = torch.clamp(mel_spec, min=1e-10).log10()
- if self.global_log_mel_max is not None:
- log_spec_max = torch.tensor(
- self.global_log_mel_max,
- device=log_spec.device,
- dtype=log_spec.dtype,
- )
- else:
- log_spec_max = log_spec.max()
-
- log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- if device != "cpu":
- log_spec = log_spec.detach().cpu()
- return log_spec
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = None,
- padding: str | None = "longest",
- max_length: int | None = None,
- sampling_rate: int | None = None,
- do_normalize: bool | None = None,
- device: str | None = "cpu",
- return_token_timestamps: bool | None = None,
- center: bool = True,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
- the STFT computation if available, otherwise a slower NumPy based one.
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'tf'`: Return TensorFlow `tf.constant` objects.
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values / vectors.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance of the model.
- device (`str`, *optional*, defaults to `'cpu'`):
- Specifies the device for computation of the log-mel spectrogram of audio signals in the
- `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
- return_token_timestamps (`bool`, *optional*, defaults to `None`):
- Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred.
-
- Whether or not to return the number of frames of the input raw_speech.
- These num_frames can be used by the model to compute word level timestamps.
- center (`bool`, *optional*, defaults to `True`):
- Whether to use centering for the STFT computation.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- # Convert to torch tensor
- if isinstance(raw_speech, np.ndarray):
- raw_speech = torch.tensor(raw_speech)
- elif isinstance(raw_speech, (list, tuple)):
- if isinstance(raw_speech[0], (list, np.ndarray)):
- raw_speech = [torch.tensor(speech) for speech in raw_speech]
- else: # list[float]
- raw_speech = torch.tensor(raw_speech)
-
- is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1
- if is_batched_torch and len(raw_speech.shape) > 2:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- raw_speech = raw_speech.mean(-1)
-
- is_batched_sequence = isinstance(raw_speech, (list, tuple))
- if is_batched_sequence:
- for speech in raw_speech:
- if len(speech.shape) > 1:
- logger.warning(
- f"Only mono-channel audio is supported for input to {self.__class__.__name__}. "
- "We will take the mean of the channels to convert to mono."
- )
- speech = speech.mean(-1)
-
- if is_batched_torch or is_batched_sequence:
- raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech]
- else:
- raw_speech = [raw_speech[:, None].to(torch.float32)]
-
- batched_speech = BatchFeature({"input_features": raw_speech})
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- return_tensors="pt",
- )
- input_features = padded_inputs.input_features.squeeze(-1)
- input_features = self._torch_extract_fbank_features(input_features, device, center)
- data = {
- "input_features": input_features.to(torch.float32),
- }
-
- if return_attention_mask:
- attention_mask = padded_inputs.attention_mask[:, self.win_length - 1 :: self.hop_length]
- data["attention_mask"] = attention_mask.to(torch.bool)
-
- return BatchFeature(data=data, tensor_type=return_tensors)
+VoxtralRealtimeFeatureExtractor = deprecated_feature_extractor(
+ VoxtralRealtimeAudioProcessor, "VoxtralRealtimeFeatureExtractor"
+)
__all__ = ["VoxtralRealtimeFeatureExtractor"]
diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py
new file mode 100644
index 000000000000..66467620f39d
--- /dev/null
+++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py
@@ -0,0 +1,34 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+
+
+class Wav2Vec2AudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ do_normalize = True
+
+ def _process_audio(self, audio_el):
+ audio_el = super()._process_audio(audio_el)
+
+ if self.do_normalize:
+ audio_el = (audio_el - audio_el.mean()) / torch.sqrt(audio_el.var(correction=0) + 1e-7)
+
+ return audio_el
+
+
+__all__ = ["Wav2Vec2AudioProcessor"]
diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
index dea2f3af5b48..bc4c8fdee07e 100644
--- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py
@@ -11,229 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for Wav2Vec2
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_wav2vec2 import Wav2Vec2AudioProcessor
-import numpy as np
-
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import PaddingStrategy, TensorType, logging
-
-
-logger = logging.get_logger(__name__)
-
-
-class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Wav2Vec2 feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- Args:
- feature_size (`int`, *optional*, defaults to 1):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- padding_value (`float`, *optional*, defaults to 0.0):
- The value that is used to fill the padding values.
- do_normalize (`bool`, *optional*, defaults to `True`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance for some models, *e.g.*,
- [wav2vec2-lv60](https://huggingface.co/models?search=lv60).
- return_attention_mask (`bool`, *optional*, defaults to `False`):
- Whether or not [`~Wav2Vec2FeatureExtractor.__call__`] should return `attention_mask`.
-
-
-
- Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as
- [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using
- `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask`
- should be passed.
-
- For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as
- [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be
- passed for batched inference.
-
- """
-
- model_input_names = ["input_values", "attention_mask"]
-
- def __init__(
- self,
- feature_size=1,
- sampling_rate=16000,
- padding_value=0.0,
- return_attention_mask=False,
- do_normalize=True,
- **kwargs,
- ):
- super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
- self.return_attention_mask = return_attention_mask
- self.do_normalize = do_normalize
-
- @staticmethod
- def zero_mean_unit_var_norm(
- input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
- ) -> list[np.ndarray]:
- """
- Every array in the list is normalized to have zero mean and unit variance
- """
- if attention_mask is not None:
- attention_mask = np.array(attention_mask, np.int32)
- normed_input_values = []
-
- for vector, length in zip(input_values, attention_mask.sum(-1)):
- normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
- if length < normed_slice.shape[0]:
- normed_slice[length:] = padding_value
-
- normed_input_values.append(normed_slice)
- else:
- normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
-
- return normed_input_values
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- padding: bool | str | PaddingStrategy = False,
- max_length: int | None = None,
- truncation: bool = False,
- pad_to_multiple_of: int | None = None,
- return_attention_mask: bool | None = None,
- return_tensors: str | TensorType | None = None,
- sampling_rate: int | None = None,
- **kwargs,
- ) -> BatchFeature:
- """
- Main method to featurize and prepare for the model one or several sequence(s).
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
- index) among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
- sequence if provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- truncation (`bool`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as
- [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using
- `attention_mask`. For such models, `input_values` should simply be padded with 0 and no
- `attention_mask` should be passed.
-
- For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as
- [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should
- be passed for batched inference.
-
-
-
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors.
- padding_value (`float`, *optional*, defaults to 0.0):
- """
-
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
- f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
- f" {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- # always return batch
- if not is_batched:
- raw_speech = [raw_speech]
-
- # convert into correct format for padding
- encoded_inputs = BatchFeature({"input_values": raw_speech})
-
- padded_inputs = self.pad(
- encoded_inputs,
- padding=padding,
- max_length=max_length,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask,
- )
-
- # convert input values to correct format
- input_values = padded_inputs["input_values"]
- if not isinstance(input_values[0], np.ndarray):
- padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values]
- elif (
- not isinstance(input_values, np.ndarray)
- and isinstance(input_values[0], np.ndarray)
- and input_values[0].dtype is np.dtype(np.float64)
- ):
- padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values]
- elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64):
- padded_inputs["input_values"] = input_values.astype(np.float32)
-
- # convert attention_mask to correct format
- attention_mask = padded_inputs.get("attention_mask")
- if attention_mask is not None:
- padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask]
-
- # zero-mean and unit-variance normalization
- if self.do_normalize:
- attention_mask = (
- attention_mask
- if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD
- else None
- )
- padded_inputs["input_values"] = self.zero_mean_unit_var_norm(
- padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
- )
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+Wav2Vec2FeatureExtractor = deprecated_feature_extractor(Wav2Vec2AudioProcessor, "Wav2Vec2FeatureExtractor")
__all__ = ["Wav2Vec2FeatureExtractor"]
diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py
new file mode 100644
index 000000000000..0a7f5bffa9be
--- /dev/null
+++ b/src/transformers/models/whisper/audio_processing_whisper.py
@@ -0,0 +1,57 @@
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from ...audio_processing_backends import TorchAudioBackend
+from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig
+
+
+class WhisperAudioProcessor(TorchAudioBackend):
+ sample_rate = 16000
+ force_mono = True
+ return_padding_mask = False
+ truncation = True
+ max_length = 480000 # 30 seconds at 16000 Hz
+ spectrogram_config = SpectrogramConfig(
+ stft_config=StftConfig(
+ n_fft=400,
+ hop_length=160,
+ power=2.0,
+ ),
+ mel_scale_config=MelScaleConfig(
+ n_mels=80,
+ mel_scale="slaney",
+ norm="slaney",
+ computation_dtype="float64",
+ ),
+ log_mode="log10",
+ skip_last_frame=True,
+ )
+
+ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs):
+ mel_filters = self.mel_filters.to(device=features.device)
+ return torch.clamp(torch.matmul(mel_filters.T, features), min=spectrogram_config.mel_floor)
+
+ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs):
+ features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs)
+
+ max_vals = features.amax(dim=(-2, -1), keepdim=True)
+ features = torch.maximum(features, max_vals - 8.0)
+ features = (features + 4.0) / 4.0
+
+ return features
+
+
+__all__ = ["WhisperAudioProcessor"]
diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py
index 4151a3824dfd..4e4f49df3c2d 100644
--- a/src/transformers/models/whisper/feature_extraction_whisper.py
+++ b/src/transformers/models/whisper/feature_extraction_whisper.py
@@ -11,335 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""
-Feature extractor class for Whisper
-"""
+from ...utils.deprecation import deprecated_feature_extractor
+from .audio_processing_whisper import WhisperAudioProcessor
-import numpy as np
-
-from ... import is_torch_available
-from ...audio_utils import mel_filter_bank, spectrogram, window_function
-from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
-from ...feature_extraction_utils import BatchFeature
-from ...utils import TensorType, logging
-
-
-if is_torch_available():
- import torch
-
-logger = logging.get_logger(__name__)
-
-
-class WhisperFeatureExtractor(SequenceFeatureExtractor):
- r"""
- Constructs a Whisper feature extractor.
-
- This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains
- most of the main methods. Users should refer to this superclass for more information regarding those methods.
-
- This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time
- Fourier Transform` which should match pytorch's `torch.stft` equivalent.
-
- Args:
- feature_size (`int`, *optional*, defaults to 80):
- The feature dimension of the extracted features.
- sampling_rate (`int`, *optional*, defaults to 16000):
- The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
- hop_length (`int`, *optional*, defaults to 160):
- Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients.
- chunk_length (`int`, *optional*, defaults to 30):
- The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio
- sequences.
- n_fft (`int`, *optional*, defaults to 400):
- Size of the Fourier transform.
- padding_value (`float`, *optional*, defaults to 0.0):
- Padding value used to pad the audio. Should correspond to silences.
- dither (`float`, *optional*, defaults to 0.0):
- Adds dithering. In other words, adds a small Gaussian noise to each frame.
- E.g. use 0.0001 to add dithering with a normal distribution centered
- around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech).
- The value 0.0 means no dithering.
- Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces
- the high log_mel_fbank values for signals with hard-zero sections,
- when VAD cutoff is present in the signal.
- """
-
- model_input_names = ["input_features"]
-
- def __init__(
- self,
- feature_size=80,
- sampling_rate=16000,
- hop_length=160,
- chunk_length=30,
- n_fft=400,
- padding_value=0.0,
- dither=0.0,
- return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask
- **kwargs,
- ):
- super().__init__(
- feature_size=feature_size,
- sampling_rate=sampling_rate,
- padding_value=padding_value,
- return_attention_mask=return_attention_mask,
- **kwargs,
- )
- self.n_fft = n_fft
- self.hop_length = hop_length
- self.chunk_length = chunk_length
- self.n_samples = chunk_length * sampling_rate
- self.nb_max_frames = self.n_samples // hop_length
- self.sampling_rate = sampling_rate
- self.dither = dither
- self.mel_filters = mel_filter_bank(
- num_frequency_bins=1 + n_fft // 2,
- num_mel_filters=feature_size,
- min_frequency=0.0,
- max_frequency=8000.0,
- sampling_rate=sampling_rate,
- norm="slaney",
- mel_scale="slaney",
- )
-
- def _np_extract_fbank_features(self, waveform_batch: np.ndarray, device: str) -> np.ndarray:
- """
- Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
- implementation with 1e-5 tolerance.
- """
- if device != "cpu":
- raise ValueError(
- f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
- "devices requires torch, which is not installed. Either set `device='cpu'`, or "
- "install torch according to the official instructions: https://pytorch.org/get-started/locally/"
- )
- log_spec_batch = []
- for waveform in waveform_batch:
- log_spec = spectrogram(
- waveform,
- window_function(self.n_fft, "hann"),
- frame_length=self.n_fft,
- hop_length=self.hop_length,
- power=2.0,
- dither=self.dither,
- mel_filters=self.mel_filters,
- log_mel="log10",
- )
- log_spec = log_spec[:, :-1]
- log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- log_spec_batch.append(log_spec)
- log_spec_batch = np.array(log_spec_batch)
- return log_spec_batch
-
- def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray:
- """
- Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
- yielding results similar to cpu computing with 1e-5 tolerance.
- """
- waveform = torch.from_numpy(waveform).to(device, torch.float32)
- window = torch.hann_window(self.n_fft, device=device)
-
- # Note: it would be better to dither the chunked waveform,
- # so overlapping signal does not get the same dithering.
- # But, chunking is happening inside pytorch, so it is here.
- if self.dither != 0.0:
- waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
-
- stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
- magnitudes = stft[..., :-1].abs() ** 2
-
- mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
- mel_spec = mel_filters.T @ magnitudes
-
- log_spec = torch.clamp(mel_spec, min=1e-10).log10()
- if waveform.dim() == 2:
- max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
- log_spec = torch.maximum(log_spec, max_val - 8.0)
- else:
- log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- if device != "cpu":
- log_spec = log_spec.detach().cpu()
- return log_spec.numpy()
-
- @staticmethod
- # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
- def zero_mean_unit_var_norm(
- input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0
- ) -> list[np.ndarray]:
- """
- Every array in the list is normalized to have zero mean and unit variance
- """
- if attention_mask is not None:
- attention_mask = np.array(attention_mask, np.int32)
- normed_input_values = []
-
- for vector, length in zip(input_values, attention_mask.sum(-1)):
- normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
- if length < normed_slice.shape[0]:
- normed_slice[length:] = padding_value
-
- normed_input_values.append(normed_slice)
- else:
- normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
-
- return normed_input_values
-
- def __call__(
- self,
- raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]],
- truncation: bool = True,
- pad_to_multiple_of: int | None = None,
- return_tensors: str | TensorType | None = None,
- return_attention_mask: bool | None = None,
- padding: str | None = "max_length",
- max_length: int | None = None,
- sampling_rate: int | None = None,
- do_normalize: bool | None = None,
- device: str | None = "cpu",
- **kwargs,
- ) -> BatchFeature:
- """Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch
- for the STFT computation if available, otherwise a slower NumPy based one.
-
- Args:
- raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
- The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
- values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
- stereo, i.e. single float per timestep.
- truncation (`bool`, *optional*, default to `True`):
- Activates truncation to cut input sequences longer than *max_length* to *max_length*.
- pad_to_multiple_of (`int`, *optional*, defaults to None):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
- If set, will return tensors instead of list of python integers. Acceptable values are:
-
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
- - `'np'`: Return Numpy `np.ndarray` objects.
- return_attention_mask (`bool`, *optional*):
- Whether to return the attention mask. If left to the default, will return the attention mask according
- to the specific feature_extractor's default.
-
- [What are attention masks?](../glossary#attention-mask)
-
-
-
- For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle
- bugs.
-
-
- padding (`str` or [`~utils.PaddingStrategy`], *optional*, defaults to `'max_length'`):
- Activates and controls padding. Accepts the following values:
-
- - `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence is
- provided).
- - `'max_length'` (default): Pad to a maximum length specified with the argument `max_length` or to the
- maximum acceptable input length for the model if that argument is not provided.
- - `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
- max_length (`int`, *optional*):
- Controls the maximum length to use by one of the truncation/padding parameters.
-
- If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
- is required by one of the truncation/padding parameters. If the model has no specific maximum input
- length (like XLNet) truncation/padding to a maximum length will be deactivated.
- sampling_rate (`int`, *optional*):
- The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
- `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
- pipeline.
- do_normalize (`bool`, *optional*, defaults to `False`):
- Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
- improve the performance of the model.
- device (`str`, *optional*, defaults to `'cpu'`):
- Specifies the device for computation of the log-mel spectrogram of audio signals in the
- `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
- **kwargs: Not supported by WhisperFeatureExtractor.__call__() and ignored.
- """
- if sampling_rate is not None:
- if sampling_rate != self.sampling_rate:
- raise ValueError(
- f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a"
- f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input"
- f" was sampled with {self.sampling_rate} and not {sampling_rate}."
- )
- else:
- logger.warning(
- f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. "
- "Failing to do so can result in silent errors that might be hard to debug."
- )
-
- is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
- if is_batched_numpy and len(raw_speech.shape) > 2:
- raise ValueError(f"Only mono-channel audio is supported for input to {self}")
- is_batched = is_batched_numpy or (
- isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
- )
-
- if is_batched:
- raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
- elif not is_batched and not isinstance(raw_speech, np.ndarray):
- raw_speech = np.asarray(raw_speech, dtype=np.float32)
- elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
- raw_speech = raw_speech.astype(np.float32)
-
- # always return batch
- if not is_batched:
- raw_speech = [np.asarray([raw_speech]).T]
-
- batched_speech = BatchFeature({"input_features": raw_speech})
-
- # convert into correct format for padding
-
- padded_inputs = self.pad(
- batched_speech,
- padding=padding,
- max_length=max_length if max_length else self.n_samples,
- truncation=truncation,
- pad_to_multiple_of=pad_to_multiple_of,
- return_attention_mask=return_attention_mask or do_normalize,
- )
-
- # zero-mean and unit-variance normalization
- if do_normalize:
- padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
- padded_inputs["input_features"],
- attention_mask=padded_inputs["attention_mask"],
- padding_value=self.padding_value,
- )
- padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
-
- # make sure list is in array format
- input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
-
- extract_fbank_features = (
- self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
- )
- input_features = extract_fbank_features(input_features[0], device)
-
- if isinstance(input_features[0], list):
- padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
-
- else:
- padded_inputs["input_features"] = input_features
-
- if return_attention_mask:
- # rescale from sample (48000) to feature (3000)
- rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length]
-
- # The STFT computation produces L//hop_length + 1 frames, but we skip the last frame (see `_torch_extract_fbank_features`).
- # This means we need to trim the rescaled attention mask to match the actual number of frames (L//hop_length) when the input length
- # is not perfectly divisible by the hop length.
- if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0:
- rescaled_attention_mask = rescaled_attention_mask[:, :-1]
- padded_inputs["attention_mask"] = rescaled_attention_mask
-
- if return_tensors is not None:
- padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
-
- return padded_inputs
+WhisperFeatureExtractor = deprecated_feature_extractor(WhisperAudioProcessor, "WhisperFeatureExtractor")
__all__ = ["WhisperFeatureExtractor"]
diff --git a/src/transformers/numpy_mel_spectrogram.py b/src/transformers/numpy_mel_spectrogram.py
new file mode 100644
index 000000000000..cd90215b78a1
--- /dev/null
+++ b/src/transformers/numpy_mel_spectrogram.py
@@ -0,0 +1,413 @@
+"""NumPy implementation of mel spectrogram computation."""
+
+import numpy as np
+import librosa
+
+
+# --- Frequency conversion utilities ---
+
+def hertz_to_mel(freq, mel_scale="htk"):
+ if mel_scale == "htk":
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
+ elif mel_scale == "kaldi":
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
+ # slaney
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = 27.0 / np.log(6.4)
+ mels = 3.0 * freq / 200.0
+ if isinstance(freq, np.ndarray):
+ log_region = freq >= min_log_hertz
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
+ elif freq >= min_log_hertz:
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
+ return mels
+
+
+def mel_to_hertz(mels, mel_scale="htk"):
+ if mel_scale == "htk":
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
+ elif mel_scale == "kaldi":
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
+ # slaney
+ min_log_hertz = 1000.0
+ min_log_mel = 15.0
+ logstep = np.log(6.4) / 27.0
+ freq = 200.0 * mels / 3.0
+ if isinstance(mels, np.ndarray):
+ log_region = mels >= min_log_mel
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
+ elif mels >= min_log_mel:
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
+ return freq
+
+
+# --- Filter bank ---
+
+def _create_triangular_filter_bank(fft_freqs, filter_freqs):
+ filter_diff = np.diff(filter_freqs)
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return np.maximum(0, np.minimum(down_slopes, up_slopes))
+
+
+def mel_filter_bank(
+ num_frequency_bins,
+ num_mel_filters,
+ min_frequency,
+ max_frequency,
+ sampling_rate,
+ norm=None,
+ mel_scale="htk",
+ triangularize_in_mel_space=False,
+ frequency_bin_mode="rfft",
+):
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
+
+ n_fft = (num_frequency_bins - 1) * 2
+
+ if triangularize_in_mel_space:
+ fft_bin_width = sampling_rate / n_fft
+ fft_freqs = hertz_to_mel(
+ fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale
+ )
+ filter_freqs = mel_freqs
+ elif frequency_bin_mode == "rfft":
+ fft_freqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate)
+ else:
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
+
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+
+ if norm == "slaney":
+ enorm = 2.0 / (
+ filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]
+ )
+ mel_filters *= np.expand_dims(enorm, 0)
+
+ return mel_filters
+
+
+# --- Window ---
+
+def window_function(window_length, name="hann_window", periodic=True):
+ N = window_length + 1 if periodic else window_length
+ fac = np.linspace(-np.pi, np.pi, N)
+ if name in ("hann", "hann_window"):
+ w = 0.5 + 0.5 * np.cos(fac)
+ elif name in ("hamming", "hamming_window"):
+ w = 0.54 + 0.46 * np.cos(fac)
+ elif name == "boxcar":
+ w = np.ones(N)
+ elif name == "povey":
+ w = (0.5 + 0.5 * np.cos(fac)) ** 0.85
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+ return w[:window_length] if periodic else w
+
+
+# --- Sub-methods ---
+
+def _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing):
+ if needs_manual_framing and win_length < n_fft:
+ frame_length = win_length
+ else:
+ if win_length < n_fft:
+ left_pad = (n_fft - win_length) // 2
+ right_pad = n_fft - win_length - left_pad
+ window = np.pad(window, (left_pad, right_pad))
+ frame_length = n_fft
+ return window, frame_length
+
+
+def _frame_waveform(waveform, frame_length, hop_length, n_fft, center, pad_mode):
+ squeezed = waveform.ndim == 1
+ if squeezed:
+ waveform = waveform[np.newaxis, :]
+ if center:
+ # Use librosa-compatible split-padding to match their STFT exactly
+ # This replicates librosa's optimization to avoid copying the entire signal
+ start_k = int(np.ceil(n_fft // 2 / hop_length))
+ tail_k = (waveform.shape[-1] + n_fft // 2 - n_fft) // hop_length + 1
+
+ if tail_k <= start_k:
+ # Head and tail overlap, use simple full padding
+ waveform = np.pad(waveform, ((0, 0), (frame_length // 2, frame_length // 2)), mode=pad_mode)
+ num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length
+ frame_starts = np.arange(num_frames) * hop_length
+ frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length)
+ frames = waveform[:, frame_indices] # (batch, num_frames, frame_length)
+ else:
+ # Split padding: handle head and tail separately like librosa
+ # Pre-padding: left pad only
+ padding = [(0, 0) for _ in range(waveform.ndim)]
+ padding[-1] = (frame_length // 2, 0)
+ y_pre = np.pad(
+ waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1],
+ padding,
+ mode=pad_mode,
+ )
+ y_frames_pre = librosa.util.frame(y_pre, frame_length=frame_length, hop_length=hop_length)
+ y_frames_pre = y_frames_pre[..., :start_k]
+ y_frames_pre = np.moveaxis(y_frames_pre, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length)
+ extra = y_frames_pre.shape[-2]
+
+ # Post-padding: right pad only
+ padding[-1] = (0, frame_length // 2)
+ y_post = np.pad(
+ waveform[..., (tail_k) * hop_length - n_fft // 2 :],
+ padding,
+ mode=pad_mode,
+ )
+ y_frames_post = librosa.util.frame(y_post, frame_length=frame_length, hop_length=hop_length)
+ y_frames_post = np.moveaxis(y_frames_post, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length)
+ extra += y_frames_post.shape[-2]
+
+ # Middle: no padding
+ start = start_k * hop_length - n_fft // 2
+ y_frames_middle = librosa.util.frame(
+ waveform[..., start:], frame_length=frame_length, hop_length=hop_length
+ )
+ y_frames_middle = np.moveaxis(y_frames_middle, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length)
+
+ # Total frames
+ num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2]
+
+ # Concatenate frames
+ frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2)
+ else:
+ # No centering: no padding
+ num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length
+ frame_starts = np.arange(num_frames) * hop_length
+ frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length)
+ frames = waveform[:, frame_indices] # (batch, num_frames, frame_length)
+
+ if squeezed:
+ frames = frames.squeeze(0)
+ return frames, num_frames
+
+
+def _apply_frame_processing(frames, *, dither=0.0, preemphasis=None, remove_dc_offset=False):
+ compute_dtype = frames.dtype
+ if dither != 0.0:
+ frames = frames + dither * np.random.randn(*frames.shape).astype(compute_dtype)
+ if remove_dc_offset:
+ frames = frames - frames.mean(axis=-1, keepdims=True)
+ if preemphasis is not None:
+ preemph_src = preemphasis * frames[..., :-1]
+ frames[..., 1:] = frames[..., 1:] - preemph_src
+ frames[..., 0] = frames[..., 0] * (1 - preemphasis)
+ return frames
+
+
+def _windowed_fft(frames, window, fft_length, power, normalized):
+ """Apply window, compute FFT, and return power spectrogram of shape (..., freq, time)."""
+ frames = frames * window
+ spec = np.fft.rfft(frames, n=fft_length, axis=-1).astype(np.complex64)
+ if normalized:
+ spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype)
+ spec = np.abs(spec, dtype=np.float64) ** power
+ return np.moveaxis(spec, -1, -2)
+
+
+def _apply_mel_scale(
+ spectrogram: np.ndarray,
+ mel_filters: np.ndarray,
+ mel_floor: float = 1e-10,
+) -> np.ndarray:
+ """Apply mel filterbank to a spectrogram.
+
+ Args:
+ spectrogram: Power spectrogram of shape (..., freq, time).
+ mel_filters: Mel filterbank of shape (freq, n_mels).
+ mel_floor: Minimum value for clamping.
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ # (n_mels, freq) @ (..., freq, time) -> (..., n_mels, time)
+ mel_spec = np.matmul(mel_filters.T, spectrogram)
+ return np.maximum(mel_floor, mel_spec)
+
+
+# --- Main function ---
+
+def mel_spectrogram(
+ waveform: np.ndarray,
+ sampling_rate: int,
+ *,
+ n_fft: int = 400,
+ win_length: int | None = None,
+ hop_length: int | None = None,
+ window_fn: str = "hann_window",
+ wkwargs: dict | None = None,
+ power: float = 2.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ normalized: bool = False,
+ periodic: bool = True,
+ # mel scale kwargs
+ n_mels: int = 128,
+ f_min: float = 0.0,
+ f_max: float | None = None,
+ mel_scale: str = "htk",
+ norm: str | None = None,
+ triangularize_in_mel_space: bool = False,
+ # kaldi-specific kwargs
+ dither: float = 0.0,
+ preemphasis: float | None = None,
+ remove_dc_offset: bool = False,
+ mel_floor: float = 1e-10,
+) -> np.ndarray:
+ """Compute mel spectrogram using NumPy.
+
+ Args:
+ waveform: Input waveform of shape (..., time).
+ sampling_rate: Sample rate in Hz.
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ if f_max is None:
+ f_max = sampling_rate / 2.0
+
+ # --- STFT ---
+ if win_length is None:
+ win_length = n_fft
+ if hop_length is None:
+ hop_length = win_length // 2
+ window = window_function(win_length, name=window_fn, periodic=periodic)
+
+ needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset
+ window, frame_length = _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing)
+
+ is_1d = waveform.ndim == 1
+ if is_1d:
+ waveform = waveform[np.newaxis, :]
+ leading_shape = waveform.shape[:-1]
+ waveform = waveform.reshape(-1, waveform.shape[-1])
+ frames, num_frames = _frame_waveform(waveform, frame_length, hop_length, n_fft, center, pad_mode)
+ compute_dtype = np.result_type(waveform.dtype, window.dtype)
+ frames = frames.astype(compute_dtype, copy=False)
+ frames = _apply_frame_processing(frames, dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset)
+ spectrogram = _windowed_fft(frames, window, n_fft, power, normalized)
+
+ num_frequency_bins = n_fft // 2 + 1
+ spectrogram = spectrogram.reshape(*leading_shape, num_frequency_bins, num_frames)
+ if is_1d:
+ spectrogram = spectrogram.squeeze(0)
+
+ num_frequency_bins = spectrogram.shape[-2]
+ mel_fb = mel_filter_bank(
+ num_frequency_bins, n_mels, f_min, f_max, sampling_rate,
+ norm=norm, mel_scale=mel_scale,
+ triangularize_in_mel_space=triangularize_in_mel_space,
+ )
+
+ return _apply_mel_scale(spectrogram, mel_fb, mel_floor=mel_floor)
+
+
+class MelSpectrogram:
+ """Cached mel spectrogram — precomputes window and mel filterbank.
+
+ Same API and exact same results as the functional ``mel_spectrogram``, but
+ avoids recomputing the window and mel filterbank on every call.
+
+ Usage::
+
+ transform = MelSpectrogram(sampling_rate=16000, n_fft=1024, n_mels=80)
+ mel = transform(waveform) # fast repeated calls
+ """
+
+ def __init__(
+ self,
+ sampling_rate: int,
+ *,
+ n_fft: int = 400,
+ win_length: int | None = None,
+ hop_length: int | None = None,
+ window_fn: str = "hann_window",
+ wkwargs: dict | None = None,
+ power: float = 2.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ normalized: bool = False,
+ periodic: bool = True,
+ n_mels: int = 128,
+ f_min: float = 0.0,
+ f_max: float | None = None,
+ mel_scale: str = "htk",
+ norm: str | None = None,
+ triangularize_in_mel_space: bool = False,
+ dither: float = 0.0,
+ preemphasis: float | None = None,
+ remove_dc_offset: bool = False,
+ mel_floor: float = 1e-10,
+ ):
+ self.sampling_rate = sampling_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ self.power = power
+ self.center = center
+ self.pad_mode = pad_mode
+ self.normalized = normalized
+ self.periodic = periodic
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max if f_max is not None else sampling_rate / 2.0
+ self.mel_floor = mel_floor
+ self.dither = dither
+ self.preemphasis = preemphasis
+ self.remove_dc_offset = remove_dc_offset
+ self.window_fn = window_fn
+
+ # Precompute window
+ needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset
+ window = window_function(self.win_length, name=window_fn, periodic=periodic)
+ self._window, self._frame_length = _prepare_window_and_framing(
+ window, self.win_length, n_fft, needs_manual_framing,
+ )
+
+ # Precompute mel filterbank
+ num_frequency_bins = n_fft // 2 + 1
+ self._mel_fb = mel_filter_bank(
+ num_frequency_bins, n_mels, self.f_min, self.f_max, sampling_rate,
+ norm=norm, mel_scale=mel_scale,
+ triangularize_in_mel_space=triangularize_in_mel_space,
+ )
+
+ def __call__(self, waveform: np.ndarray) -> np.ndarray:
+ """Compute mel spectrogram.
+
+ Args:
+ waveform: Input of shape (..., time).
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ is_1d = waveform.ndim == 1
+ if is_1d:
+ waveform = waveform[np.newaxis, :]
+ leading_shape = waveform.shape[:-1]
+ waveform = waveform.reshape(-1, waveform.shape[-1])
+ frames, num_frames = _frame_waveform(
+ waveform, self._frame_length, self.hop_length, self.n_fft, self.center, self.pad_mode,
+ )
+ compute_dtype = np.result_type(waveform.dtype, self._window.dtype)
+ frames = frames.astype(compute_dtype, copy=False)
+ frames = _apply_frame_processing(
+ frames, dither=self.dither, preemphasis=self.preemphasis, remove_dc_offset=self.remove_dc_offset,
+ )
+ spectrogram = _windowed_fft(frames, self._window, self.n_fft, self.power, self.normalized)
+
+ num_frequency_bins = self.n_fft // 2 + 1
+ spectrogram = spectrogram.reshape(*leading_shape, num_frequency_bins, num_frames)
+ if is_1d:
+ spectrogram = spectrogram.squeeze(0)
+
+ return _apply_mel_scale(spectrogram, self._mel_fb, mel_floor=self.mel_floor)
diff --git a/src/transformers/preprocessing_base.py b/src/transformers/preprocessing_base.py
new file mode 100644
index 000000000000..d994f4811e32
--- /dev/null
+++ b/src/transformers/preprocessing_base.py
@@ -0,0 +1,470 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Base mixin for image processors and feature extractors, providing shared
+save/load/serialization logic.
+"""
+
+import copy
+import json
+import os
+from copy import deepcopy
+from typing import Any, TypeVar
+
+import numpy as np
+from huggingface_hub import create_repo, is_offline_mode
+
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+ PROCESSOR_NAME,
+ PushToHubMixin,
+ logging,
+ safe_load_json_file,
+)
+from .utils.hub import cached_file
+
+
+logger = logging.get_logger(__name__)
+
+PreprocessingMixinType = TypeVar("PreprocessingMixinType", bound="PreprocessingMixin")
+
+
+class PreprocessingMixin(PushToHubMixin):
+ """
+ Base mixin providing saving/loading functionality shared by
+ ImageProcessingMixin, AudioProcessingMixin and FeatureExtractionMixin.
+
+ Subclasses must set the following class attributes:
+ _config_name: str — config file name (e.g. IMAGE_PROCESSOR_NAME)
+ _type_key: str — key added in to_dict() (e.g. "image_processor_type")
+ _nested_config_keys: list — keys to check in processor_config.json
+ _auto_class_default: str — default auto class for register_for_auto_class
+ _file_type_label: str — label for user-agent / error messages
+ Optional:
+ _excluded_dict_keys: set — keys to drop from to_dict() output
+ _extra_init_pops: list — extra keys to pop in __init__
+ _config_filename_kwarg: str — kwarg name that can override the config filename
+ _subfolder_default: str — default for the subfolder kwarg
+ """
+
+ _auto_class = None
+
+ # --- Must be overridden by subclasses ---
+ _config_name: str
+ _type_key: str
+ _nested_config_keys: list[str] = []
+ _auto_class_default: str
+ _file_type_label: str
+
+ # --- Optional overrides ---
+ _excluded_dict_keys: set[str] = set()
+ _extra_init_pops: list[str] = []
+ _config_filename_kwarg: str | None = None
+ _subfolder_default: str | None = ""
+
+ def __init__(self, **kwargs):
+ """Set elements of `kwargs` as attributes."""
+ for key in self._extra_init_pops:
+ kwargs.pop(key, None)
+ # Pop "processor_class", should not be saved in config
+ kwargs.pop("processor_class", None)
+
+ if hasattr(self, "valid_kwargs") and hasattr(self.valid_kwargs, "__annotations__"):
+ self._init_kwargs_from_valid_kwargs(kwargs)
+
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def _init_kwargs_from_valid_kwargs(self, kwargs: dict):
+ """
+ Initialize instance attributes from `valid_kwargs` annotations.
+
+ For each key in `self.valid_kwargs.__annotations__`, pops it from `kwargs`
+ and sets it on the instance (or deep-copies the class default).
+ Also sets `self._valid_kwargs_names`.
+ """
+ for key in self.valid_kwargs.__annotations__:
+ kwarg = kwargs.pop(key, None)
+ if kwarg is not None:
+ setattr(self, key, kwarg)
+ else:
+ setattr(self, key, deepcopy(getattr(self, key, None)))
+ self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys())
+
+ def filter_out_unused_kwargs(self, kwargs: dict) -> dict:
+ """
+ Filter out the unused kwargs from the kwargs dictionary.
+ """
+ if self.unused_kwargs is None:
+ return kwargs
+
+ for kwarg_name in self.unused_kwargs:
+ if kwarg_name in kwargs:
+ logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.")
+ kwargs.pop(kwarg_name)
+ return kwargs
+
+ @classmethod
+ def from_dict(cls, config_dict: dict[str, Any], **kwargs):
+ """
+ Instantiates a processor from a Python dictionary of parameters.
+
+ Args:
+ config_dict (`dict[str, Any]`):
+ Dictionary that will be used to instantiate the processor object.
+ kwargs (`dict[str, Any]`):
+ Additional parameters from which to initialize the processor object.
+
+ Returns:
+ A processor of type [`~PreprocessingMixin`].
+ """
+ config_dict = config_dict.copy()
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ # Use valid_kwargs pattern when available (image/audio processors)
+ if hasattr(cls, "valid_kwargs") and hasattr(cls.valid_kwargs, "__annotations__"):
+ config_dict.update({k: v for k, v in kwargs.items() if k in cls.valid_kwargs.__annotations__})
+ processor = cls(**config_dict)
+
+ # Apply extra kwargs to instance (BC for remote code)
+ extra_keys = []
+ for key in reversed(list(kwargs.keys())):
+ if hasattr(processor, key) and key not in cls.valid_kwargs.__annotations__:
+ setattr(processor, key, kwargs.pop(key, None))
+ extra_keys.append(key)
+ if extra_keys:
+ logger.warning_once(
+ f"Processor {cls.__name__}: kwargs {extra_keys} were applied for backward compatibility. "
+ f"To avoid this warning, add them to valid_kwargs."
+ )
+ else:
+ processor = cls(**config_dict)
+
+ logger.info(f"Processor {processor}")
+ if return_unused_kwargs:
+ return processor, kwargs
+ else:
+ return processor
+
+ @classmethod
+ def from_pretrained(
+ cls: type[PreprocessingMixinType],
+ pretrained_model_name_or_path: str | os.PathLike,
+ cache_dir: str | os.PathLike | None = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ token: str | bool | None = None,
+ revision: str = "main",
+ **kwargs,
+ ) -> PreprocessingMixinType:
+ r"""
+ Instantiate a processor from a pretrained model name or path.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ This can be either:
+
+ - a string, the *model id* of a pretrained processor hosted inside a model repo on
+ huggingface.co.
+ - a path to a *directory* containing a processor file saved using the
+ [`~PreprocessingMixin.save_pretrained`] method, e.g., `./my_model_directory/`.
+ - a path or url to a saved processor JSON *file*, e.g.,
+ `./my_model_directory/preprocessor_config.json`.
+ cache_dir (`str` or `os.PathLike`, *optional*):
+ Path to a directory in which a downloaded pretrained model processor should be cached if the
+ standard cache should not be used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force to (re-)download the processor files and override the cached versions if
+ they exist.
+ token (`str` or `bool`, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ If `False`, then this function returns just the final processor object. If `True`, then this
+ functions returns a `Tuple(processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+ consisting of the key/value pairs whose keys are not processor attributes.
+ kwargs (`dict[str, Any]`, *optional*):
+ The values in kwargs of any keys which are processor attributes will be used to override the
+ loaded values.
+
+ Returns:
+ A processor of type [`~PreprocessingMixin`].
+ """
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs["revision"] = revision
+
+ if token is not None:
+ kwargs["token"] = token
+
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ return cls.from_dict(config_dict, **kwargs)
+
+ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
+ """
+ Save a processor object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~PreprocessingMixin.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the processor JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face model hub after saving it.
+ kwargs (`dict[str, Any]`, *optional*):
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id
+ files_timestamps = self._get_files_timestamps(save_directory)
+
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+ # loaded from the Hub.
+ if self._auto_class is not None:
+ custom_object_save(self, save_directory, config=self)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_file = os.path.join(save_directory, self._config_name)
+
+ self.to_json_file(output_file)
+ logger.info(f"{self._file_type_label} saved in {output_file}")
+
+ if push_to_hub:
+ self._upload_modified_files(
+ save_directory,
+ repo_id,
+ files_timestamps,
+ commit_message=commit_message,
+ token=kwargs.get("token"),
+ )
+
+ return [output_file]
+
+ @classmethod
+ def _get_config_dict(
+ cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ processor using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", cls._subfolder_default)
+
+ # Allow overriding the config filename via a kwarg (e.g. image_processor_filename)
+ if cls._config_filename_kwarg is not None:
+ config_filename = kwargs.pop(cls._config_filename_kwarg, cls._config_name)
+ else:
+ config_filename = cls._config_name
+
+ from_pipeline = kwargs.pop("_from_pipeline", None)
+ from_auto_class = kwargs.pop("_from_auto", False)
+
+ user_agent = {"file_type": cls._file_type_label, "from_auto_class": from_auto_class}
+ if from_pipeline is not None:
+ user_agent["using_pipeline"] = from_pipeline
+
+ if is_offline_mode() and not local_files_only:
+ logger.info("Offline mode: forcing local_files_only=True")
+ local_files_only = True
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isdir(pretrained_model_name_or_path):
+ config_file = os.path.join(pretrained_model_name_or_path, config_filename)
+ if os.path.isfile(pretrained_model_name_or_path):
+ resolved_config_file = pretrained_model_name_or_path
+ resolved_processor_file = None
+ is_local = True
+ else:
+ config_file = config_filename
+ try:
+ resolved_processor_file = cached_file(
+ pretrained_model_name_or_path,
+ filename=PROCESSOR_NAME,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ filename=config_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ user_agent=user_agent,
+ revision=revision,
+ subfolder=subfolder,
+ _raise_exceptions_for_missing_entries=False,
+ )
+ except OSError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception:
+ # For any other exception, we throw a generic error.
+ raise OSError(
+ f"Can't load {cls._file_type_label} for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {config_filename} file"
+ )
+
+ # Load config dict. Priority goes as (nested config if found -> standalone config)
+ # We are downloading both configs because almost all models have a `processor_config.json` but
+ # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style
+ config_dict = None
+ if resolved_processor_file is not None:
+ processor_dict = safe_load_json_file(resolved_processor_file)
+ for nested_key in cls._nested_config_keys:
+ if nested_key in processor_dict:
+ config_dict = processor_dict[nested_key]
+ break
+
+ if resolved_config_file is not None and config_dict is None:
+ config_dict = safe_load_json_file(resolved_config_file)
+
+ if config_dict is None:
+ raise OSError(
+ f"Can't load {cls._file_type_label} for '{pretrained_model_name_or_path}'. If you were trying to load"
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+ f" directory containing a {config_filename} file"
+ )
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_config_file}")
+ else:
+ logger.info(
+ f"loading configuration file {config_file} from cache at {resolved_config_file}"
+ )
+
+ return config_dict, kwargs
+
+ def to_dict(self) -> dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `dict[str, Any]`: Dictionary of all the attributes that make up this instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output[self._type_key] = self.__class__.__name__
+ output.pop("_valid_kwargs_names", None)
+ for key in self._excluded_dict_keys:
+ if key in output:
+ del output[key]
+ return output
+
+ @classmethod
+ def from_json_file(cls, json_file: str | os.PathLike):
+ """
+ Instantiates a processor from the path to a JSON file of parameters.
+
+ Args:
+ json_file (`str` or `os.PathLike`):
+ Path to the JSON file containing the parameters.
+
+ Returns:
+ A processor of type [`~PreprocessingMixin`]: The processor object instantiated from that JSON file.
+ """
+ with open(json_file, encoding="utf-8") as reader:
+ text = reader.read()
+ config_dict = json.loads(text)
+ return cls(**config_dict)
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this instance in JSON format.
+ """
+ dictionary = self.to_dict()
+
+ for key, value in dictionary.items():
+ if isinstance(value, np.ndarray):
+ dictionary[key] = value.tolist()
+
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path: str | os.PathLike):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this instance's parameters will be saved.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ @classmethod
+ def register_for_auto_class(cls, auto_class=None):
+ """
+ Register this class with a given auto class.
+
+ Args:
+ auto_class (`str` or `type`, *optional*):
+ The auto class to register this new processor with. Defaults to the subclass's `_auto_class_default`.
+ """
+ if auto_class is None:
+ auto_class = cls._auto_class_default
+
+ if not isinstance(auto_class, str):
+ auto_class = auto_class.__name__
+
+ import transformers.models.auto as auto_module
+
+ if not hasattr(auto_module, auto_class):
+ raise ValueError(f"{auto_class} is not a valid auto class.")
+
+ cls._auto_class = auto_class
diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py
index fb1bd18c6239..cbbdd63d8110 100644
--- a/src/transformers/processing_utils.py
+++ b/src/transformers/processing_utils.py
@@ -408,6 +408,7 @@ class AudioKwargs(TypedDict, total=False):
- `'np'`: Return NumPy `np.ndarray` objects.
"""
+ sample_rate: Annotated[int | None, positive_int()]
sampling_rate: Annotated[int | None, positive_int()]
raw_speech: Union["np.ndarray", list[float], list["np.ndarray"], list[list[float]]] | None
padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()]
@@ -416,6 +417,8 @@ class AudioKwargs(TypedDict, total=False):
pad_to_multiple_of: Annotated[int | None, positive_int()]
return_attention_mask: bool | None
return_tensors: Annotated[str | TensorType | None, tensor_type_validator()]
+ do_normalize: bool | None
+ device: str | None
class ProcessingKwargs(TypedDict, total=False):
diff --git a/src/transformers/torch_mel_spectrogram.py b/src/transformers/torch_mel_spectrogram.py
new file mode 100644
index 000000000000..3d48f2b8192a
--- /dev/null
+++ b/src/transformers/torch_mel_spectrogram.py
@@ -0,0 +1,522 @@
+"""PyTorch implementation of mel spectrogram computation."""
+
+import math
+
+import torch
+
+
+# --- Frequency conversion utilities ---
+
+def _hertz_to_mel_scalar(freq: float, mel_scale: str = "htk") -> float:
+ """Convert a single Hz value to mel using Python math (float64)."""
+ if mel_scale == "htk":
+ return 2595.0 * math.log10(1.0 + freq / 700.0)
+ elif mel_scale == "kaldi":
+ return 1127.0 * math.log(1.0 + freq / 700.0)
+ # slaney
+ f_sp = 200.0 / 3
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - 0.0) / f_sp
+ logstep = math.log(6.4) / 27.0
+ if freq >= min_log_hz:
+ return min_log_mel + math.log(freq / min_log_hz) / logstep
+ return (freq - 0.0) / f_sp
+
+
+def hertz_to_mel(freq: torch.Tensor, mel_scale: str = "htk") -> torch.Tensor:
+ if mel_scale == "htk":
+ return 2595.0 * torch.log10(1.0 + freq / 700.0)
+ elif mel_scale == "kaldi":
+ return 1127.0 * torch.log(1.0 + freq / 700.0)
+ # slaney
+ f_sp = 200.0 / 3
+ min_log_hertz = 1000.0
+ min_log_mel = min_log_hertz / f_sp
+ logstep = 27.0 / torch.log(torch.tensor(6.4))
+ mels = freq / f_sp
+ log_region = freq >= min_log_hertz
+ mels[log_region] = min_log_mel + torch.log(freq[log_region] / min_log_hertz) * logstep
+ return mels
+
+
+def mel_to_hertz(mels: torch.Tensor, mel_scale: str = "htk") -> torch.Tensor:
+ if mel_scale == "htk":
+ return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
+ elif mel_scale == "kaldi":
+ return 700.0 * (torch.exp(mels / 1127.0) - 1.0)
+ # slaney
+ f_sp = 200.0 / 3
+ min_log_hz = 1000.0
+ min_log_mel = (min_log_hz - 0.0) / f_sp
+ logstep = math.log(6.4) / 27.0
+ freq = 0.0 + f_sp * mels
+ log_region = mels >= min_log_mel
+ freq[log_region] = min_log_hz * torch.exp(logstep * (mels[log_region] - min_log_mel))
+ return freq
+
+
+def _create_triangular_filter_bank(
+ fft_freqs: torch.Tensor, filter_freqs: torch.Tensor
+) -> torch.Tensor:
+ filter_diff = filter_freqs[1:] - filter_freqs[:-1]
+ slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return torch.clamp(torch.minimum(down_slopes, up_slopes), min=0)
+
+
+def _kaldi_mel_filter_bank(
+ num_frequency_bins: int,
+ num_mel_filters: int,
+ min_frequency: float,
+ max_frequency: float,
+ sampling_rate: int,
+) -> torch.Tensor:
+ """Compute mel filter bank matching kaldi's exact construction.
+
+ Replicates torchaudio.compliance.kaldi.get_mel_banks exactly:
+ - Uses 1127*ln mel scale (not 2595*log10)
+ - Computes mel points via mel_low + i * delta (not torch.linspace)
+ - Uses n_fft/2 FFT bins (excludes Nyquist), then pads with zero column
+
+ Returns:
+ Tensor of shape (num_frequency_bins, num_mel_filters).
+ """
+ n_fft = (num_frequency_bins - 1) * 2
+ num_fft_bins = n_fft // 2 # kaldi excludes Nyquist bin
+ fft_bin_width = sampling_rate / n_fft
+
+ mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0)
+ mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0)
+ mel_delta = (mel_high - mel_low) / (num_mel_filters + 1)
+
+ bin_idx = torch.arange(num_mel_filters).unsqueeze(1)
+ left_mel = mel_low + bin_idx * mel_delta
+ center_mel = mel_low + (bin_idx + 1.0) * mel_delta
+ right_mel = mel_low + (bin_idx + 2.0) * mel_delta
+
+ mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log()
+ mel = mel.unsqueeze(0)
+
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+ banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
+
+ # kaldi pads a zero column for the Nyquist bin
+ banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0)
+
+ return banks.T # (num_frequency_bins, num_mel_filters)
+
+
+def mel_filter_bank_torch(
+ num_frequency_bins: int,
+ num_mel_filters: int,
+ min_frequency: float,
+ max_frequency: float,
+ sampling_rate: int,
+ norm: str | None = None,
+ mel_scale: str = "htk",
+ triangularize_in_mel_space: bool = False,
+ frequency_bin_mode: str = "rfft",
+ computation_dtype: "torch.dtype | None" = None,
+ bands_to_zero: int = 0,
+) -> torch.Tensor:
+ """Compute mel filter bank as a pure PyTorch tensor.
+
+ Matches torchaudio's melscale_fbanks: mel range endpoints are computed in
+ float64 (Python math), then all tensor work is done in the default dtype
+ (float32).
+
+ Args:
+ computation_dtype: If provided, all intermediate tensor operations are
+ performed in this dtype (e.g. ``torch.float64``), and the result is
+ cast back to the default dtype. This is useful to obtain results
+ that are numerically identical to a NumPy (float64) reference
+ implementation.
+ bands_to_zero: Number of lowest frequency bins to zero out before
+ building the filter bank. The zeroed rows are restored (as zeros)
+ in the output. Set to 1 to exclude the DC bin (HTK / LASR style).
+
+ Returns:
+ Tensor of shape (num_frequency_bins, num_mel_filters).
+ """
+ if triangularize_in_mel_space and bands_to_zero == 0:
+ # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks.
+ # Kept for backward compatibility with models that rely on this behaviour
+ # (AST, SeamlessM4T, Speech2Text, etc.).
+ return _kaldi_mel_filter_bank(
+ num_frequency_bins, num_mel_filters, min_frequency, max_frequency, sampling_rate,
+ )
+
+ mel_min = _hertz_to_mel_scalar(min_frequency, mel_scale=mel_scale)
+ mel_max = _hertz_to_mel_scalar(max_frequency, mel_scale=mel_scale)
+
+ n_fft = (num_frequency_bins - 1) * 2
+
+ if triangularize_in_mel_space:
+ # Kaldi-style direct slope computation in mel space.
+ # Uses mel_low + i * delta (not linspace) and direct per-band slopes
+ # to match the exact numerical behaviour of kaldi/HTK filter banks.
+ mel_delta = (mel_max - mel_min) / (num_mel_filters + 1)
+ bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1)
+ left_mel = mel_min + bin_idx * mel_delta
+ center_mel = mel_min + (bin_idx + 1.0) * mel_delta
+ right_mel = mel_min + (bin_idx + 2.0) * mel_delta
+
+ fft_bin_width = sampling_rate / n_fft
+ num_fft_bins = num_frequency_bins - bands_to_zero
+ hz_freqs = fft_bin_width * torch.arange(bands_to_zero, num_frequency_bins, dtype=computation_dtype)
+ mel = hertz_to_mel(hz_freqs, mel_scale=mel_scale).unsqueeze(0)
+
+ up_slope = (mel - left_mel) / (center_mel - left_mel)
+ down_slope = (right_mel - mel) / (right_mel - center_mel)
+ mel_filters = torch.max(torch.zeros(1, dtype=computation_dtype), torch.min(up_slope, down_slope))
+
+ # Transpose to (num_fft_bins, num_mel_filters) and restore zeroed bands
+ mel_filters = mel_filters.T
+ if bands_to_zero > 0:
+ mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, bands_to_zero, 0))
+
+ return mel_filters
+
+ mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype)
+
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
+ if frequency_bin_mode == "rfft":
+ fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate)
+ else:
+ fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins)
+ if computation_dtype is not None:
+ fft_freqs = fft_freqs.to(computation_dtype)
+
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+
+ if norm == "slaney":
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
+ mel_filters = mel_filters * enorm.unsqueeze(0)
+
+ if bands_to_zero > 0:
+ mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, bands_to_zero, 0))
+
+ return mel_filters
+
+
+def window_function(window_length, name="hann_window", periodic=True, wkwargs=None):
+ """Create a window tensor using torch window functions."""
+ if wkwargs is None:
+ wkwargs = {}
+ if name in ["hann", "hann_window"]:
+ return torch.hann_window(window_length, periodic=periodic, **wkwargs)
+ elif name in ["hamming", "hamming_window"]:
+ return torch.hamming_window(window_length, periodic=periodic, **wkwargs)
+ elif name == "boxcar":
+ return torch.ones(window_length)
+ elif name == "povey":
+ return torch.hann_window(window_length, periodic=periodic, **wkwargs).pow(0.85)
+ else:
+ raise ValueError(f"Unknown window function '{name}'")
+
+
+# --- Sub-methods ---
+
+def _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing):
+ if needs_manual_framing and win_length < n_fft:
+ frame_length = win_length
+ else:
+ if win_length < n_fft:
+ left_pad = (n_fft - win_length) // 2
+ right_pad = n_fft - win_length - left_pad
+ window = torch.nn.functional.pad(window, (left_pad, right_pad))
+ frame_length = n_fft
+ return window, frame_length
+
+
+def _apply_frame_processing(frames, *, dither=0.0, preemphasis=None, remove_dc_offset=False):
+ if dither != 0.0:
+ frames = frames + dither * torch.randn_like(frames)
+ if remove_dc_offset:
+ frames = frames - frames.mean(dim=-1, keepdim=True)
+ if preemphasis is not None:
+ frames = torch.cat([
+ frames[..., :1] * (1 - preemphasis),
+ frames[..., 1:] - preemphasis * frames[..., :-1],
+ ], dim=-1)
+ return frames
+
+
+def _apply_mel_scale(
+ spectrogram: torch.Tensor,
+ mel_filters: torch.Tensor,
+ mel_floor: float = 1e-10,
+) -> torch.Tensor:
+ """Apply mel filterbank to a spectrogram.
+
+ Args:
+ spectrogram: Power spectrogram of shape (..., freq, time).
+ mel_filters: Mel filterbank of shape (freq, n_mels).
+ mel_floor: Minimum value for clamping.
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ # (..., time, freq) @ (freq, n_mels) -> (..., time, n_mels) -> (..., n_mels, time)
+ mel_spec = torch.matmul(spectrogram.transpose(-2, -1), mel_filters).transpose(-2, -1)
+ return torch.clamp(mel_spec, min=mel_floor)
+
+
+def _torch_stft(
+ waveform, window, frame_length, hop_length, fft_length,
+ normalized, center, pad_mode,
+):
+ """Fast path using torch.stft. Returns complex STFT of shape (batch, freq, time)."""
+ stft_out = torch.stft(
+ waveform,
+ n_fft=fft_length,
+ hop_length=hop_length,
+ win_length=frame_length,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ normalized=False,
+ return_complex=True,
+ )
+ if normalized:
+ stft_out = stft_out / window.pow(2.0).sum().sqrt()
+ return stft_out
+
+
+def _manual_stft(
+ waveform, window, frame_length, hop_length, fft_length,
+ num_frequency_bins, power, normalized, center, pad_mode,
+ apply_frame_processing=None,
+):
+ """Manual framing STFT for kaldi-specific features. Returns power spectrogram of shape (batch, freq, time)."""
+ if center:
+ waveform = torch.nn.functional.pad(
+ waveform, (frame_length // 2, frame_length // 2), mode=pad_mode
+ )
+
+ # Extract all frames at once: (batch, num_frames, frame_length)
+ frames = waveform.unfold(-1, frame_length, hop_length)
+
+ if apply_frame_processing is not None:
+ frames = apply_frame_processing(frames)
+
+ frames = frames * window
+
+ # Zero-pad frames to fft_length if frame_length < fft_length (kaldi left-aligns in FFT buffer)
+ if frame_length < fft_length:
+ frames = torch.nn.functional.pad(frames, (0, fft_length - frame_length))
+
+ # Batched FFT: (batch, num_frames, fft_length) -> (batch, num_frames, num_frequency_bins)
+ spec = torch.fft.rfft(frames, n=fft_length)
+
+ if normalized:
+ spec = spec / window.pow(2.0).sum().sqrt()
+
+ spec = spec.abs() ** power
+
+ # (batch, num_frames, freq) -> (batch, freq, num_frames)
+ return spec.transpose(-2, -1)
+
+
+# --- Main function ---
+
+def mel_spectrogram(
+ waveform: torch.Tensor,
+ sampling_rate: int,
+ *,
+ n_fft: int = 400,
+ win_length: int | None = None,
+ hop_length: int | None = None,
+ window_fn: str = "hann_window",
+ wkwargs: dict | None = None,
+ power: float = 2.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ normalized: bool = False,
+ periodic: bool = True,
+ # mel scale kwargs
+ n_mels: int = 128,
+ f_min: float = 0.0,
+ f_max: float | None = None,
+ mel_scale: str = "htk",
+ norm: str | None = None,
+ triangularize_in_mel_space: bool = False,
+ # kaldi-specific kwargs
+ dither: float = 0.0,
+ preemphasis: float | None = None,
+ remove_dc_offset: bool = False,
+ mel_floor: float = 1e-10,
+) -> torch.Tensor:
+ """Compute mel spectrogram using PyTorch.
+
+ Args:
+ waveform: Input waveform of shape (..., time).
+ sampling_rate: Sample rate in Hz.
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ if f_max is None:
+ f_max = sampling_rate / 2.0
+
+ # --- STFT ---
+ if win_length is None:
+ win_length = n_fft
+ if hop_length is None:
+ hop_length = win_length // 2
+ device = waveform.device
+ dtype = waveform.dtype
+
+ needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset
+
+ window_wkwargs = {**(wkwargs or {}), "dtype": dtype}
+ window = window_function(win_length, name=window_fn, periodic=periodic, wkwargs=window_wkwargs)
+ window = window.to(device=device)
+ window, frame_length = _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing)
+
+ is_1d = waveform.ndim == 1
+ if is_1d:
+ waveform = waveform.unsqueeze(0)
+ leading_shape = waveform.shape[:-1]
+ waveform = waveform.reshape(-1, waveform.shape[-1])
+ if needs_manual_framing:
+ frame_proc = lambda f: _apply_frame_processing(
+ f, dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset,
+ )
+ spectrogram = _manual_stft(
+ waveform, window, frame_length, hop_length, n_fft,
+ n_fft // 2 + 1, power, normalized, center, pad_mode,
+ apply_frame_processing=frame_proc,
+ )
+ else:
+ spectrogram = _torch_stft(
+ waveform, window, frame_length, hop_length, n_fft,
+ power, normalized, center, pad_mode,
+ )
+
+ spectrogram = spectrogram.reshape(*leading_shape, spectrogram.shape[-2], spectrogram.shape[-1])
+ if is_1d:
+ spectrogram = spectrogram.squeeze(0)
+ spectrogram = spectrogram.float()
+
+ num_frequency_bins = spectrogram.shape[-2]
+ mel_filters = mel_filter_bank_torch(
+ num_frequency_bins, n_mels, f_min, f_max, sampling_rate,
+ norm=norm, mel_scale=mel_scale,
+ triangularize_in_mel_space=triangularize_in_mel_space,
+ ).to(spectrogram.device)
+
+ return _apply_mel_scale(spectrogram, mel_filters, mel_floor=mel_floor)
+
+
+class MelSpectrogram(torch.nn.Module):
+ """Cached mel spectrogram transform — precomputes window and mel filterbank.
+
+ Same API and exact same results as the functional ``mel_spectrogram``, but
+ avoids recomputing the window and mel filterbank on every call.
+
+ Usage::
+
+ transform = MelSpectrogram(sampling_rate=16000, n_fft=1024, n_mels=80)
+ transform = transform.cuda() # move buffers to GPU once
+ mel = transform(waveform) # fast repeated calls
+ """
+
+ def __init__(
+ self,
+ sampling_rate: int,
+ *,
+ n_fft: int = 400,
+ win_length: int | None = None,
+ hop_length: int | None = None,
+ window_fn: str = "hann_window",
+ wkwargs: dict | None = None,
+ power: float = 2.0,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ normalized: bool = False,
+ periodic: bool = True,
+ n_mels: int = 128,
+ f_min: float = 0.0,
+ f_max: float | None = None,
+ mel_scale: str = "htk",
+ norm: str | None = None,
+ triangularize_in_mel_space: bool = False,
+ dither: float = 0.0,
+ preemphasis: float | None = None,
+ remove_dc_offset: bool = False,
+ mel_floor: float = 1e-10,
+ ):
+ super().__init__()
+ self.sampling_rate = sampling_rate
+ self.n_fft = n_fft
+ self.win_length = win_length if win_length is not None else n_fft
+ self.hop_length = hop_length if hop_length is not None else self.win_length // 2
+ self.power = power
+ self.center = center
+ self.pad_mode = pad_mode
+ self.normalized = normalized
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max if f_max is not None else sampling_rate / 2.0
+ self.mel_floor = mel_floor
+ self.dither = dither
+ self.preemphasis = preemphasis
+ self.remove_dc_offset = remove_dc_offset
+
+ self._needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset
+
+ # Build window
+ window = window_function(self.win_length, name=window_fn, periodic=periodic, wkwargs=wkwargs)
+ window, self._frame_length = _prepare_window_and_framing(window, self.win_length, n_fft, self._needs_manual_framing)
+ self.register_buffer("window", window)
+
+ # Build mel filterbank
+ num_frequency_bins = n_fft // 2 + 1
+ mel_fb = mel_filter_bank_torch(
+ num_frequency_bins, n_mels, self.f_min, self.f_max, sampling_rate,
+ norm=norm, mel_scale=mel_scale,
+ triangularize_in_mel_space=triangularize_in_mel_space,
+ )
+ self.register_buffer("mel_filters", mel_fb)
+
+ def forward(self, waveform: torch.Tensor) -> torch.Tensor:
+ """Compute mel spectrogram.
+
+ Args:
+ waveform: Input of shape (..., time).
+
+ Returns:
+ Mel spectrogram of shape (..., n_mels, time).
+ """
+ is_1d = waveform.ndim == 1
+ if is_1d:
+ waveform = waveform.unsqueeze(0)
+
+ leading_shape = waveform.shape[:-1]
+ waveform = waveform.reshape(-1, waveform.shape[-1])
+
+ if self._needs_manual_framing:
+ frame_proc = lambda f: _apply_frame_processing(
+ f, dither=self.dither, preemphasis=self.preemphasis, remove_dc_offset=self.remove_dc_offset,
+ )
+ spec = _manual_stft(
+ waveform, self.window, self._frame_length, self.hop_length,
+ self.n_fft, self.n_fft // 2 + 1, self.power, self.normalized,
+ self.center, self.pad_mode,
+ apply_frame_processing=frame_proc,
+ )
+ else:
+ spec = _torch_stft(
+ waveform, self.window, self._frame_length, self.hop_length,
+ self.n_fft, self.power, self.normalized, self.center, self.pad_mode,
+ )
+
+ spec = spec.reshape(*leading_shape, spec.shape[-2], spec.shape[-1])
+ if is_1d:
+ spec = spec.squeeze(0)
+ spec = spec.float()
+
+ return _apply_mel_scale(spec, self.mel_filters, mel_floor=self.mel_floor)
diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py
index db0e67325d78..9b44e549df1b 100644
--- a/src/transformers/utils/deprecation.py
+++ b/src/transformers/utils/deprecation.py
@@ -33,6 +33,41 @@ class Action(ExplicitEnum):
RAISE = "raise"
+def deprecated_feature_extractor(audio_processor_class, old_class_name, version="5.5"):
+ """Create a deprecated FeatureExtractor alias for an AudioProcessor.
+
+ Uses dynamic class creation to reduce boilerplate across ~20 models.
+ """
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ f"`{old_class_name}` is deprecated and will be removed in v{version}. "
+ f"Use `{audio_processor_class.__name__}` instead.",
+ FutureWarning,
+ stacklevel=2,
+ )
+ super(type(self), self).__init__(*args, **kwargs)
+
+ def __init_subclass__(cls, **kwargs):
+ warnings.warn(
+ f"`{old_class_name}` is deprecated and will be removed in v{version}. "
+ f"Use `{audio_processor_class.__name__}` instead.",
+ FutureWarning,
+ )
+ super(type(cls), cls).__init_subclass__(**kwargs)
+
+ return type(
+ old_class_name,
+ (audio_processor_class,),
+ {
+ "__init__": __init__,
+ "__init_subclass__": __init_subclass__,
+ "__module__": audio_processor_class.__module__,
+ "__doc__": f"Deprecated. Use {audio_processor_class.__name__} instead.",
+ },
+ )
+
+
def deprecate_kwarg(
old_name: str,
version: str,