From 092c6a4aacb2bb0dca39d4f1dbd7eeae7fe1eed5 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 2 Mar 2026 15:46:13 +0100 Subject: [PATCH 01/28] refacto to introduce ProcessingMixin --- src/transformers/feature_extraction_utils.py | 360 +---------------- src/transformers/image_processing_base.py | 366 +---------------- src/transformers/processing_base.py | 393 +++++++++++++++++++ 3 files changed, 417 insertions(+), 702 deletions(-) create mode 100644 src/transformers/processing_base.py diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index f1b66f752da4..9b49c4cab5c3 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -15,20 +15,15 @@ Feature extraction saving/loading class for common feature extractors. """ -import copy -import json import os from collections import UserDict from typing import TYPE_CHECKING, Any, TypeVar, Union import numpy as np -from huggingface_hub import create_repo, is_offline_mode -from .dynamic_module_utils import custom_object_save +from .processing_base import ProcessingMixin from .utils import ( FEATURE_EXTRACTOR_NAME, - PROCESSOR_NAME, - PushToHubMixin, TensorType, _is_tensor_or_array_like, copy_func, @@ -38,9 +33,7 @@ is_torch_dtype, logging, requires_backends, - safe_load_json_file, ) -from .utils.hub import cached_file if TYPE_CHECKING: @@ -263,170 +256,21 @@ def maybe_to(v): return self -class FeatureExtractionMixin(PushToHubMixin): +class FeatureExtractionMixin(ProcessingMixin): """ This is a feature extraction mixin used to provide saving/loading functionality for sequential and audio feature extractors. """ - _auto_class = None - - def __init__(self, **kwargs): - """Set elements of `kwargs` as attributes.""" - # Pop "processor_class", it should not be saved in feature extractor config - kwargs.pop("processor_class", None) - # Additional attributes without default values - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - - @classmethod - def from_pretrained( - cls: type[SpecificFeatureExtractorType], - pretrained_model_name_or_path: str | os.PathLike, - cache_dir: str | os.PathLike | None = None, - force_download: bool = False, - local_files_only: bool = False, - token: str | bool | None = None, - revision: str = "main", - **kwargs, - ) -> SpecificFeatureExtractorType: - r""" - Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a - derived class of [`SequenceFeatureExtractor`]. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on - huggingface.co. - - a path to a *directory* containing a feature extractor file saved using the - [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g., - `./my_model_directory/`. - - a path or url to a saved feature extractor JSON *file*, e.g., - `./my_model_directory/preprocessor_config.json`. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model feature extractor should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the feature extractor files and override the cached versions - if they exist. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - - - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - If `False`, then this function returns just the final feature extractor object. If `True`, then this - functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary - consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of - `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored. - kwargs (`dict[str, Any]`, *optional*): - The values in kwargs of any keys which are feature extractor attributes will be used to override the - loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is - controlled by the `return_unused_kwargs` keyword parameter. - - Returns: - A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]. - - Examples: - - ```python - # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a - # derived class: *Wav2Vec2FeatureExtractor* - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "facebook/wav2vec2-base-960h" - ) # Download feature_extraction_config from huggingface.co and cache. - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "./test/saved_model/" - ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')* - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json") - feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( - "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False - ) - assert feature_extractor.return_attention_mask is False - feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained( - "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True - ) - assert feature_extractor.return_attention_mask is False - assert unused_kwargs == {"foo": False} - ```""" - kwargs["cache_dir"] = cache_dir - kwargs["force_download"] = force_download - kwargs["local_files_only"] = local_files_only - kwargs["revision"] = revision - - if token is not None: - kwargs["token"] = token - - feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs) - - return cls.from_dict(feature_extractor_dict, **kwargs) - - def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): - """ - Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the - [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the feature extractor JSON file will be saved (will be created if it does not exist). - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id - files_timestamps = self._get_files_timestamps(save_directory) - - # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be - # loaded from the Hub. - if self._auto_class is not None: - custom_object_save(self, save_directory, config=self) - - # If we save using the predefined names, we can load using `from_pretrained` - output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME) - - self.to_json_file(output_feature_extractor_file) - logger.info(f"Feature extractor saved in {output_feature_extractor_file}") - - if push_to_hub: - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=kwargs.get("token"), - ) - - return [output_feature_extractor_file] + _config_name = FEATURE_EXTRACTOR_NAME + _type_key = "feature_extractor_type" + _nested_config_keys = ["feature_extractor", "audio_processor"] + _auto_class_default = "AutoFeatureExtractor" + _file_type_label = "feature extractor" + _excluded_dict_keys = {"mel_filters", "window"} + _extra_init_pops = [] + _config_filename_kwarg = None + _subfolder_default = None @classmethod def get_feature_extractor_dict( @@ -443,104 +287,7 @@ def get_feature_extractor_dict( Returns: `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object. """ - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - subfolder = kwargs.pop("subfolder", None) - token = kwargs.pop("token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - - user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class} - if from_pipeline is not None: - user_agent["using_pipeline"] = from_pipeline - - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME) - if os.path.isfile(pretrained_model_name_or_path): - resolved_feature_extractor_file = pretrained_model_name_or_path - resolved_processor_file = None - is_local = True - else: - feature_extractor_file = FEATURE_EXTRACTOR_NAME - try: - # Load from local folder or from cache or download from model Hub and cache - resolved_processor_file = cached_file( - pretrained_model_name_or_path, - filename=PROCESSOR_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - ) - resolved_feature_extractor_file = cached_file( - pretrained_model_name_or_path, - filename=feature_extractor_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - ) - except OSError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to - # the original exception. - raise - except Exception: - # For any other exception, we throw a generic error. - raise OSError( - f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" - " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a {FEATURE_EXTRACTOR_NAME} file" - ) - - # Load feature_extractor dict. Priority goes as (nested config if found -> image processor config) - # We are downloading both configs because almost all models have a `processor_config.json` but - # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style - feature_extractor_dict = None - if resolved_processor_file is not None: - processor_dict = safe_load_json_file(resolved_processor_file) - if "feature_extractor" in processor_dict or "audio_processor" in processor_dict: - feature_extractor_dict = processor_dict.get("feature_extractor", processor_dict.get("audio_processor")) - - if resolved_feature_extractor_file is not None and feature_extractor_dict is None: - feature_extractor_dict = safe_load_json_file(resolved_feature_extractor_file) - - if feature_extractor_dict is None: - raise OSError( - f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load" - " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a {feature_extractor_file} file" - ) - - if is_local: - logger.info(f"loading configuration file {resolved_feature_extractor_file}") - else: - logger.info( - f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}" - ) - - return feature_extractor_dict, kwargs + return cls._get_config_dict(pretrained_model_name_or_path, **kwargs) @classmethod def from_dict( @@ -581,89 +328,6 @@ def from_dict( else: return feature_extractor - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - output = copy.deepcopy(self.__dict__) - output["feature_extractor_type"] = self.__class__.__name__ - if "mel_filters" in output: - del output["mel_filters"] - if "window" in output: - del output["window"] - return output - - @classmethod - def from_json_file(cls, json_file: str | os.PathLike) -> "FeatureExtractionMixin": - """ - Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to - a JSON file of parameters. - - Args: - json_file (`str` or `os.PathLike`): - Path to the JSON file containing the parameters. - - Returns: - A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor - object instantiated from that JSON file. - """ - with open(json_file, encoding="utf-8") as reader: - text = reader.read() - feature_extractor_dict = json.loads(text) - return cls(**feature_extractor_dict) - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. - """ - dictionary = self.to_dict() - - for key, value in dictionary.items(): - if isinstance(value, np.ndarray): - dictionary[key] = value.tolist() - - return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: str | os.PathLike): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this feature_extractor instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @classmethod - def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"): - """ - Register this class with a given auto class. This should only be used for custom feature extractors as the ones - in the library are already mapped with `AutoFeatureExtractor`. - - - - Args: - auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`): - The auto class to register this new feature extractor with. - """ - if not isinstance(auto_class, str): - auto_class = auto_class.__name__ - - import transformers.models.auto as auto_module - - if not hasattr(auto_module, auto_class): - raise ValueError(f"{auto_class} is not a valid auto class.") - - cls._auto_class = auto_class - FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub) if FeatureExtractionMixin.push_to_hub.__doc__ is not None: diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index 72db8fcc9bec..215f4d8baa2c 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -12,26 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import json import os from typing import Any, TypeVar -import numpy as np -from huggingface_hub import create_repo, is_offline_mode - -from .dynamic_module_utils import custom_object_save from .feature_extraction_utils import BatchFeature as BaseBatchFeature from .image_utils import is_valid_image, load_image +from .processing_base import ProcessingMixin from .utils import ( IMAGE_PROCESSOR_NAME, - PROCESSOR_NAME, - PushToHubMixin, copy_func, logging, - safe_load_json_file, ) -from .utils.hub import cached_file ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin") @@ -58,175 +49,21 @@ class BatchFeature(BaseBatchFeature): # TODO: (Amy) - factor out the common parts of this and the feature extractor -class ImageProcessingMixin(PushToHubMixin): +class ImageProcessingMixin(ProcessingMixin): """ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature extractors. """ - _auto_class = None - - def __init__(self, **kwargs): - """Set elements of `kwargs` as attributes.""" - # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use - # `XXXImageProcessor`, this attribute and its value are misleading. - kwargs.pop("feature_extractor_type", None) - # Pop "processor_class", should not be saved with image processing config anymore - kwargs.pop("processor_class", None) - # Additional attributes without default values - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - - @classmethod - def from_pretrained( - cls: type[ImageProcessorType], - pretrained_model_name_or_path: str | os.PathLike, - cache_dir: str | os.PathLike | None = None, - force_download: bool = False, - local_files_only: bool = False, - token: str | bool | None = None, - revision: str = "main", - **kwargs, - ) -> ImageProcessorType: - r""" - Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained image_processor hosted inside a model repo on - huggingface.co. - - a path to a *directory* containing a image processor file saved using the - [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g., - `./my_model_directory/`. - - a path or url to a saved image processor JSON *file*, e.g., - `./my_model_directory/preprocessor_config.json`. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model image processor should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the image processor files and override the cached versions if - they exist. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - token (`str` or `bool`, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use - the token generated when running `hf auth login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - - - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - If `False`, then this function returns just the final image processor object. If `True`, then this - functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary - consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of - `kwargs` which has not been used to update `image_processor` and is otherwise ignored. - subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can - specify the folder name here. - kwargs (`dict[str, Any]`, *optional*): - The values in kwargs of any keys which are image processor attributes will be used to override the - loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is - controlled by the `return_unused_kwargs` keyword parameter. - - Returns: - A image processor of type [`~image_processing_utils.ImageProcessingMixin`]. - - Examples: - - ```python - # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a - # derived class: *CLIPImageProcessor* - image_processor = CLIPImageProcessor.from_pretrained( - "openai/clip-vit-base-patch32" - ) # Download image_processing_config from huggingface.co and cache. - image_processor = CLIPImageProcessor.from_pretrained( - "./test/saved_model/" - ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')* - image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json") - image_processor = CLIPImageProcessor.from_pretrained( - "openai/clip-vit-base-patch32", do_normalize=False, foo=False - ) - assert image_processor.do_normalize is False - image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained( - "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True - ) - assert image_processor.do_normalize is False - assert unused_kwargs == {"foo": False} - ```""" - kwargs["cache_dir"] = cache_dir - kwargs["force_download"] = force_download - kwargs["local_files_only"] = local_files_only - kwargs["revision"] = revision - - if token is not None: - kwargs["token"] = token - - image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs) - - return cls.from_dict(image_processor_dict, **kwargs) - - def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): - """ - Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the - [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method. - - Args: - save_directory (`str` or `os.PathLike`): - Directory where the image processor JSON file will be saved (will be created if it does not exist). - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the - repository you want to push to with `repo_id` (will default to the name of `save_directory` in your - namespace). - kwargs (`dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. - """ - if os.path.isfile(save_directory): - raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") - - os.makedirs(save_directory, exist_ok=True) - - if push_to_hub: - commit_message = kwargs.pop("commit_message", None) - repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) - repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id - files_timestamps = self._get_files_timestamps(save_directory) - - # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be - # loaded from the Hub. - if self._auto_class is not None: - custom_object_save(self, save_directory, config=self) - - # If we save using the predefined names, we can load using `from_pretrained` - output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME) - - self.to_json_file(output_image_processor_file) - logger.info(f"Image processor saved in {output_image_processor_file}") - - if push_to_hub: - self._upload_modified_files( - save_directory, - repo_id, - files_timestamps, - commit_message=commit_message, - token=kwargs.get("token"), - ) - - return [output_image_processor_file] + _config_name = IMAGE_PROCESSOR_NAME + _type_key = "image_processor_type" + _nested_config_keys = ["image_processor"] + _auto_class_default = "AutoImageProcessor" + _file_type_label = "image processor" + _excluded_dict_keys = set() + _extra_init_pops = ["feature_extractor_type"] + _config_filename_kwarg = "image_processor_filename" + _subfolder_default = "" @classmethod def get_image_processor_dict( @@ -248,104 +85,7 @@ def get_image_processor_dict( Returns: `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object. """ - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - token = kwargs.pop("token", None) - local_files_only = kwargs.pop("local_files_only", False) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", "") - image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME) - - from_pipeline = kwargs.pop("_from_pipeline", None) - from_auto_class = kwargs.pop("_from_auto", False) - - user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class} - if from_pipeline is not None: - user_agent["using_pipeline"] = from_pipeline - - if is_offline_mode() and not local_files_only: - logger.info("Offline mode: forcing local_files_only=True") - local_files_only = True - - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): - image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename) - if os.path.isfile(pretrained_model_name_or_path): - resolved_image_processor_file = pretrained_model_name_or_path - resolved_processor_file = None - is_local = True - else: - image_processor_file = image_processor_filename - try: - resolved_processor_file = cached_file( - pretrained_model_name_or_path, - filename=PROCESSOR_NAME, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - ) - resolved_image_processor_file = cached_file( - pretrained_model_name_or_path, - filename=image_processor_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - user_agent=user_agent, - revision=revision, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - ) - except OSError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to - # the original exception. - raise - except Exception: - # For any other exception, we throw a generic error. - raise OSError( - f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" - " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a {image_processor_filename} file" - ) - - # Load image_processor dict. Priority goes as (nested config if found -> image processor config) - # We are downloading both configs because almost all models have a `processor_config.json` but - # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style - image_processor_dict = None - if resolved_processor_file is not None: - processor_dict = safe_load_json_file(resolved_processor_file) - if "image_processor" in processor_dict: - image_processor_dict = processor_dict["image_processor"] - - if resolved_image_processor_file is not None and image_processor_dict is None: - image_processor_dict = safe_load_json_file(resolved_image_processor_file) - - if image_processor_dict is None: - raise OSError( - f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load" - " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a {image_processor_filename} file" - ) - - if is_local: - logger.info(f"loading configuration file {resolved_image_processor_file}") - else: - logger.info( - f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}" - ) - - return image_processor_dict, kwargs + return cls._get_config_dict(pretrained_model_name_or_path, **kwargs) @classmethod def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): @@ -388,88 +128,6 @@ def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): else: return image_processor - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. - - Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance. - """ - output = copy.deepcopy(self.__dict__) - output["image_processor_type"] = self.__class__.__name__ - - return output - - @classmethod - def from_json_file(cls, json_file: str | os.PathLike): - """ - Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON - file of parameters. - - Args: - json_file (`str` or `os.PathLike`): - Path to the JSON file containing the parameters. - - Returns: - A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object - instantiated from that JSON file. - """ - with open(json_file, encoding="utf-8") as reader: - text = reader.read() - image_processor_dict = json.loads(text) - return cls(**image_processor_dict) - - def to_json_string(self) -> str: - """ - Serializes this instance to a JSON string. - - Returns: - `str`: String containing all the attributes that make up this feature_extractor instance in JSON format. - """ - dictionary = self.to_dict() - - for key, value in dictionary.items(): - if isinstance(value, np.ndarray): - dictionary[key] = value.tolist() - - return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path: str | os.PathLike): - """ - Save this instance to a JSON file. - - Args: - json_file_path (`str` or `os.PathLike`): - Path to the JSON file in which this image_processor instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string()) - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - @classmethod - def register_for_auto_class(cls, auto_class="AutoImageProcessor"): - """ - Register this class with a given auto class. This should only be used for custom image processors as the ones - in the library are already mapped with `AutoImageProcessor `. - - - - Args: - auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`): - The auto class to register this new image processor with. - """ - if not isinstance(auto_class, str): - auto_class = auto_class.__name__ - - import transformers.models.auto as auto_module - - if not hasattr(auto_module, auto_class): - raise ValueError(f"{auto_class} is not a valid auto class.") - - cls._auto_class = auto_class - def fetch_images(self, image_url_or_urls: str | list[str] | list[list[str]]): """ Convert a single or a list of urls into the corresponding `PIL.Image` objects. diff --git a/src/transformers/processing_base.py b/src/transformers/processing_base.py new file mode 100644 index 000000000000..ff9a7158ff56 --- /dev/null +++ b/src/transformers/processing_base.py @@ -0,0 +1,393 @@ +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base mixin for image processors and feature extractors, providing shared +save/load/serialization logic. +""" + +import copy +import json +import os +from typing import Any, TypeVar + +import numpy as np +from huggingface_hub import create_repo, is_offline_mode + +from .dynamic_module_utils import custom_object_save +from .utils import ( + PROCESSOR_NAME, + PushToHubMixin, + logging, + safe_load_json_file, +) +from .utils.hub import cached_file + + +logger = logging.get_logger(__name__) + +ProcessingMixinType = TypeVar("ProcessingMixinType", bound="ProcessingMixin") + + +class ProcessingMixin(PushToHubMixin): + """ + Base mixin providing saving/loading functionality shared by + ImageProcessingMixin and FeatureExtractionMixin. + + Subclasses must set the following class attributes: + _config_name: str — config file name (e.g. IMAGE_PROCESSOR_NAME) + _type_key: str — key added in to_dict() (e.g. "image_processor_type") + _nested_config_keys: list — keys to check in processor_config.json + _auto_class_default: str — default auto class for register_for_auto_class + _file_type_label: str — label for user-agent / error messages + Optional: + _excluded_dict_keys: set — keys to drop from to_dict() output + _extra_init_pops: list — extra keys to pop in __init__ + _config_filename_kwarg: str — kwarg name that can override the config filename + _subfolder_default: str — default for the subfolder kwarg + """ + + _auto_class = None + + # --- Must be overridden by subclasses --- + _config_name: str + _type_key: str + _nested_config_keys: list[str] = [] + _auto_class_default: str + _file_type_label: str + + # --- Optional overrides --- + _excluded_dict_keys: set[str] = set() + _extra_init_pops: list[str] = [] + _config_filename_kwarg: str | None = None + _subfolder_default: str | None = "" + + def __init__(self, **kwargs): + """Set elements of `kwargs` as attributes.""" + for key in self._extra_init_pops: + kwargs.pop(key, None) + # Pop "processor_class", should not be saved in config + kwargs.pop("processor_class", None) + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + @classmethod + def from_pretrained( + cls: type[ProcessingMixinType], + pretrained_model_name_or_path: str | os.PathLike, + cache_dir: str | os.PathLike | None = None, + force_download: bool = False, + local_files_only: bool = False, + token: str | bool | None = None, + revision: str = "main", + **kwargs, + ) -> ProcessingMixinType: + r""" + Instantiate a processor from a pretrained model name or path. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained processor hosted inside a model repo on + huggingface.co. + - a path to a *directory* containing a processor file saved using the + [`~ProcessingMixin.save_pretrained`] method, e.g., `./my_model_directory/`. + - a path or url to a saved processor JSON *file*, e.g., + `./my_model_directory/preprocessor_config.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model processor should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the processor files and override the cached versions if + they exist. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final processor object. If `True`, then this + functions returns a `Tuple(processor, unused_kwargs)` where *unused_kwargs* is a dictionary + consisting of the key/value pairs whose keys are not processor attributes. + kwargs (`dict[str, Any]`, *optional*): + The values in kwargs of any keys which are processor attributes will be used to override the + loaded values. + + Returns: + A processor of type [`~ProcessingMixin`]. + """ + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs["revision"] = revision + + if token is not None: + kwargs["token"] = token + + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + return cls.from_dict(config_dict, **kwargs) + + def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): + """ + Save a processor object to the directory `save_directory`, so that it can be re-loaded using the + [`~ProcessingMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the processor JSON file will be saved (will be created if it does not exist). + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + kwargs (`dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = create_repo(repo_id, exist_ok=True, **kwargs).repo_id + files_timestamps = self._get_files_timestamps(save_directory) + + # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self) + + # If we save using the predefined names, we can load using `from_pretrained` + output_file = os.path.join(save_directory, self._config_name) + + self.to_json_file(output_file) + logger.info(f"{self._file_type_label} saved in {output_file}") + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + + return [output_file] + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + processor using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object. + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", cls._subfolder_default) + + # Allow overriding the config filename via a kwarg (e.g. image_processor_filename) + if cls._config_filename_kwarg is not None: + config_filename = kwargs.pop(cls._config_filename_kwarg, cls._config_name) + else: + config_filename = cls._config_name + + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": cls._file_type_label, "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, config_filename) + if os.path.isfile(pretrained_model_name_or_path): + resolved_config_file = pretrained_model_name_or_path + resolved_processor_file = None + is_local = True + else: + config_file = config_filename + try: + resolved_processor_file = cached_file( + pretrained_model_name_or_path, + filename=PROCESSOR_NAME, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + resolved_config_file = cached_file( + pretrained_model_name_or_path, + filename=config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load {cls._file_type_label} for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {config_filename} file" + ) + + # Load config dict. Priority goes as (nested config if found -> standalone config) + # We are downloading both configs because almost all models have a `processor_config.json` but + # not all of these are nested. We need to check if it was saved recebtly as nested or if it is legacy style + config_dict = None + if resolved_processor_file is not None: + processor_dict = safe_load_json_file(resolved_processor_file) + for nested_key in cls._nested_config_keys: + if nested_key in processor_dict: + config_dict = processor_dict[nested_key] + break + + if resolved_config_file is not None and config_dict is None: + config_dict = safe_load_json_file(resolved_config_file) + + if config_dict is None: + raise OSError( + f"Can't load {cls._file_type_label} for '{pretrained_model_name_or_path}'. If you were trying to load" + " it from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a {config_filename} file" + ) + + if is_local: + logger.info(f"loading configuration file {resolved_config_file}") + else: + logger.info( + f"loading configuration file {config_file} from cache at {resolved_config_file}" + ) + + return config_dict, kwargs + + def to_dict(self) -> dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `dict[str, Any]`: Dictionary of all the attributes that make up this instance. + """ + output = copy.deepcopy(self.__dict__) + output[self._type_key] = self.__class__.__name__ + for key in self._excluded_dict_keys: + if key in output: + del output[key] + return output + + @classmethod + def from_json_file(cls, json_file: str | os.PathLike): + """ + Instantiates a processor from the path to a JSON file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + A processor of type [`~ProcessingMixin`]: The processor object instantiated from that JSON file. + """ + with open(json_file, encoding="utf-8") as reader: + text = reader.read() + config_dict = json.loads(text) + return cls(**config_dict) + + def to_json_string(self) -> str: + """ + Serializes this instance to a JSON string. + + Returns: + `str`: String containing all the attributes that make up this instance in JSON format. + """ + dictionary = self.to_dict() + + for key, value in dictionary.items(): + if isinstance(value, np.ndarray): + dictionary[key] = value.tolist() + + return json.dumps(dictionary, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: str | os.PathLike): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string()) + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + @classmethod + def register_for_auto_class(cls, auto_class=None): + """ + Register this class with a given auto class. + + Args: + auto_class (`str` or `type`, *optional*): + The auto class to register this new processor with. Defaults to the subclass's `_auto_class_default`. + """ + if auto_class is None: + auto_class = cls._auto_class_default + + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class From 69357b82f4782aa7b1b1ca1d24cbfc879bf0fed6 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 2 Mar 2026 18:54:07 +0100 Subject: [PATCH 02/28] draft --- src/transformers/audio_processing_base.py | 107 +++++++++ src/transformers/audio_processing_utils.py | 224 ++++++++++++++++++ src/transformers/feature_extraction_utils.py | 4 +- src/transformers/image_processing_base.py | 45 +--- .../wav2vec2/audio_processing_wav2vec2.py | 102 ++++++++ .../whisper/audio_processing_whisper.py | 161 +++++++++++++ ...ocessing_base.py => preprocessing_base.py} | 91 ++++++- src/transformers/processing_utils.py | 3 + 8 files changed, 683 insertions(+), 54 deletions(-) create mode 100644 src/transformers/audio_processing_base.py create mode 100644 src/transformers/audio_processing_utils.py create mode 100644 src/transformers/models/wav2vec2/audio_processing_wav2vec2.py create mode 100644 src/transformers/models/whisper/audio_processing_whisper.py rename src/transformers/{processing_base.py => preprocessing_base.py} (81%) diff --git a/src/transformers/audio_processing_base.py b/src/transformers/audio_processing_base.py new file mode 100644 index 000000000000..2d4b06a68678 --- /dev/null +++ b/src/transformers/audio_processing_base.py @@ -0,0 +1,107 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, TypeVar + +from .audio_utils import is_valid_audio, load_audio +from .feature_extraction_utils import BatchFeature as BaseBatchFeature +from .preprocessing_base import PreprocessingMixin +from .utils import ( + FEATURE_EXTRACTOR_NAME, + copy_func, + logging, +) + + +AudioProcessorType = TypeVar("AudioProcessorType", bound="AudioProcessingMixin") + + +logger = logging.get_logger(__name__) + + +class BatchFeature(BaseBatchFeature): + r""" + Holds the output of the audio processor specific `__call__` methods. + + This class is derived from a python dictionary and can be used as a dictionary. + + Args: + data (`dict`): + Dictionary of lists/arrays/tensors returned by the __call__ method ('input_values', 'input_features', etc.). + tensor_type (`Union[None, str, TensorType]`, *optional*): + You can give a tensor_type here to convert the lists of integers in PyTorch/Numpy Tensors at + initialization. + """ + + +class AudioProcessingMixin(PreprocessingMixin): + """ + This is an audio processor mixin used to provide saving/loading functionality for audio processors. + """ + + _config_name = FEATURE_EXTRACTOR_NAME + _type_key = "audio_processor_type" + _nested_config_keys = ["audio_processor", "feature_extractor"] + _auto_class_default = "AutoFeatureExtractor" + _file_type_label = "audio processor" + _excluded_dict_keys = {"mel_filters", "window"} + _extra_init_pops = ["feature_extractor_type"] + _config_filename_kwarg = "audio_processor_filename" + _subfolder_default = "" + + @classmethod + def get_audio_processor_dict( + cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs + ) -> tuple[dict[str, Any], dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating an + audio processor of type [`~audio_processing_base.AudioProcessingMixin`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + audio_processor_filename (`str`, *optional*, defaults to `"preprocessor_config.json"`): + The name of the file in the model directory to use for the audio processor config. + + Returns: + `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the audio processor object. + """ + return cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + def fetch_audio(self, audio_url_or_urls: str | list[str] | list[list[str]]): + """ + Convert a single or a list of urls into the corresponding `np.ndarray` objects. + + If a single url is passed, the return value will be a single object. If a list is passed a list of objects is + returned. + """ + if isinstance(audio_url_or_urls, list): + return [self.fetch_audio(x) for x in audio_url_or_urls] + elif isinstance(audio_url_or_urls, str): + return load_audio(audio_url_or_urls) + elif is_valid_audio(audio_url_or_urls): + return audio_url_or_urls + else: + raise TypeError(f"only a single or a list of entries is supported but got type={type(audio_url_or_urls)}") + + +AudioProcessingMixin.push_to_hub = copy_func(AudioProcessingMixin.push_to_hub) +if AudioProcessingMixin.push_to_hub.__doc__ is not None: + AudioProcessingMixin.push_to_hub.__doc__ = AudioProcessingMixin.push_to_hub.__doc__.format( + object="audio processor", object_class="AutoFeatureExtractor", object_files="audio processor file" + ) diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py new file mode 100644 index 000000000000..6095097526c1 --- /dev/null +++ b/src/transformers/audio_processing_utils.py @@ -0,0 +1,224 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache +from typing import Optional, Union, Unpack + +import numpy as np +from huggingface_hub.dataclasses import validate_typed_dict + +from .audio_processing_base import AudioProcessingMixin +from .audio_utils import AudioInput, make_list_of_audio +from .feature_extraction_utils import BatchFeature +from .image_utils import validate_kwargs +from .processing_utils import AudioKwargs +from .utils import TensorType, logging +from .utils.import_utils import is_torch_available, requires + + +if is_torch_available(): + import torch + import torch.nn.functional as F + + +logger = logging.get_logger(__name__) + + +@requires(backends=("torch",)) +class BaseAudioProcessor(AudioProcessingMixin): + model_input_names = ["audio"] + valid_kwargs = AudioKwargs + unused_kwargs = None + padding = True + padding_side = "right" + padding_value = 0.0 + + def __init__( + self, + sample_rate: int, + force_mono: bool, + **kwargs, + ): + self.sample_rate = sample_rate + self.force_mono = force_mono + + super().__init__(**kwargs) + + kwargs = self.filter_out_unused_kwargs(kwargs) + self._init_kwargs_from_valid_kwargs(kwargs) + + def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: + return self.preprocess(audio, *args, **kwargs) + + def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: + # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same + validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names) + + # Perform type validation on received kwargs + validate_typed_dict(self.valid_kwargs, kwargs) + + # Set default kwargs from self. This ensures that if a kwarg is not provided + # by the user, it gets its default value from the instance, or is set to None. + for kwarg_name in self._valid_kwargs_names: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Update kwargs that need further processing before being validated + kwargs = self._further_process_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + return self._preprocess_audio_like_inputs(audio, *args, **kwargs) + + def _further_process_kwargs( + self, + **kwargs, + ) -> dict: + return kwargs + + def _validate_preprocess_kwargs( + self, + sample_rate: Optional[int] = None, + max_length: Optional[int] = None, + truncation: Optional[bool] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ): + """ + Validate the kwargs for the preprocess method. + """ + validate_preprocess_arguments( + sample_rate=sample_rate, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + ) + + def _preprocess_audio_like_inputs( + self, + audio: AudioInput, + *args, + sample_rate: Optional[int] = None, + **kwargs: Unpack[AudioKwargs], + ) -> BatchFeature: + audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) + return self._preprocess(audio, *args, **kwargs) + + def _prepare_audio_like_inputs(self, audio: AudioInput, sample_rate: Optional[int] = None) -> list["torch.Tensor"]: + if not (isinstance(audio, str) or (isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio))): + # NOTE: we want to force the user to either: + # 1. pass the sample rate when provided audio is array-type, to avoid silent errors that might be hard to debug + # 2. pass url-type audio inputs, that we can load in the correct sample rate directly + if sample_rate is not None: + if sample_rate != self.sample_rate: + raise ValueError( + f"The model corresponding to this audio processor: {self.__class__.__name__} was trained using a" + f" sample rate of {self.sample_rate}. Please make sure that the provided `audio` input" + f" was sampled with {self.sample_rate} and not {sample_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sample_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + elif isinstance(audio, str): + audio = [audio] + + audio = make_list_of_audio(audio) + + if self.force_mono: + # TODO: audio proc, to change + audio = [a.mean(axis=1) if a.ndim > 1 else a for a in audio] + + audio = [torch.from_numpy(audio_el) if isinstance(audio_el, np.ndarray) else audio_el for audio_el in audio] + + return audio + + def _preprocess( + self, + audio: list["torch.Tensor"], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + **kwargs, + ) -> BatchFeature: + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + is_batched = len(audio) > 1 + + if truncation and max_length is None: + raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + + if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: + logger.warning( + f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " + f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." + ) + truncation = True + + if truncation: + audio = [audio_el[..., :max_length] for audio_el in audio] + + if max_length is None: + max_length = max(audio_el.shape[-1] for audio_el in audio) + + if padding: + audio = [self.pad(audio_el, max_length) for audio_el in audio] + + audio = torch.stack(audio, dim=0) if return_tensors else audio + return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) + + def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": + current_length = audio.shape[-1] + if current_length >= max_length: + return audio + + if self.padding_value is None: + raise ValueError( + "Asking to pad but the audio processor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." + ) + + if self.padding_side == "right": + pad_args = (0, max_length - current_length) + elif self.padding_side == "left": + pad_args = (max_length - current_length, 0) + else: + raise ValueError(f"Invalid padding side: {self.padding_side}") + + return F.pad(audio, pad_args, "constant", self.padding_value) + + def to_dict(self): + return super().to_dict() + + +@lru_cache(maxsize=10) +def validate_preprocess_arguments( + sample_rate: Optional[int] = None, + max_length: Optional[int] = None, + truncation: Optional[bool] = None, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, +): + """ + Checks validity of typically used arguments in a `BaseAudioProcessor` `preprocess` method. + Raises `ValueError` if arguments incompatibility is caught. + """ + pass diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py index 9b49c4cab5c3..b30a056f7794 100644 --- a/src/transformers/feature_extraction_utils.py +++ b/src/transformers/feature_extraction_utils.py @@ -21,7 +21,7 @@ import numpy as np -from .processing_base import ProcessingMixin +from .preprocessing_base import PreprocessingMixin from .utils import ( FEATURE_EXTRACTOR_NAME, TensorType, @@ -256,7 +256,7 @@ def maybe_to(v): return self -class FeatureExtractionMixin(ProcessingMixin): +class FeatureExtractionMixin(PreprocessingMixin): """ This is a feature extraction mixin used to provide saving/loading functionality for sequential and audio feature extractors. diff --git a/src/transformers/image_processing_base.py b/src/transformers/image_processing_base.py index 215f4d8baa2c..79d2f7bf2aec 100644 --- a/src/transformers/image_processing_base.py +++ b/src/transformers/image_processing_base.py @@ -17,7 +17,7 @@ from .feature_extraction_utils import BatchFeature as BaseBatchFeature from .image_utils import is_valid_image, load_image -from .processing_base import ProcessingMixin +from .preprocessing_base import PreprocessingMixin from .utils import ( IMAGE_PROCESSOR_NAME, copy_func, @@ -49,7 +49,7 @@ class BatchFeature(BaseBatchFeature): # TODO: (Amy) - factor out the common parts of this and the feature extractor -class ImageProcessingMixin(ProcessingMixin): +class ImageProcessingMixin(PreprocessingMixin): """ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature extractors. @@ -87,47 +87,6 @@ def get_image_processor_dict( """ return cls._get_config_dict(pretrained_model_name_or_path, **kwargs) - @classmethod - def from_dict(cls, image_processor_dict: dict[str, Any], **kwargs): - """ - Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters. - - Args: - image_processor_dict (`dict[str, Any]`): - Dictionary that will be used to instantiate the image processor object. Such a dictionary can be - retrieved from a pretrained checkpoint by leveraging the - [`~image_processing_utils.ImageProcessingMixin.to_dict`] method. - kwargs (`dict[str, Any]`): - Additional parameters from which to initialize the image processor object. - - Returns: - [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those - parameters. - """ - image_processor_dict = image_processor_dict.copy() - return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - image_processor_dict.update({k: v for k, v in kwargs.items() if k in cls.valid_kwargs.__annotations__}) - image_processor = cls(**image_processor_dict) - - # Apply extra kwargs to instance (BC for remote code, e.g. phi4_multimodal) - extra_keys = [] - for key in reversed(list(kwargs.keys())): - if hasattr(image_processor, key) and key not in cls.valid_kwargs.__annotations__: - setattr(image_processor, key, kwargs.pop(key, None)) - extra_keys.append(key) - if extra_keys: - logger.warning_once( - f"Image processor {cls.__name__}: kwargs {extra_keys} were applied for backward compatibility. " - f"To avoid this warning, add them to valid_kwargs: create a custom TypedDict extending " - f"ImagesKwargs with these keys and set it as the `valid_kwargs` class attribute." - ) - - logger.info(f"Image processor {image_processor}") - if return_unused_kwargs: - return image_processor, kwargs - else: - return image_processor - def fetch_images(self, image_url_or_urls: str | list[str] | list[list[str]]): """ Convert a single or a list of urls into the corresponding `PIL.Image` objects. diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py new file mode 100644 index 000000000000..bcbcdd3f3f63 --- /dev/null +++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio processor class for Wav2Vec2 +""" + +from typing import Optional, Union + +import torch + +from ...audio_processing_utils import BaseAudioProcessor +from ...audio_utils import AudioInput +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class Wav2Vec2AudioProcessor(BaseAudioProcessor): + r""" + Constructs a Wav2Vec2 audio processor. + + This audio processor inherits from [`~audio_processing_utils.BaseAudioProcessor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + Args: + sample_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly + improve the performance for some models, *e.g.*, + [wav2vec2-lv60](https://huggingface.co/models?search=lv60). + """ + + model_input_names = ["input_values", "attention_mask"] + + def __init__( + self, + sample_rate: int = 16000, + do_normalize: bool = True, + force_mono: bool = True, + **kwargs, + ): + super().__init__( + sample_rate=sample_rate, + force_mono=force_mono, + **kwargs, + ) + self.do_normalize = do_normalize + + def _preprocess( + self, + audio: list[torch.Tensor], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + do_normalize: Optional[bool] = None, + **kwargs, + ) -> BatchFeature: + if do_normalize is None: + do_normalize = self.do_normalize + + # Truncation and padding via base class + result = super()._preprocess( + audio, + padding=padding, + max_length=max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=None, # we handle conversion after normalization + ) + + audio_tensors = result["audio"] + + if do_normalize: + audio_tensors = [self._zero_mean_unit_var_norm(t) for t in audio_tensors] + + input_values = torch.stack(audio_tensors, dim=0) if return_tensors else audio_tensors + return BatchFeature(data={"input_values": input_values}, tensor_type=return_tensors) + + @staticmethod + def _zero_mean_unit_var_norm(tensor: torch.Tensor) -> torch.Tensor: + """Zero-mean unit-variance normalize a tensor.""" + return (tensor - tensor.mean()) / torch.sqrt(tensor.var() + 1e-7) + + +__all__ = ["Wav2Vec2AudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py new file mode 100644 index 000000000000..1edfbc3d14b2 --- /dev/null +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -0,0 +1,161 @@ +# coding=utf-8 +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Audio processor class for Whisper +""" + +from typing import Optional, Union + +import numpy as np +import torch + +from ...audio_processing_utils import BaseAudioProcessor +from ...audio_utils import AudioInput, mel_filter_bank +from ...feature_extraction_utils import BatchFeature +from ...utils import TensorType, logging + + +logger = logging.get_logger(__name__) + + +class WhisperAudioProcessor(BaseAudioProcessor): + r""" + Constructs a Whisper audio processor. + + This audio processor inherits from [`~audio_processing_utils.BaseAudioProcessor`] which contains + most of the main methods. Users should refer to this superclass for more information regarding those methods. + + This class extracts mel-filter bank features from raw speech using PyTorch's `torch.stft`. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features (number of mel bins). + sample_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). + hop_length (`int`, *optional*, defaults to 160): + Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. + chunk_length (`int`, *optional*, defaults to 30): + The maximum number of seconds of audio used to trim and pad sequences. + n_fft (`int`, *optional*, defaults to 400): + Size of the Fourier transform. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering (small Gaussian noise) to each frame. Use 0.0 for no dithering. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size: int = 80, + sample_rate: int = 16000, + hop_length: int = 160, + chunk_length: int = 30, + n_fft: int = 400, + dither: float = 0.0, + force_mono: bool = True, + **kwargs, + ): + super().__init__( + sample_rate=sample_rate, + force_mono=force_mono, + **kwargs, + ) + self.feature_size = feature_size + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sample_rate + self.nb_max_frames = self.n_samples // hop_length + self.dither = dither + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + n_fft // 2, + num_mel_filters=feature_size, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=sample_rate, + norm="slaney", + mel_scale="slaney", + ) + + def _preprocess( + self, + audio: list[torch.Tensor], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + do_normalize: Optional[bool] = None, + device: Optional[str] = "cpu", + **kwargs, + ) -> BatchFeature: + # Default max_length to n_samples (chunk_length * sample_rate) + if max_length is None: + max_length = self.n_samples + + # Use base class for truncation + padding + result = super()._preprocess( + audio, + padding=padding if padding is not None else True, + max_length=max_length, + truncation=truncation if truncation is not None else True, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=None, # we handle conversion after feature extraction + ) + + audio_tensors = result["audio"] + + # Zero-mean unit-variance normalization (before spectrogram) + if do_normalize: + audio_tensors = [(t - t.mean()) / torch.sqrt(t.var() + 1e-7) for t in audio_tensors] + + # Stack into batch for spectrogram extraction + waveform_batch = torch.stack(audio_tensors, dim=0).to(device, torch.float32) + + # Extract log-mel spectrogram + input_features = self._extract_fbank_features(waveform_batch, device) + + return BatchFeature(data={"input_features": input_features}, tensor_type=return_tensors) + + def _extract_fbank_features(self, waveform: torch.Tensor, device: str = "cpu") -> torch.Tensor: + """ + Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation. + """ + window = torch.hann_window(self.n_fft, device=device) + + if self.dither != 0.0: + waveform = waveform + self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) + + stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + if device != "cpu": + log_spec = log_spec.detach().cpu() + + return log_spec + + +__all__ = ["WhisperAudioProcessor"] diff --git a/src/transformers/processing_base.py b/src/transformers/preprocessing_base.py similarity index 81% rename from src/transformers/processing_base.py rename to src/transformers/preprocessing_base.py index ff9a7158ff56..1282d1d35ac1 100644 --- a/src/transformers/processing_base.py +++ b/src/transformers/preprocessing_base.py @@ -19,6 +19,7 @@ import copy import json import os +from copy import deepcopy from typing import Any, TypeVar import numpy as np @@ -36,13 +37,13 @@ logger = logging.get_logger(__name__) -ProcessingMixinType = TypeVar("ProcessingMixinType", bound="ProcessingMixin") +PreprocessingMixinType = TypeVar("PreprocessingMixinType", bound="PreprocessingMixin") -class ProcessingMixin(PushToHubMixin): +class PreprocessingMixin(PushToHubMixin): """ Base mixin providing saving/loading functionality shared by - ImageProcessingMixin and FeatureExtractionMixin. + ImageProcessingMixin, AudioProcessingMixin and FeatureExtractionMixin. Subclasses must set the following class attributes: _config_name: str — config file name (e.g. IMAGE_PROCESSOR_NAME) @@ -86,9 +87,80 @@ def __init__(self, **kwargs): logger.error(f"Can't set {key} with value {value} for {self}") raise err + def _init_kwargs_from_valid_kwargs(self, kwargs: dict): + """ + Initialize instance attributes from `valid_kwargs` annotations. + + For each key in `self.valid_kwargs.__annotations__`, pops it from `kwargs` + and sets it on the instance (or deep-copies the class default). + Also sets `self._valid_kwargs_names`. + """ + for key in self.valid_kwargs.__annotations__: + kwarg = kwargs.pop(key, None) + if kwarg is not None: + setattr(self, key, kwarg) + else: + setattr(self, key, deepcopy(getattr(self, key, None))) + self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys()) + + def filter_out_unused_kwargs(self, kwargs: dict) -> dict: + """ + Filter out the unused kwargs from the kwargs dictionary. + """ + if self.unused_kwargs is None: + return kwargs + + for kwarg_name in self.unused_kwargs: + if kwarg_name in kwargs: + logger.warning_once(f"This processor does not use the `{kwarg_name}` parameter. It will be ignored.") + kwargs.pop(kwarg_name) + return kwargs + + @classmethod + def from_dict(cls, config_dict: dict[str, Any], **kwargs): + """ + Instantiates a processor from a Python dictionary of parameters. + + Args: + config_dict (`dict[str, Any]`): + Dictionary that will be used to instantiate the processor object. + kwargs (`dict[str, Any]`): + Additional parameters from which to initialize the processor object. + + Returns: + A processor of type [`~PreprocessingMixin`]. + """ + config_dict = config_dict.copy() + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + # Use valid_kwargs pattern when available (image/audio processors) + if hasattr(cls, "valid_kwargs") and hasattr(cls.valid_kwargs, "__annotations__"): + config_dict.update({k: v for k, v in kwargs.items() if k in cls.valid_kwargs.__annotations__}) + processor = cls(**config_dict) + + # Apply extra kwargs to instance (BC for remote code) + extra_keys = [] + for key in reversed(list(kwargs.keys())): + if hasattr(processor, key) and key not in cls.valid_kwargs.__annotations__: + setattr(processor, key, kwargs.pop(key, None)) + extra_keys.append(key) + if extra_keys: + logger.warning_once( + f"Processor {cls.__name__}: kwargs {extra_keys} were applied for backward compatibility. " + f"To avoid this warning, add them to valid_kwargs." + ) + else: + processor = cls(**config_dict) + + logger.info(f"Processor {processor}") + if return_unused_kwargs: + return processor, kwargs + else: + return processor + @classmethod def from_pretrained( - cls: type[ProcessingMixinType], + cls: type[PreprocessingMixinType], pretrained_model_name_or_path: str | os.PathLike, cache_dir: str | os.PathLike | None = None, force_download: bool = False, @@ -96,7 +168,7 @@ def from_pretrained( token: str | bool | None = None, revision: str = "main", **kwargs, - ) -> ProcessingMixinType: + ) -> PreprocessingMixinType: r""" Instantiate a processor from a pretrained model name or path. @@ -107,7 +179,7 @@ def from_pretrained( - a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co. - a path to a *directory* containing a processor file saved using the - [`~ProcessingMixin.save_pretrained`] method, e.g., `./my_model_directory/`. + [`~PreprocessingMixin.save_pretrained`] method, e.g., `./my_model_directory/`. - a path or url to a saved processor JSON *file*, e.g., `./my_model_directory/preprocessor_config.json`. cache_dir (`str` or `os.PathLike`, *optional*): @@ -129,7 +201,7 @@ def from_pretrained( loaded values. Returns: - A processor of type [`~ProcessingMixin`]. + A processor of type [`~PreprocessingMixin`]. """ kwargs["cache_dir"] = cache_dir kwargs["force_download"] = force_download @@ -146,7 +218,7 @@ def from_pretrained( def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs): """ Save a processor object to the directory `save_directory`, so that it can be re-loaded using the - [`~ProcessingMixin.from_pretrained`] class method. + [`~PreprocessingMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): @@ -319,6 +391,7 @@ def to_dict(self) -> dict[str, Any]: """ output = copy.deepcopy(self.__dict__) output[self._type_key] = self.__class__.__name__ + output.pop("_valid_kwargs_names", None) for key in self._excluded_dict_keys: if key in output: del output[key] @@ -334,7 +407,7 @@ def from_json_file(cls, json_file: str | os.PathLike): Path to the JSON file containing the parameters. Returns: - A processor of type [`~ProcessingMixin`]: The processor object instantiated from that JSON file. + A processor of type [`~PreprocessingMixin`]: The processor object instantiated from that JSON file. """ with open(json_file, encoding="utf-8") as reader: text = reader.read() diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index fb1bd18c6239..cbbdd63d8110 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -408,6 +408,7 @@ class AudioKwargs(TypedDict, total=False): - `'np'`: Return NumPy `np.ndarray` objects. """ + sample_rate: Annotated[int | None, positive_int()] sampling_rate: Annotated[int | None, positive_int()] raw_speech: Union["np.ndarray", list[float], list["np.ndarray"], list[list[float]]] | None padding: Annotated[bool | str | PaddingStrategy | None, padding_validator()] @@ -416,6 +417,8 @@ class AudioKwargs(TypedDict, total=False): pad_to_multiple_of: Annotated[int | None, positive_int()] return_attention_mask: bool | None return_tensors: Annotated[str | TensorType | None, tensor_type_validator()] + do_normalize: bool | None + device: str | None class ProcessingKwargs(TypedDict, total=False): From 9b018cb31375802e200f66ec36bd7d3ef1ec6810 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 4 Mar 2026 14:00:50 +0100 Subject: [PATCH 03/28] draft update --- src/transformers/__init__.py | 3 + src/transformers/audio_processing_backends.py | 194 ++++++++++ src/transformers/audio_processing_utils.py | 145 +++---- src/transformers/image_processing_utils.py | 12 +- .../wav2vec2/audio_processing_wav2vec2.py | 11 +- .../whisper/audio_processing_whisper.py | 15 +- src/transformers/preprocessing_base.py | 4 + tests/test_audio_processing_common.py | 360 ++++++++++++++++++ 8 files changed, 628 insertions(+), 116 deletions(-) create mode 100644 src/transformers/audio_processing_backends.py create mode 100644 tests/test_audio_processing_common.py diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c66b077cac36..9c86fd7f83c8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -355,6 +355,7 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: + _import_structure["audio_processing_backends"] = ["NumpyBackend", "TorchBackend"] _import_structure["model_debugging_utils"] = [ "model_addition_debugger_context", ] @@ -477,6 +478,8 @@ if TYPE_CHECKING: # All modeling imports # Models + from .audio_processing_backends import NumpyBackend as NumpyBackend + from .audio_processing_backends import TorchBackend as TorchBackend from .backbone_utils import BackboneConfigMixin, BackboneMixin from .cache_utils import Cache as Cache from .cache_utils import DynamicCache as DynamicCache diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py new file mode 100644 index 000000000000..a6276af11644 --- /dev/null +++ b/src/transformers/audio_processing_backends.py @@ -0,0 +1,194 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +from .audio_processing_utils import BaseAudioProcessor +from .feature_extraction_utils import BatchFeature +from .utils import logging +from .utils.import_utils import requires + + +logger = logging.get_logger(__name__) + + +class NumpyBackend(BaseAudioProcessor): + """NumPy backend for portable CPU-only audio processing.""" + + @property + def backend(self) -> str: + return "numpy" + + def process_audio(self, audio_el): + """ + Process a single raw audio input into a np.ndarray. + + Handles mono conversion (averaging channels) and ensures numpy format. + """ + if not isinstance(audio_el, np.ndarray): + audio_el = np.asarray(audio_el) + + if self.force_mono: + audio_el = audio_el.mean(axis=1) if audio_el.ndim > 1 else audio_el + + return audio_el + + def pad(self, audio: np.ndarray, max_length: int) -> np.ndarray: + """Pad a single audio array to a target length using np.pad.""" + current_length = audio.shape[-1] + if current_length >= max_length: + return audio + + if self.padding_value is None: + raise ValueError( + "Asking to pad but the audio processor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." + ) + + pad_length = max_length - current_length + if self.padding_side == "right": + pad_width = [(0, 0)] * (audio.ndim - 1) + [(0, pad_length)] + elif self.padding_side == "left": + pad_width = [(0, 0)] * (audio.ndim - 1) + [(pad_length, 0)] + else: + raise ValueError(f"Invalid padding side: {self.padding_side}") + + return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) + + def _preprocess( + self, + audio: list[np.ndarray], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + **kwargs, + ) -> BatchFeature: + """Preprocess using NumPy backend: truncation, padding, stacking.""" + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + is_batched = len(audio) > 1 + + if truncation and max_length is None: + raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + + if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: + logger.warning( + f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " + f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." + ) + truncation = True + + if truncation: + audio = [audio_el[..., :max_length] for audio_el in audio] + + if max_length is None: + max_length = max(audio_el.shape[-1] for audio_el in audio) + + if padding: + audio = [self.pad(audio_el, max_length) for audio_el in audio] + + audio = np.stack(audio, axis=0) if return_tensors else audio + return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) + + +@requires(backends=("torch",)) +class TorchBackend(BaseAudioProcessor): + """Torch backend for audio processing.""" + + @property + def backend(self) -> str: + return "torch" + + def process_audio(self, audio_el): + """ + Process a single raw audio input into a torch.Tensor. + + Handles mono conversion (averaging channels) and numpy-to-torch conversion. + """ + import torch + + if self.force_mono: + audio_el = audio_el.mean(axis=1) if audio_el.ndim > 1 else audio_el + + if isinstance(audio_el, np.ndarray): + audio_el = torch.from_numpy(audio_el) + + return audio_el + + def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": + """Pad a single audio tensor to a target length using torch.nn.functional.pad.""" + import torch.nn.functional as F + + current_length = audio.shape[-1] + if current_length >= max_length: + return audio + + if self.padding_value is None: + raise ValueError( + "Asking to pad but the audio processor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." + ) + + if self.padding_side == "right": + pad_args = (0, max_length - current_length) + elif self.padding_side == "left": + pad_args = (max_length - current_length, 0) + else: + raise ValueError(f"Invalid padding side: {self.padding_side}") + + return F.pad(audio, pad_args, "constant", self.padding_value) + + def _preprocess( + self, + audio: list["torch.Tensor"], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + **kwargs, + ) -> BatchFeature: + """Preprocess using Torch backend: truncation, padding, stacking.""" + import torch + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + is_batched = len(audio) > 1 + + if truncation and max_length is None: + raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + + if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: + logger.warning( + f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " + f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." + ) + truncation = True + + if truncation: + audio = [audio_el[..., :max_length] for audio_el in audio] + + if max_length is None: + max_length = max(audio_el.shape[-1] for audio_el in audio) + + if padding: + audio = [self.pad(audio_el, max_length) for audio_el in audio] + + audio = torch.stack(audio, dim=0) if return_tensors else audio + return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 6095097526c1..4f1cbc1d6939 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,29 +13,20 @@ # limitations under the License. from functools import lru_cache -from typing import Optional, Union, Unpack +from typing import Unpack -import numpy as np from huggingface_hub.dataclasses import validate_typed_dict from .audio_processing_base import AudioProcessingMixin from .audio_utils import AudioInput, make_list_of_audio from .feature_extraction_utils import BatchFeature -from .image_utils import validate_kwargs from .processing_utils import AudioKwargs from .utils import TensorType, logging -from .utils.import_utils import is_torch_available, requires - - -if is_torch_available(): - import torch - import torch.nn.functional as F logger = logging.get_logger(__name__) -@requires(backends=("torch",)) class BaseAudioProcessor(AudioProcessingMixin): model_input_names = ["audio"] valid_kwargs = AudioKwargs @@ -56,15 +46,39 @@ def __init__( super().__init__(**kwargs) - kwargs = self.filter_out_unused_kwargs(kwargs) - self._init_kwargs_from_valid_kwargs(kwargs) - def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: return self.preprocess(audio, *args, **kwargs) + def process_audio(self, *args, **kwargs): + """ + Process a single raw audio input into the backend's working format. + + Implemented by backend subclasses (e.g., `TorchBackend`). Converts a raw input + (NumPy array) to the backend's internal format (e.g., `torch.Tensor`), handles + mono conversion if needed. + """ + raise NotImplementedError + + def _preprocess(self, *args, **kwargs): + """ + Perform the actual batch audio preprocessing (truncation, padding, stacking). + + Implemented by backend subclasses (e.g., `TorchBackend`). Receives a list of + already-prepared audio tensors and applies the configured preprocessing operations. + Returns a `BatchFeature` with the processed audio values. + """ + raise NotImplementedError + + def pad(self, *args, **kwargs): + """ + Pad a single audio tensor to a target length. + + Implemented by backend subclasses (e.g., `TorchBackend`). + """ + raise NotImplementedError + def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same - validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names) # Perform type validation on received kwargs validate_typed_dict(self.valid_kwargs, kwargs) @@ -90,11 +104,11 @@ def _further_process_kwargs( def _validate_preprocess_kwargs( self, - sample_rate: Optional[int] = None, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + sample_rate: int | None = None, + max_length: int | None = None, + truncation: bool | None = None, + pad_to_multiple_of: int | None = None, + return_tensors: str | TensorType | None = None, **kwargs, ): """ @@ -112,13 +126,19 @@ def _preprocess_audio_like_inputs( self, audio: AudioInput, *args, - sample_rate: Optional[int] = None, + sample_rate: int | None = None, **kwargs: Unpack[AudioKwargs], ) -> BatchFeature: audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) return self._preprocess(audio, *args, **kwargs) - def _prepare_audio_like_inputs(self, audio: AudioInput, sample_rate: Optional[int] = None) -> list["torch.Tensor"]: + def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = None) -> list: + """ + Prepare the audio structure for processing: handle URL inputs, validate sample rate, + and flatten into a list of audio arrays. + + Analogous to `_prepare_images_structure` in the image processing pipeline. + """ if not (isinstance(audio, str) or (isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio))): # NOTE: we want to force the user to either: # 1. pass the sample rate when provided audio is array-type, to avoid silent errors that might be hard to debug @@ -139,71 +159,18 @@ def _prepare_audio_like_inputs(self, audio: AudioInput, sample_rate: Optional[in audio = [audio] audio = make_list_of_audio(audio) - - if self.force_mono: - # TODO: audio proc, to change - audio = [a.mean(axis=1) if a.ndim > 1 else a for a in audio] - - audio = [torch.from_numpy(audio_el) if isinstance(audio_el, np.ndarray) else audio_el for audio_el in audio] - return audio - def _preprocess( - self, - audio: list["torch.Tensor"], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - **kwargs, - ) -> BatchFeature: - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - is_batched = len(audio) > 1 - - if truncation and max_length is None: - raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") - - if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: - logger.warning( - f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " - f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." - ) - truncation = True - - if truncation: - audio = [audio_el[..., :max_length] for audio_el in audio] - - if max_length is None: - max_length = max(audio_el.shape[-1] for audio_el in audio) - - if padding: - audio = [self.pad(audio_el, max_length) for audio_el in audio] - - audio = torch.stack(audio, dim=0) if return_tensors else audio - return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) - - def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": - current_length = audio.shape[-1] - if current_length >= max_length: - return audio - - if self.padding_value is None: - raise ValueError( - "Asking to pad but the audio processor does not have a padding value. Please select a value to use" - " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." - ) - - if self.padding_side == "right": - pad_args = (0, max_length - current_length) - elif self.padding_side == "left": - pad_args = (max_length - current_length, 0) - else: - raise ValueError(f"Invalid padding side: {self.padding_side}") + def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list: + """ + Prepare audio-like inputs for processing by structuring and then converting each + audio item via `process_audio`. - return F.pad(audio, pad_args, "constant", self.padding_value) + Analogous to `_prepare_image_like_inputs` in the image processing pipeline. + """ + audio = self._prepare_audio_structure(audio, sample_rate=sample_rate) + audio = [self.process_audio(audio_el) for audio_el in audio] + return audio def to_dict(self): return super().to_dict() @@ -211,11 +178,11 @@ def to_dict(self): @lru_cache(maxsize=10) def validate_preprocess_arguments( - sample_rate: Optional[int] = None, - max_length: Optional[int] = None, - truncation: Optional[bool] = None, - pad_to_multiple_of: Optional[int] = None, - return_tensors: Optional[Union[str, TensorType]] = None, + sample_rate: int | None = None, + max_length: int | None = None, + truncation: bool | None = None, + pad_to_multiple_of: int | None = None, + return_tensors: str | TensorType | None = None, ): """ Checks validity of typically used arguments in a `BaseAudioProcessor` `preprocess` method. diff --git a/src/transformers/image_processing_utils.py b/src/transformers/image_processing_utils.py index 9756866b6333..9fb1d9761ee1 100644 --- a/src/transformers/image_processing_utils.py +++ b/src/transformers/image_processing_utils.py @@ -14,7 +14,6 @@ import math from collections.abc import Iterable -from copy import deepcopy from functools import partial from typing import Any @@ -193,20 +192,11 @@ class MyImageProcessor(TorchvisionBackend): def __init__(self, **kwargs: Unpack[ImagesKwargs]): super().__init__(**kwargs) - attributes = {} - for key in self.valid_kwargs.__annotations__: - kwarg = kwargs.pop(key, None) - if kwarg is not None: - attributes[key] = kwarg - else: - attributes[key] = deepcopy(getattr(self, key, None)) + attributes = {key: getattr(self, key) for key in self._valid_kwargs_names} attributes = self._standardize_kwargs(**attributes) for key, value in attributes.items(): setattr(self, key, value) - # get valid kwargs names - self._valid_kwargs_names = list(self.valid_kwargs.__annotations__.keys()) - def __call__(self, images: ImageInput, *args, **kwargs: Unpack[ImagesKwargs]) -> BatchFeature: """Preprocess an image or a batch of images.""" return self.preprocess(images, *args, **kwargs) diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py index bcbcdd3f3f63..3d7fa2817bad 100644 --- a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,20 +15,18 @@ Audio processor class for Wav2Vec2 """ -from typing import Optional, Union import torch -from ...audio_processing_utils import BaseAudioProcessor -from ...audio_utils import AudioInput +from ...audio_processing_backends import TorchBackend from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging +from ...utils import logging logger = logging.get_logger(__name__) -class Wav2Vec2AudioProcessor(BaseAudioProcessor): +class Wav2Vec2AudioProcessor(TorchBackend): r""" Constructs a Wav2Vec2 audio processor. @@ -69,7 +66,7 @@ def _preprocess( truncation, pad_to_multiple_of, return_tensors, - do_normalize: Optional[bool] = None, + do_normalize: bool | None = None, **kwargs, ) -> BatchFeature: if do_normalize is None: diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 1edfbc3d14b2..09ed776c0039 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,21 +15,19 @@ Audio processor class for Whisper """ -from typing import Optional, Union -import numpy as np import torch -from ...audio_processing_utils import BaseAudioProcessor -from ...audio_utils import AudioInput, mel_filter_bank +from ...audio_processing_backends import TorchBackend +from ...audio_utils import mel_filter_bank from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging +from ...utils import logging logger = logging.get_logger(__name__) -class WhisperAudioProcessor(BaseAudioProcessor): +class WhisperAudioProcessor(TorchBackend): r""" Constructs a Whisper audio processor. @@ -97,8 +94,8 @@ def _preprocess( truncation, pad_to_multiple_of, return_tensors, - do_normalize: Optional[bool] = None, - device: Optional[str] = "cpu", + do_normalize: bool | None = None, + device: str | None = "cpu", **kwargs, ) -> BatchFeature: # Default max_length to n_samples (chunk_length * sample_rate) diff --git a/src/transformers/preprocessing_base.py b/src/transformers/preprocessing_base.py index 1282d1d35ac1..d994f4811e32 100644 --- a/src/transformers/preprocessing_base.py +++ b/src/transformers/preprocessing_base.py @@ -79,6 +79,10 @@ def __init__(self, **kwargs): kwargs.pop(key, None) # Pop "processor_class", should not be saved in config kwargs.pop("processor_class", None) + + if hasattr(self, "valid_kwargs") and hasattr(self.valid_kwargs, "__annotations__"): + self._init_kwargs_from_valid_kwargs(kwargs) + # Additional attributes without default values for key, value in kwargs.items(): try: diff --git a/tests/test_audio_processing_common.py b/tests/test_audio_processing_common.py new file mode 100644 index 000000000000..9bc72955f5c6 --- /dev/null +++ b/tests/test_audio_processing_common.py @@ -0,0 +1,360 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile + +import numpy as np + +from transformers.testing_utils import ( + check_json_file_has_correct_format, + require_torch, +) +from transformers.utils import is_torch_available + + +if is_torch_available(): + import torch + + +def prepare_audio_inputs( + batch_size, + min_length=400, + max_length=2000, + num_channels=1, + equal_length=False, + numpify=False, + torchify=False, +): + """This function prepares a list of numpy arrays, or a list of PyTorch tensors if one specifies torchify=True. + + One can specify whether the audio inputs are of the same length or not. + """ + + assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" + + audio_inputs = [] + for _ in range(batch_size): + if equal_length: + length = max_length + else: + length = np.random.randint(min_length, max_length) + + if num_channels > 1: + audio_inputs.append(np.random.randn(length, num_channels).astype(np.float32)) + else: + audio_inputs.append(np.random.randn(length).astype(np.float32)) + + if torchify: + audio_inputs = [torch.from_numpy(audio) for audio in audio_inputs] + + return audio_inputs + + +class AudioProcessingTestMixin: + """Mixin class for testing audio processors, analogous to ``ImageProcessingTestMixin``. + + Subclasses must set the following in ``setUp``: + + * ``self.audio_processing_classes`` – ``dict[str, type]`` mapping backend name to class + * ``self.audio_processor_dict`` – kwargs to instantiate the processor + * ``self.audio_processor_tester`` – object with ``prepare_audio_inputs()`` and ``batch_size`` + """ + + # ─── Serialization ──────────────────────────────────────────────── + + def test_audio_processor_to_json_string(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor = audio_processing_class(**self.audio_processor_dict) + obj = json.loads(audio_processor.to_json_string()) + for key, value in self.audio_processor_dict.items(): + self.assertEqual(obj[key], value) + + def test_audio_processor_to_json_file(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor_first = audio_processing_class(**self.audio_processor_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + json_file_path = os.path.join(tmpdirname, "audio_processor.json") + audio_processor_first.to_json_file(json_file_path) + audio_processor_second = audio_processing_class.from_json_file(json_file_path) + + self.assertEqual(audio_processor_second.to_dict(), audio_processor_first.to_dict()) + + def test_audio_processor_from_and_save_pretrained(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor_first = audio_processing_class(**self.audio_processor_dict) + + with tempfile.TemporaryDirectory() as tmpdirname: + saved_file = audio_processor_first.save_pretrained(tmpdirname)[0] + check_json_file_has_correct_format(saved_file) + audio_processor_second = audio_processing_class.from_pretrained(tmpdirname) + + self.assertEqual(audio_processor_second.to_dict(), audio_processor_first.to_dict()) + + def test_init_without_params(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor = audio_processing_class() + self.assertIsNotNone(audio_processor) + + # ─── Backend equivalence ────────────────────────────────────────── + + @require_torch + def test_backends_equivalence(self): + if len(self.audio_processing_classes) < 2: + self.skipTest(reason="Skipping backends equivalence test as there are less than 2 backends") + + audio_input = np.random.randn(16000).astype(np.float32) + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + + encodings = {} + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor = audio_processing_class(**self.audio_processor_dict) + encodings[backend_name] = audio_processor(audio_input, sample_rate=sample_rate, return_tensors="pt") + + backend_names = list(encodings.keys()) + reference_backend = backend_names[0] + reference_key = list(encodings[reference_backend].keys())[0] + reference_values = encodings[reference_backend][reference_key] + for backend_name in backend_names[1:]: + torch.testing.assert_close(reference_values, encodings[backend_name][reference_key], atol=1e-5, rtol=1e-5) + + @require_torch + def test_backends_equivalence_batched(self): + if len(self.audio_processing_classes) < 2: + self.skipTest(reason="Skipping backends equivalence test as there are less than 2 backends") + + audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False) + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + + encodings = {} + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processor = audio_processing_class(**self.audio_processor_dict) + encodings[backend_name] = audio_processor(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + + backend_names = list(encodings.keys()) + reference_backend = backend_names[0] + reference_key = list(encodings[reference_backend].keys())[0] + reference_values = encodings[reference_backend][reference_key] + for backend_name in backend_names[1:]: + torch.testing.assert_close(reference_values, encodings[backend_name][reference_key], atol=1e-5, rtol=1e-5) + + # ─── Cross-backend save / load ──────────────────────────────────── + + def test_save_load_backends(self): + """Test that we can load audio processors saved by one backend with another.""" + if len(self.audio_processing_classes) < 2: + self.skipTest("Skipping backend save/load test as there are less than 2 backends") + + backend_names = list(self.audio_processing_classes.keys()) + + for backend1 in backend_names: + processor1 = self.audio_processing_classes[backend1](**self.audio_processor_dict) + + for backend2 in backend_names: + if backend1 == backend2: + continue + + with tempfile.TemporaryDirectory() as tmpdirname: + processor1.save_pretrained(tmpdirname) + processor2 = self.audio_processing_classes[backend2].from_pretrained(tmpdirname) + + dict1 = processor1.to_dict() + dict2 = processor2.to_dict() + common_keys = set(dict1) & set(dict2) + self.assertEqual( + {k: dict1[k] for k in common_keys}, + {k: dict2[k] for k in common_keys}, + f"Backends {backend1} and {backend2} differ in common keys", + ) + + # ─── Input type tests ───────────────────────────────────────────── + + @require_torch + def test_call_numpy(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False) + for audio in audio_inputs: + self.assertIsInstance(audio, np.ndarray) + + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + + # Test not batched input + encoded = audio_processing(audio_inputs[0], sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoded.keys())[0] + self.assertEqual(len(encoded[output_key].shape), 2) # (1, length) + + # Test batched + encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + self.assertEqual(encoded[output_key].shape[0], self.audio_processor_tester.batch_size) + + @require_torch + def test_call_pytorch(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False, torchify=True) + + for audio in audio_inputs: + self.assertIsInstance(audio, torch.Tensor) + + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + + # Test not batched input + encoded = audio_processing(audio_inputs[0], sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoded.keys())[0] + self.assertEqual(len(encoded[output_key].shape), 2) + + # Test batched + encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + self.assertEqual(encoded[output_key].shape[0], self.audio_processor_tester.batch_size) + + @require_torch + def test_call_multichannel_force_mono(self): + """Test that multi-channel audio is correctly averaged to mono.""" + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + processor_dict = {**self.audio_processor_dict, "force_mono": True} + audio_processing = audio_processing_class(**processor_dict) + + audio_inputs = prepare_audio_inputs( + batch_size=self.audio_processor_tester.batch_size, + num_channels=2, + min_length=self.audio_processor_tester.min_length, + max_length=self.audio_processor_tester.max_length, + equal_length=True, + ) + + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoded.keys())[0] + # After force_mono, output should be 2D: (batch, length) + self.assertEqual(len(encoded[output_key].shape), 2) + + # ─── Padding tests ──────────────────────────────────────────────── + + @require_torch + def test_padding_right(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + processor_dict = {**self.audio_processor_dict, "padding_side": "right"} + audio_processing = audio_processing_class(**processor_dict) + + audio_inputs = [ + np.random.randn(100).astype(np.float32), + np.random.randn(200).astype(np.float32), + ] + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoded.keys())[0] + self.assertEqual(encoded[output_key].shape[-1], 200) + + @require_torch + def test_padding_left(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + processor_dict = {**self.audio_processor_dict, "padding_side": "left"} + audio_processing = audio_processing_class(**processor_dict) + + audio_inputs = [ + np.random.randn(100).astype(np.float32), + np.random.randn(200).astype(np.float32), + ] + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoded.keys())[0] + self.assertEqual(encoded[output_key].shape[-1], 200) + + # ─── Truncation tests ───────────────────────────────────────────── + + @require_torch + def test_truncation(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + + audio_inputs = [ + np.random.randn(500).astype(np.float32), + np.random.randn(1000).astype(np.float32), + ] + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + encoded = audio_processing( + audio_inputs, sample_rate=sample_rate, truncation=True, max_length=300, return_tensors="pt" + ) + output_key = list(encoded.keys())[0] + self.assertEqual(encoded[output_key].shape[-1], 300) + + @require_torch + def test_truncation_without_max_length_raises(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + + audio_inputs = [np.random.randn(500).astype(np.float32)] + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + with self.assertRaises(ValueError): + audio_processing( + audio_inputs, sample_rate=sample_rate, truncation=True, max_length=None, return_tensors="pt" + ) + + # ─── pad_to_multiple_of ─────────────────────────────────────────── + + @require_torch + def test_pad_to_multiple_of(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + + audio_inputs = [np.random.randn(100).astype(np.float32)] + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + encoded = audio_processing( + audio_inputs, + sample_rate=sample_rate, + truncation=True, + max_length=150, + pad_to_multiple_of=64, + return_tensors="pt", + ) + output_key = list(encoded.keys())[0] + # max_length=150 rounded up to next multiple of 64 → 192 + self.assertEqual(encoded[output_key].shape[-1] % 64, 0) + + # ─── Sample rate validation ─────────────────────────────────────── + + def test_wrong_sample_rate_raises(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + + audio_inputs = [np.random.randn(100).astype(np.float32)] + expected_sr = self.audio_processor_dict.get("sample_rate", 16000) + with self.assertRaises(ValueError): + audio_processing(audio_inputs, sample_rate=expected_sr + 1000, return_tensors="pt") + + # ─── Dtype casting ──────────────────────────────────────────────── + + @require_torch + def test_cast_dtype(self): + for backend_name, audio_processing_class in self.audio_processing_classes.items(): + audio_processing = audio_processing_class(**self.audio_processor_dict) + + audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=True) + sample_rate = self.audio_processor_dict.get("sample_rate", 16000) + + encoding = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") + output_key = list(encoding.keys())[0] + self.assertEqual(encoding[output_key].dtype, torch.float32) + + encoding = encoding.to(torch.float16) + self.assertEqual(encoding[output_key].dtype, torch.float16) + + encoding = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt").to( + "cpu", torch.bfloat16 + ) + self.assertEqual(encoding[output_key].device, torch.device("cpu")) + self.assertEqual(encoding[output_key].dtype, torch.bfloat16) From 4d1af7d94e4df73c472eba4d8994d7ebfb6116a6 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 4 Mar 2026 19:16:17 +0100 Subject: [PATCH 04/28] keep drafting --- src/transformers/__init__.py | 4 +- src/transformers/audio_processing_backends.py | 419 +++++++++++++++--- src/transformers/audio_processing_base.py | 8 +- src/transformers/audio_processing_utils.py | 229 +++++++--- src/transformers/audio_utils.py | 136 ++++++ .../wav2vec2/audio_processing_wav2vec2.py | 85 +--- .../whisper/audio_processing_whisper.py | 155 +------ 7 files changed, 709 insertions(+), 327 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9c86fd7f83c8..82b20dc39684 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -355,7 +355,7 @@ _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] else: - _import_structure["audio_processing_backends"] = ["NumpyBackend", "TorchBackend"] + _import_structure["audio_processing_backends"] = ["NumpyAudioBackend", "NumpyBackend", "TorchAudioBackend", "TorchBackend"] _import_structure["model_debugging_utils"] = [ "model_addition_debugger_context", ] @@ -478,7 +478,9 @@ if TYPE_CHECKING: # All modeling imports # Models + from .audio_processing_backends import NumpyAudioBackend as NumpyAudioBackend from .audio_processing_backends import NumpyBackend as NumpyBackend + from .audio_processing_backends import TorchAudioBackend as TorchAudioBackend from .audio_processing_backends import TorchBackend as TorchBackend from .backbone_utils import BackboneConfigMixin, BackboneMixin from .cache_utils import Cache as Cache diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index a6276af11644..6cd61107a4fc 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -16,15 +16,20 @@ import numpy as np from .audio_processing_utils import BaseAudioProcessor +from .audio_utils import SpectrogramConfig, NormalizationConfig from .feature_extraction_utils import BatchFeature -from .utils import logging +from .utils import logging, is_torch_available from .utils.import_utils import requires logger = logging.get_logger(__name__) -class NumpyBackend(BaseAudioProcessor): +if is_torch_available(): + import torch + + +class NumpyAudioBackend(BaseAudioProcessor): """NumPy backend for portable CPU-only audio processing.""" @property @@ -67,47 +72,183 @@ def pad(self, audio: np.ndarray, max_length: int) -> np.ndarray: return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) - def _preprocess( + def pad_values( self, audio: list[np.ndarray], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - **kwargs, - ) -> BatchFeature: - """Preprocess using NumPy backend: truncation, padding, stacking.""" + *, + max_length: int | None = None, + truncation: bool = False, + pad_to_multiple_of: int | None = None, + ) -> list[np.ndarray]: + """Truncate and/or pad raw audio values (stage 3).""" if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - is_batched = len(audio) > 1 + if truncation: + if max_length is None: + raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") + audio = [a[..., :max_length] for a in audio] - if truncation and max_length is None: - raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + if max_length is None: + max_length = max(a.shape[-1] for a in audio) - if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: - logger.warning( - f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " - f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." - ) - truncation = True + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - if truncation: - audio = [audio_el[..., :max_length] for audio_el in audio] + audio = [self.pad(a, max_length) for a in audio] + return audio + def values_normalize( + self, + audio: list[np.ndarray], + *, + normalization_config: NormalizationConfig, + ) -> list[np.ndarray]: + """Normalize raw audio values (stage 4). Supports zero-mean-unit-var.""" + if normalization_config.method == "zero_mean_unit_var": + return [ + (a - np.mean(a)) / (np.std(a) + 1e-7) + for a in audio + ] + raise ValueError(f"Unknown normalization method: {normalization_config.method}") + + def extract_spectrogram( + self, + audio: list[np.ndarray], + *, + spectrogram_config: SpectrogramConfig, + ) -> list[np.ndarray]: + """Extract audio features (stage 5). Override in model-specific subclasses.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `extract_spectrogram`. " + "Override this method in your model-specific audio processor." + ) + + def feature_normalize( + self, + features: list[np.ndarray], + *, + feature_normalization_config: NormalizationConfig, + ) -> list[np.ndarray]: + """Normalize extracted features (stage 6). Supports zero-mean-unit-var.""" + if feature_normalization_config.method == "zero_mean_unit_var": + return [ + (f - np.mean(f)) / (np.std(f) + 1e-7) + for f in features + ] + raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") + + def pad_features( + self, + features: list[np.ndarray], + *, + max_length: int | None = None, + pad_to_multiple_of: int | None = None, + ) -> list[np.ndarray]: + """Pad 2D features to a target length (stage 7).""" if max_length is None: - max_length = max(audio_el.shape[-1] for audio_el in audio) + max_length = max(f.shape[-1] for f in features) - if padding: - audio = [self.pad(audio_el, max_length) for audio_el in audio] - - audio = np.stack(audio, axis=0) if return_tensors else audio - return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + padded = [] + for f in features: + current_length = f.shape[-1] + if current_length >= max_length: + padded.append(f[..., :max_length]) + else: + pad_length = max_length - current_length + if f.ndim == 2: + pad_width = [(0, 0), (0, pad_length)] + else: + pad_width = [(0, 0)] * (f.ndim - 1) + [(0, pad_length)] + padded.append(np.pad(f, pad_width, mode="constant", constant_values=0.0)) + return padded -@requires(backends=("torch",)) -class TorchBackend(BaseAudioProcessor): + def _preprocess( + self, + audio: list[np.ndarray], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + do_pad_values=None, + do_values_normalize=None, + normalization_config=None, + spectrogram_config=None, + do_feature_normalize=None, + feature_normalization_config=None, + do_pad_features=None, + **kwargs, + ) -> BatchFeature: + """Preprocess using NumPy backend: 5-stage pipeline (stages 3-7).""" + # Default do_values_normalize to True if a normalization config is provided + if do_values_normalize is None: + do_values_normalize = normalization_config is not None + + # Determine normalize_before_pad for values + values_normalize_before_pad = ( + normalization_config.normalize_before_pad if normalization_config is not None else True + ) + feature_normalize_before_pad = ( + feature_normalization_config.normalize_before_pad if feature_normalization_config is not None else True + ) + + # --- Stages 3 & 4: Values padding and normalization --- + if values_normalize_before_pad: + # Stage 4 before 3: normalize then pad + if do_values_normalize and normalization_config is not None: + audio = self.values_normalize(audio, normalization_config=normalization_config) + if do_pad_values or padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + else: + # Stage 3 before 4: pad then normalize + if do_pad_values or padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + if do_values_normalize and normalization_config is not None: + audio = self.values_normalize(audio, normalization_config=normalization_config) + + # --- Stage 5: Feature extraction --- + if spectrogram_config is not None: + features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + else: + features = audio + + # --- Stages 6 & 7: Feature normalization and padding --- + if feature_normalize_before_pad: + # Stage 6 before 7: normalize then pad + if do_feature_normalize and feature_normalization_config is not None: + features = self.feature_normalize( + features, feature_normalization_config=feature_normalization_config + ) + if do_pad_features: + features = self.pad_features( + features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of + ) + else: + # Stage 7 before 6: pad then normalize + if do_pad_features: + features = self.pad_features( + features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of + ) + if do_feature_normalize and feature_normalization_config is not None: + features = self.feature_normalize( + features, feature_normalization_config=feature_normalization_config + ) + + # Stack into batch + output_key = self.model_input_names[0] + stacked = np.stack(features, axis=0) if return_tensors else features + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +class TorchAudioBackend(BaseAudioProcessor): """Torch backend for audio processing.""" @property @@ -153,42 +294,208 @@ def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": return F.pad(audio, pad_args, "constant", self.padding_value) - def _preprocess( + def pad_values( self, audio: list["torch.Tensor"], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - **kwargs, - ) -> BatchFeature: - """Preprocess using Torch backend: truncation, padding, stacking.""" - import torch - + *, + max_length: int | None = None, + truncation: bool = False, + pad_to_multiple_of: int | None = None, + ) -> list["torch.Tensor"]: + """Truncate and/or pad raw audio values (stage 3).""" if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - is_batched = len(audio) > 1 + if truncation: + if max_length is None: + raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") + audio = [a[..., :max_length] for a in audio] + + if max_length is None: + max_length = max(a.shape[-1] for a in audio) + + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + audio = [self.pad(a, max_length) for a in audio] + return audio - if truncation and max_length is None: - raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.") + def values_normalize( + self, + audio: list["torch.Tensor"], + *, + normalization_config: NormalizationConfig, + ) -> list["torch.Tensor"]: + """Normalize raw audio values (stage 4). Supports zero-mean-unit-var.""" + import torch - if is_batched and not truncation and max_length is not None and max(audio_el.shape[-1] for audio_el in audio) > max_length: - logger.warning( - f"Truncation is set to False but `max_length` is set to {max_length} with the longest audio being " - f"{max(audio_el.shape[-1] for audio_el in audio)}. We will set truncation to True." + if normalization_config.method == "zero_mean_unit_var": + return [ + (a - torch.mean(a)) / (torch.std(a) + 1e-7) + for a in audio + ] + raise ValueError(f"Unknown normalization method: {normalization_config.method}") + + def extract_spectrogram( + self, + audio: list["torch.Tensor"], + *, + spectrogram_config: SpectrogramConfig, + ) -> list["torch.Tensor"]: + """Extract log-mel spectrogram features using the provided config and mel_filters.""" + import torch + + if not hasattr(self, "mel_filters"): + raise NotImplementedError( + f"{self.__class__.__name__} does not have `mel_filters`. " + "Either set `mel_filters` or override `extract_spectrogram`." ) - truncation = True - if truncation: - audio = [audio_el[..., :max_length] for audio_el in audio] + stft_cfg = spectrogram_config.stft_config + n_fft = stft_cfg.n_fft + hop_length = stft_cfg.hop_length + + waveform = torch.stack(audio, dim=0) + device = waveform.device + window = torch.hann_window(n_fft, device=device) + + stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** stft_cfg.power + + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + if waveform.dim() == 2: + max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] + log_spec = torch.maximum(log_spec, max_val - 8.0) + else: + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return [log_spec[i] for i in range(log_spec.shape[0])] + + def feature_normalize( + self, + features: list["torch.Tensor"], + *, + feature_normalization_config: NormalizationConfig, + ) -> list["torch.Tensor"]: + """Normalize extracted features (stage 6). Supports zero-mean-unit-var.""" + import torch + + if feature_normalization_config.method == "zero_mean_unit_var": + return [ + (f - torch.mean(f)) / (torch.std(f) + 1e-7) + for f in features + ] + raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") + + def pad_features( + self, + features: list["torch.Tensor"], + *, + max_length: int | None = None, + pad_to_multiple_of: int | None = None, + ) -> list["torch.Tensor"]: + """Pad 2D features to a target length (stage 7).""" + import torch.nn.functional as F if max_length is None: - max_length = max(audio_el.shape[-1] for audio_el in audio) + max_length = max(f.shape[-1] for f in features) - if padding: - audio = [self.pad(audio_el, max_length) for audio_el in audio] + if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + padded = [] + for f in features: + current_length = f.shape[-1] + if current_length >= max_length: + padded.append(f[..., :max_length]) + else: + pad_length = max_length - current_length + padded.append(F.pad(f, (0, pad_length), "constant", 0.0)) + return padded + + def _preprocess( + self, + audio: list["torch.Tensor"], + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + do_pad_values=None, + do_values_normalize=None, + normalization_config=None, + spectrogram_config=None, + do_feature_normalize=None, + feature_normalization_config=None, + do_pad_features=None, + **kwargs, + ) -> BatchFeature: + """Preprocess using Torch backend: 5-stage pipeline (stages 3-7).""" + import torch - audio = torch.stack(audio, dim=0) if return_tensors else audio - return BatchFeature(data={"audio": audio}, tensor_type=return_tensors) + # Default do_values_normalize to True if a normalization config is provided + if do_values_normalize is None: + do_values_normalize = normalization_config is not None + + # Determine normalize_before_pad for values + values_normalize_before_pad = ( + normalization_config.normalize_before_pad if normalization_config is not None else True + ) + feature_normalize_before_pad = ( + feature_normalization_config.normalize_before_pad if feature_normalization_config is not None else True + ) + + # --- Stages 3 & 4: Values padding and normalization --- + if values_normalize_before_pad: + # Stage 4 before 3: normalize then pad + if do_values_normalize and normalization_config is not None: + audio = self.values_normalize(audio, normalization_config=normalization_config) + if do_pad_values or padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + else: + # Stage 3 before 4: pad then normalize + if do_pad_values or padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + if do_values_normalize and normalization_config is not None: + audio = self.values_normalize(audio, normalization_config=normalization_config) + + # --- Stage 5: Feature extraction --- + if spectrogram_config is not None: + features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + else: + features = audio + + # --- Stages 6 & 7: Feature normalization and padding --- + if feature_normalize_before_pad: + # Stage 6 before 7: normalize then pad + if do_feature_normalize and feature_normalization_config is not None: + features = self.feature_normalize( + features, feature_normalization_config=feature_normalization_config + ) + if do_pad_features: + features = self.pad_features( + features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of + ) + else: + # Stage 7 before 6: pad then normalize + if do_pad_features: + features = self.pad_features( + features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of + ) + if do_feature_normalize and feature_normalization_config is not None: + features = self.feature_normalize( + features, feature_normalization_config=feature_normalization_config + ) + + # Stack into batch + output_key = self.model_input_names[0] + stacked = torch.stack(features, dim=0) if return_tensors else features + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) \ No newline at end of file diff --git a/src/transformers/audio_processing_base.py b/src/transformers/audio_processing_base.py index 2d4b06a68678..c8fd46a98bef 100644 --- a/src/transformers/audio_processing_base.py +++ b/src/transformers/audio_processing_base.py @@ -83,17 +83,19 @@ def get_audio_processor_dict( """ return cls._get_config_dict(pretrained_model_name_or_path, **kwargs) - def fetch_audio(self, audio_url_or_urls: str | list[str] | list[list[str]]): + def fetch_audio(self, audio_url_or_urls: str | list[str] | list[list[str]], sampling_rate: int | None = None): """ Convert a single or a list of urls into the corresponding `np.ndarray` objects. If a single url is passed, the return value will be a single object. If a list is passed a list of objects is returned. """ + if sampling_rate is None: + sampling_rate = getattr(self, "sample_rate", 16000) if isinstance(audio_url_or_urls, list): - return [self.fetch_audio(x) for x in audio_url_or_urls] + return [self.fetch_audio(x, sampling_rate=sampling_rate) for x in audio_url_or_urls] elif isinstance(audio_url_or_urls, str): - return load_audio(audio_url_or_urls) + return load_audio(audio_url_or_urls, sampling_rate=sampling_rate) elif is_valid_audio(audio_url_or_urls): return audio_url_or_urls else: diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 4f1cbc1d6939..662fbcd45e07 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import lru_cache from typing import Unpack from huggingface_hub.dataclasses import validate_typed_dict from .audio_processing_base import AudioProcessingMixin -from .audio_utils import AudioInput, make_list_of_audio +from .audio_utils import AudioInput, SpectrogramConfig, NormalizationConfig, make_list_of_audio, mel_filter_bank from .feature_extraction_utils import BatchFeature from .processing_utils import AudioKwargs from .utils import TensorType, logging @@ -27,33 +26,94 @@ logger = logging.get_logger(__name__) +class AudioProcessingKwargs(AudioKwargs, total=False): + """Extended keyword arguments for the audio processing pipeline.""" + + do_pad_values: bool | None + do_values_normalize: bool | None + normalization_config: dict | NormalizationConfig | None + spectrogram_config: dict | SpectrogramConfig | None + do_feature_normalize: bool | None + feature_normalization_config: dict | NormalizationConfig | None + do_pad_features: bool | None + do_resample: bool | None + + class BaseAudioProcessor(AudioProcessingMixin): model_input_names = ["audio"] - valid_kwargs = AudioKwargs + valid_kwargs = AudioProcessingKwargs unused_kwargs = None padding = True padding_side = "right" padding_value = 0.0 + max_length = None + truncation = None + + sample_rate: int = None + force_mono: bool = None + + # Pipeline stage defaults + do_pad_values = None + do_values_normalize = None + normalization_config = None + spectrogram_config = None + do_feature_normalize = None + feature_normalization_config = None + do_pad_features = None + do_resample = False def __init__( self, - sample_rate: int, - force_mono: bool, + sample_rate: int | None = None, + force_mono: bool | None = None, **kwargs, ): - self.sample_rate = sample_rate - self.force_mono = force_mono + if sample_rate is not None: + self.sample_rate = sample_rate + if self.sample_rate is None: + raise ValueError( + f"`sample_rate` must be set either as a class attribute on {self.__class__.__name__} " + "or passed to __init__." + ) + + if force_mono is not None: + self.force_mono = force_mono + if self.force_mono is None: + raise ValueError( + f"`force_mono` must be set either as a class attribute on {self.__class__.__name__} " + "or passed to __init__." + ) super().__init__(**kwargs) - def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: + # Standardize init attributes (coerce dicts to config dataclasses) + attributes = {key: getattr(self, key) for key in self._valid_kwargs_names} + attributes = self._standardize_kwargs(**attributes) + for key, value in attributes.items(): + setattr(self, key, value) + + # Derive max_length and mel_filters from spectrogram_config + if self.spectrogram_config is not None: + sc = self.spectrogram_config + if not hasattr(self, "mel_filters"): + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + sc.stft_config.n_fft // 2, + num_mel_filters=sc.mel_scale_config.n_mels, + min_frequency=sc.mel_scale_config.f_min, + max_frequency=sc.mel_scale_config.f_max if sc.mel_scale_config.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=sc.mel_scale_config.norm, + mel_scale=sc.mel_scale_config.mel_scale, + ) + + def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: return self.preprocess(audio, *args, **kwargs) def process_audio(self, *args, **kwargs): """ Process a single raw audio input into the backend's working format. - Implemented by backend subclasses (e.g., `TorchBackend`). Converts a raw input + Implemented by backend subclasses (e.g., `TorchAudioBackend`). Converts a raw input (NumPy array) to the backend's internal format (e.g., `torch.Tensor`), handles mono conversion if needed. """ @@ -61,9 +121,9 @@ def process_audio(self, *args, **kwargs): def _preprocess(self, *args, **kwargs): """ - Perform the actual batch audio preprocessing (truncation, padding, stacking). + Perform the actual batch audio preprocessing pipeline (stages 3-7). - Implemented by backend subclasses (e.g., `TorchBackend`). Receives a list of + Implemented by backend subclasses (e.g., `TorchAudioBackend`). Receives a list of already-prepared audio tensors and applies the configured preprocessing operations. Returns a `BatchFeature` with the processed audio values. """ @@ -73,33 +133,84 @@ def pad(self, *args, **kwargs): """ Pad a single audio tensor to a target length. - Implemented by backend subclasses (e.g., `TorchBackend`). + Implemented by backend subclasses (e.g., `TorchAudioBackend`). + """ + raise NotImplementedError + + def pad_values(self, *args, **kwargs): + """ + Pad raw audio values to a target length (pipeline stage 3). + + Implemented by backend subclasses. + """ + raise NotImplementedError + + def values_normalize(self, *args, **kwargs): + """ + Normalize raw audio values (pipeline stage 4). + + Implemented by backend subclasses. """ raise NotImplementedError - def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: - # args are not validated, but their order in the `preprocess` and `_preprocess` signatures must be the same + def extract_spectrogram(self, *args, **kwargs): + """ + Extract spectrogram from audio (pipeline stage 5). + + Implemented by model-specific processor subclasses. + """ + raise NotImplementedError + def feature_normalize(self, *args, **kwargs): + """ + Normalize extracted features (pipeline stage 6). + + Implemented by backend subclasses. + """ + raise NotImplementedError + + def pad_features(self, *args, **kwargs): + """ + Pad extracted features to a target length (pipeline stage 7). + + Implemented by backend subclasses. + """ + raise NotImplementedError + + def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: + """ + Preprocess an audio or a batch of audio. + """ # Perform type validation on received kwargs validate_typed_dict(self.valid_kwargs, kwargs) - # Set default kwargs from self. This ensures that if a kwarg is not provided - # by the user, it gets its default value from the instance, or is set to None. + # Set default kwargs from self. for kwarg_name in self._valid_kwargs_names: kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) - # Update kwargs that need further processing before being validated - kwargs = self._further_process_kwargs(**kwargs) + # Standardize kwargs (coerce dicts to config dataclasses) + kwargs = self._standardize_kwargs(**kwargs) # Validate kwargs self._validate_preprocess_kwargs(**kwargs) return self._preprocess_audio_like_inputs(audio, *args, **kwargs) - def _further_process_kwargs( + def _standardize_kwargs( self, **kwargs, ) -> dict: + """Coerce dict configs to their dataclass form.""" + if isinstance(kwargs.get("normalization_config"), dict): + kwargs["normalization_config"] = NormalizationConfig.from_dict(kwargs["normalization_config"]) + if isinstance(kwargs.get("spectrogram_config"), dict): + kwargs["spectrogram_config"] = SpectrogramConfig.from_dict( + kwargs["spectrogram_config"] + ) + if isinstance(kwargs.get("feature_normalization_config"), dict): + kwargs["feature_normalization_config"] = NormalizationConfig.from_dict( + kwargs["feature_normalization_config"] + ) return kwargs def _validate_preprocess_kwargs( @@ -109,40 +220,52 @@ def _validate_preprocess_kwargs( truncation: bool | None = None, pad_to_multiple_of: int | None = None, return_tensors: str | TensorType | None = None, + do_values_normalize: bool | None = None, + normalization_config: NormalizationConfig | None = None, + do_feature_normalize: bool | None = None, + feature_normalization_config: NormalizationConfig | None = None, **kwargs, ): - """ - Validate the kwargs for the preprocess method. - """ - validate_preprocess_arguments( - sample_rate=sample_rate, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=return_tensors, - ) + """Validate the kwargs for the preprocess method.""" + if do_values_normalize and normalization_config is None: + raise ValueError( + "`do_values_normalize=True` requires `normalization_config` to be set." + ) + if do_feature_normalize and feature_normalization_config is None: + raise ValueError( + "`do_feature_normalize=True` requires `feature_normalization_config` to be set." + ) + if truncation and max_length is None: + raise ValueError( + "When setting `truncation=True`, make sure that `max_length` is defined." + ) def _preprocess_audio_like_inputs( self, audio: AudioInput, *args, sample_rate: int | None = None, - **kwargs: Unpack[AudioKwargs], + **kwargs: Unpack[AudioProcessingKwargs], ) -> BatchFeature: audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) return self._preprocess(audio, *args, **kwargs) def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = None) -> list: """ - Prepare the audio structure for processing: handle URL inputs, validate sample rate, + Prepare the audio structure for processing: fetch URL inputs, validate sample rate, and flatten into a list of audio arrays. Analogous to `_prepare_images_structure` in the image processing pipeline. """ - if not (isinstance(audio, str) or (isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio))): - # NOTE: we want to force the user to either: - # 1. pass the sample rate when provided audio is array-type, to avoid silent errors that might be hard to debug - # 2. pass url-type audio inputs, that we can load in the correct sample rate directly + is_url_input = isinstance(audio, str) or ( + isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio) + ) + + if is_url_input: + # URL inputs: load directly at the correct sample rate + audio = self.fetch_audio(audio) + else: + # Array inputs: validate that the user-provided sample rate matches the model's if sample_rate is not None: if sample_rate != self.sample_rate: raise ValueError( @@ -155,8 +278,6 @@ def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = f"It is strongly recommended to pass the `sample_rate` argument to `{self.__class__.__name__}()`. " "Failing to do so can result in silent errors that might be hard to debug." ) - elif isinstance(audio, str): - audio = [audio] audio = make_list_of_audio(audio) return audio @@ -173,19 +294,21 @@ def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int return audio def to_dict(self): - return super().to_dict() - - -@lru_cache(maxsize=10) -def validate_preprocess_arguments( - sample_rate: int | None = None, - max_length: int | None = None, - truncation: bool | None = None, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, -): - """ - Checks validity of typically used arguments in a `BaseAudioProcessor` `preprocess` method. - Raises `ValueError` if arguments incompatibility is caught. - """ - pass + output = super().to_dict() + # Serialize config dataclasses to plain dicts for JSON persistence + for key in ("normalization_config", "spectrogram_config", "feature_normalization_config"): + if key in output and hasattr(output[key], "to_dict"): + output[key] = output[key].to_dict() + + # Filter out None values that are class defaults + filtered_dict = {} + for key, value in output.items(): + if value is None: + class_default = getattr(type(self), key, "NOT_FOUND") + # Keep None if user explicitly set it (class default is non-None) + if class_default != "NOT_FOUND" and class_default is not None: + filtered_dict[key] = value + else: + filtered_dict[key] = value + + return filtered_dict diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 85b56634afe7..679835b443ec 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -22,6 +22,7 @@ import os import warnings from collections.abc import Sequence +from dataclasses import dataclass, field, fields from io import BytesIO from typing import TYPE_CHECKING, Any, Union @@ -57,6 +58,141 @@ AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]] +@dataclass(frozen=True) +class StftConfig: + """Configuration for Short-Time Fourier Transform. + + Uses torchaudio parameter naming conventions. See + `torchaudio.transforms.MelSpectrogram` for reference. + """ + + n_fft: int = 400 + win_length: int | None = None + hop_length: int | None = None + window_fn: str = "hann_window" + wkwargs: dict | None = None + power: float = 2.0 + center: bool = True + pad_mode: str = "reflect" + normalized: bool = False + onesided: bool | None = None + pad: int = 0 + + def to_dict(self) -> dict: + return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} + + @classmethod + def from_dict(cls, d: dict) -> "StftConfig": + valid_keys = {f.name for f in fields(cls)} + return cls(**{k: v for k, v in d.items() if k in valid_keys}) + + +@dataclass(frozen=True) +class MelScaleConfig: + """Configuration for mel filterbank. + + Uses torchaudio parameter naming conventions. See + `torchaudio.transforms.MelSpectrogram` for reference. + """ + + n_mels: int = 128 + f_min: float = 0.0 + f_max: float | None = None + mel_scale: str = "htk" + norm: str | None = None + + def to_dict(self) -> dict: + return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} + + @classmethod + def from_dict(cls, d: dict) -> "MelScaleConfig": + valid_keys = {f.name for f in fields(cls)} + return cls(**{k: v for k, v in d.items() if k in valid_keys}) + + +@dataclass(frozen=True) +class SpectrogramConfig: + """Configuration for spectrogram extraction, composed of STFT and mel scale sub-configs.""" + + stft_config: StftConfig = field(default_factory=StftConfig) + mel_scale_config: MelScaleConfig = field(default_factory=MelScaleConfig) + log_mode: str = "log10" + chunk_length: int | None = None + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key {key} not found in SpectrogramConfig.") + + def __iter__(self): + for f in fields(self): + val = getattr(self, f.name) + if val is not None: + if hasattr(val, "to_dict"): + yield f.name, val.to_dict() + else: + yield f.name, val + + def __eq__(self, other): + if isinstance(other, dict): + return dict(self) == other + if isinstance(other, SpectrogramConfig): + return tuple(getattr(self, f.name) for f in fields(self)) == tuple( + getattr(other, f.name) for f in fields(self) + ) + return NotImplemented + + def to_dict(self) -> dict: + return dict(self) + + @classmethod + def from_dict(cls, d: dict) -> "SpectrogramConfig": + stft_config = StftConfig.from_dict(d["stft_config"]) if "stft_config" in d else StftConfig() + mel_scale_config = MelScaleConfig.from_dict(d["mel_scale_config"]) if "mel_scale_config" in d else MelScaleConfig() + return cls( + stft_config=stft_config, + mel_scale_config=mel_scale_config, + log_mode=d.get("log_mode", "log10"), + chunk_length=d.get("chunk_length"), + ) + + +@dataclass(frozen=True) +class NormalizationConfig: + """Hashable dictionary to store audio normalization configuration.""" + + method: str = "zero_mean_unit_var" + normalize_before_pad: bool = True + + def __getitem__(self, key): + if hasattr(self, key): + return getattr(self, key) + raise KeyError(f"Key {key} not found in NormalizationConfig.") + + def __iter__(self): + for f in fields(self): + val = getattr(self, f.name) + if val is not None: + yield f.name, val + + def __eq__(self, other): + if isinstance(other, dict): + return dict(self) == other + if isinstance(other, NormalizationConfig): + return tuple(getattr(self, f.name) for f in fields(self)) == tuple( + getattr(other, f.name) for f in fields(self) + ) + return NotImplemented + + def to_dict(self) -> dict: + return dict(self) + + @classmethod + def from_dict(cls, d: dict) -> "NormalizationConfig": + valid_keys = {f.name for f in fields(cls)} + return cls(**{k: v for k, v in d.items() if k in valid_keys}) + + def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray: """ Loads `audio` to an np.ndarray object. diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py index 3d7fa2817bad..0258c133243a 100644 --- a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py @@ -11,89 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Audio processor class for Wav2Vec2 -""" +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import NormalizationConfig -import torch - -from ...audio_processing_backends import TorchBackend -from ...feature_extraction_utils import BatchFeature -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -class Wav2Vec2AudioProcessor(TorchBackend): - r""" - Constructs a Wav2Vec2 audio processor. - - This audio processor inherits from [`~audio_processing_utils.BaseAudioProcessor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - sample_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - do_normalize (`bool`, *optional*, defaults to `True`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance for some models, *e.g.*, - [wav2vec2-lv60](https://huggingface.co/models?search=lv60). - """ +class Wav2Vec2AudioProcessor(TorchAudioBackend): model_input_names = ["input_values", "attention_mask"] - def __init__( - self, - sample_rate: int = 16000, - do_normalize: bool = True, - force_mono: bool = True, - **kwargs, - ): - super().__init__( - sample_rate=sample_rate, - force_mono=force_mono, - **kwargs, - ) - self.do_normalize = do_normalize - - def _preprocess( - self, - audio: list[torch.Tensor], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - do_normalize: bool | None = None, - **kwargs, - ) -> BatchFeature: - if do_normalize is None: - do_normalize = self.do_normalize - - # Truncation and padding via base class - result = super()._preprocess( - audio, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=None, # we handle conversion after normalization - ) - - audio_tensors = result["audio"] - - if do_normalize: - audio_tensors = [self._zero_mean_unit_var_norm(t) for t in audio_tensors] - - input_values = torch.stack(audio_tensors, dim=0) if return_tensors else audio_tensors - return BatchFeature(data={"input_values": input_values}, tensor_type=return_tensors) - - @staticmethod - def _zero_mean_unit_var_norm(tensor: torch.Tensor) -> torch.Tensor: - """Zero-mean unit-variance normalize a tensor.""" - return (tensor - tensor.mean()) / torch.sqrt(tensor.var() + 1e-7) + sample_rate = 16000 + force_mono = True + do_values_normalize = True + normalization_config = NormalizationConfig(method="zero_mean_unit_var") __all__ = ["Wav2Vec2AudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 09ed776c0039..4b8a7013f677 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -11,148 +11,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Audio processor class for Whisper -""" +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -import torch - -from ...audio_processing_backends import TorchBackend -from ...audio_utils import mel_filter_bank -from ...feature_extraction_utils import BatchFeature -from ...utils import logging - - -logger = logging.get_logger(__name__) - - -class WhisperAudioProcessor(TorchBackend): - r""" - Constructs a Whisper audio processor. - - This audio processor inherits from [`~audio_processing_utils.BaseAudioProcessor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using PyTorch's `torch.stft`. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features (number of mel bins). - sample_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 160): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - chunk_length (`int`, *optional*, defaults to 30): - The maximum number of seconds of audio used to trim and pad sequences. - n_fft (`int`, *optional*, defaults to 400): - Size of the Fourier transform. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering (small Gaussian noise) to each frame. Use 0.0 for no dithering. - """ +class WhisperAudioProcessor(TorchAudioBackend): model_input_names = ["input_features"] - def __init__( - self, - feature_size: int = 80, - sample_rate: int = 16000, - hop_length: int = 160, - chunk_length: int = 30, - n_fft: int = 400, - dither: float = 0.0, - force_mono: bool = True, - **kwargs, - ): - super().__init__( - sample_rate=sample_rate, - force_mono=force_mono, - **kwargs, - ) - self.feature_size = feature_size - self.n_fft = n_fft - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_samples = chunk_length * sample_rate - self.nb_max_frames = self.n_samples // hop_length - self.dither = dither - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + n_fft // 2, - num_mel_filters=feature_size, - min_frequency=0.0, - max_frequency=8000.0, - sampling_rate=sample_rate, - norm="slaney", + sample_rate = 16000 + force_mono = True + truncation = True + max_length = 480000 # 30 seconds at 16000 Hz + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=400, + hop_length=160, + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, mel_scale="slaney", - ) - - def _preprocess( - self, - audio: list[torch.Tensor], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - do_normalize: bool | None = None, - device: str | None = "cpu", - **kwargs, - ) -> BatchFeature: - # Default max_length to n_samples (chunk_length * sample_rate) - if max_length is None: - max_length = self.n_samples - - # Use base class for truncation + padding - result = super()._preprocess( - audio, - padding=padding if padding is not None else True, - max_length=max_length, - truncation=truncation if truncation is not None else True, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors=None, # we handle conversion after feature extraction - ) - - audio_tensors = result["audio"] - - # Zero-mean unit-variance normalization (before spectrogram) - if do_normalize: - audio_tensors = [(t - t.mean()) / torch.sqrt(t.var() + 1e-7) for t in audio_tensors] - - # Stack into batch for spectrogram extraction - waveform_batch = torch.stack(audio_tensors, dim=0).to(device, torch.float32) - - # Extract log-mel spectrogram - input_features = self._extract_fbank_features(waveform_batch, device) - - return BatchFeature(data={"input_features": input_features}, tensor_type=return_tensors) - - def _extract_fbank_features(self, waveform: torch.Tensor, device: str = "cpu") -> torch.Tensor: - """ - Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation. - """ - window = torch.hann_window(self.n_fft, device=device) - - if self.dither != 0.0: - waveform = waveform + self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) - - stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 - - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - mel_spec = mel_filters.T @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - if waveform.dim() == 2: - max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] - log_spec = torch.maximum(log_spec, max_val - 8.0) - else: - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - - if device != "cpu": - log_spec = log_spec.detach().cpu() - - return log_spec + ), + log_mode="log10", + chunk_length=30, + ) __all__ = ["WhisperAudioProcessor"] From b2397b599e1aa6f6287165a246291ae59ac74b35 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 5 Mar 2026 15:18:57 +0100 Subject: [PATCH 05/28] starting to look like something --- src/transformers/audio_processing_backends.py | 77 ++++++-- src/transformers/audio_processing_utils.py | 4 + src/transformers/audio_utils.py | 12 ++ ...rocessing_audio_spectrogram_transformer.py | 82 +++++++++ .../models/clap/audio_processing_clap.py | 162 +++++++++++++++++ .../models/clvp/audio_processing_clvp.py | 88 +++++++++ .../models/dac/audio_processing_dac.py | 24 +++ .../models/dia/audio_processing_dia.py | 25 +++ .../encodec/audio_processing_encodec.py | 24 +++ .../gemma3n/audio_processing_gemma3n.py | 170 ++++++++++++++++++ .../audio_processing_granite_speech.py | 79 ++++++++ .../audio_processing_kyutai_speech_to_text.py | 53 ++++++ .../models/lasr/audio_processing_lasr.py | 88 +++++++++ .../audio_processing_musicgen_melody.py | 106 +++++++++++ .../parakeet/audio_processing_parakeet.py | 119 ++++++++++++ .../pe_audio/audio_processing_pe_audio.py | 23 +++ .../audio_processing_phi4_multimodal.py | 117 ++++++++++++ .../pop2piano/audio_processing_pop2piano.py | 34 ++++ .../audio_processing_seamless_m4t.py | 96 ++++++++++ .../audio_processing_speech_to_text.py | 95 ++++++++++ .../speecht5/audio_processing_speecht5.py | 23 +++ .../univnet/audio_processing_univnet.py | 118 ++++++++++++ ...processing_vibevoice_acoustic_tokenizer.py | 47 +++++ .../audio_processing_voxtral_realtime.py | 38 ++++ .../wav2vec2/audio_processing_wav2vec2.py | 2 - .../whisper/audio_processing_whisper.py | 3 +- 26 files changed, 1688 insertions(+), 21 deletions(-) create mode 100644 src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py create mode 100644 src/transformers/models/clap/audio_processing_clap.py create mode 100644 src/transformers/models/clvp/audio_processing_clvp.py create mode 100644 src/transformers/models/dac/audio_processing_dac.py create mode 100644 src/transformers/models/dia/audio_processing_dia.py create mode 100644 src/transformers/models/encodec/audio_processing_encodec.py create mode 100644 src/transformers/models/gemma3n/audio_processing_gemma3n.py create mode 100644 src/transformers/models/granite_speech/audio_processing_granite_speech.py create mode 100644 src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py create mode 100644 src/transformers/models/lasr/audio_processing_lasr.py create mode 100644 src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py create mode 100644 src/transformers/models/parakeet/audio_processing_parakeet.py create mode 100644 src/transformers/models/pe_audio/audio_processing_pe_audio.py create mode 100644 src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py create mode 100644 src/transformers/models/pop2piano/audio_processing_pop2piano.py create mode 100644 src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py create mode 100644 src/transformers/models/speech_to_text/audio_processing_speech_to_text.py create mode 100644 src/transformers/models/speecht5/audio_processing_speecht5.py create mode 100644 src/transformers/models/univnet/audio_processing_univnet.py create mode 100644 src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py create mode 100644 src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 6cd61107a4fc..c23b2272afc1 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -19,7 +19,6 @@ from .audio_utils import SpectrogramConfig, NormalizationConfig from .feature_extraction_utils import BatchFeature from .utils import logging, is_torch_available -from .utils.import_utils import requires logger = logging.get_logger(__name__) @@ -107,7 +106,7 @@ def values_normalize( """Normalize raw audio values (stage 4). Supports zero-mean-unit-var.""" if normalization_config.method == "zero_mean_unit_var": return [ - (a - np.mean(a)) / (np.std(a) + 1e-7) + (a - np.mean(a)) / np.sqrt(np.var(a) + 1e-7) for a in audio ] raise ValueError(f"Unknown normalization method: {normalization_config.method}") @@ -118,11 +117,47 @@ def extract_spectrogram( *, spectrogram_config: SpectrogramConfig, ) -> list[np.ndarray]: - """Extract audio features (stage 5). Override in model-specific subclasses.""" - raise NotImplementedError( - f"{self.__class__.__name__} does not implement `extract_spectrogram`. " - "Override this method in your model-specific audio processor." - ) + """Extract log-mel spectrogram features using the numpy spectrogram() function.""" + from .audio_utils import spectrogram as compute_spectrogram, window_function + + if not hasattr(self, "mel_filters"): + raise NotImplementedError( + f"{self.__class__.__name__} does not have `mel_filters`. " + "Either set `mel_filters` or override `extract_spectrogram`." + ) + + stft_cfg = spectrogram_config.stft_config + n_fft = stft_cfg.n_fft + hop_length = stft_cfg.hop_length if stft_cfg.hop_length is not None else n_fft // 4 + win_length = stft_cfg.win_length if stft_cfg.win_length is not None else n_fft + + # Build window — map torch names like "hann_window" to audio_utils names like "hann" + window_name = stft_cfg.window_fn.replace("_window", "") + window = window_function(win_length, window_name, periodic=stft_cfg.periodic) + + features = [] + for waveform in audio: + w = waveform + if spectrogram_config.waveform_scale is not None: + w = np.squeeze(w) * spectrogram_config.waveform_scale + spec = compute_spectrogram( + w, + window=window, + frame_length=win_length, + hop_length=hop_length, + fft_length=n_fft, + power=stft_cfg.power, + center=stft_cfg.center, + pad_mode=stft_cfg.pad_mode, + preemphasis=spectrogram_config.preemphasis, + remove_dc_offset=spectrogram_config.remove_dc_offset, + mel_filters=self.mel_filters, + mel_floor=spectrogram_config.mel_floor, + log_mel=spectrogram_config.log_mode if spectrogram_config.log_mode != "log10" else "log10", + ) + features.append(spec) + + return features def feature_normalize( self, @@ -133,7 +168,7 @@ def feature_normalize( """Normalize extracted features (stage 6). Supports zero-mean-unit-var.""" if feature_normalization_config.method == "zero_mean_unit_var": return [ - (f - np.mean(f)) / (np.std(f) + 1e-7) + (f - np.mean(f)) / np.sqrt(np.var(f) + 1e-7) for f in features ] raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") @@ -217,12 +252,13 @@ def _preprocess( # --- Stage 5: Feature extraction --- if spectrogram_config is not None: features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + if self.transpose_features: + features = [f.T for f in features] else: features = audio # --- Stages 6 & 7: Feature normalization and padding --- if feature_normalize_before_pad: - # Stage 6 before 7: normalize then pad if do_feature_normalize and feature_normalization_config is not None: features = self.feature_normalize( features, feature_normalization_config=feature_normalization_config @@ -232,7 +268,6 @@ def _preprocess( features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of ) else: - # Stage 7 before 6: pad then normalize if do_pad_features: features = self.pad_features( features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of @@ -245,6 +280,8 @@ def _preprocess( # Stack into batch output_key = self.model_input_names[0] stacked = np.stack(features, axis=0) if return_tensors else features + if self.add_channel_dim and isinstance(stacked, np.ndarray): + stacked = stacked[:, np.newaxis, :] return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) @@ -331,7 +368,7 @@ def values_normalize( if normalization_config.method == "zero_mean_unit_var": return [ - (a - torch.mean(a)) / (torch.std(a) + 1e-7) + (a - torch.mean(a)) / torch.sqrt(torch.var(a, correction=0) + 1e-7) for a in audio ] raise ValueError(f"Unknown normalization method: {normalization_config.method}") @@ -366,11 +403,15 @@ def extract_spectrogram( mel_spec = mel_filters.T @ magnitudes log_spec = torch.clamp(mel_spec, min=1e-10).log10() - if waveform.dim() == 2: + if spectrogram_config.global_log_mel_max is not None: + max_val = torch.tensor( + spectrogram_config.global_log_mel_max, device=log_spec.device, dtype=log_spec.dtype + ) + elif waveform.dim() == 2: max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] - log_spec = torch.maximum(log_spec, max_val - 8.0) else: - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + max_val = log_spec.max() + log_spec = torch.maximum(log_spec, max_val - 8.0) log_spec = (log_spec + 4.0) / 4.0 return [log_spec[i] for i in range(log_spec.shape[0])] @@ -386,7 +427,7 @@ def feature_normalize( if feature_normalization_config.method == "zero_mean_unit_var": return [ - (f - torch.mean(f)) / (torch.std(f) + 1e-7) + (f - torch.mean(f)) / torch.sqrt(torch.var(f, correction=0) + 1e-7) for f in features ] raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") @@ -470,12 +511,13 @@ def _preprocess( # --- Stage 5: Feature extraction --- if spectrogram_config is not None: features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + if self.transpose_features: + features = [f.permute(*reversed(range(f.dim()))) for f in features] else: features = audio # --- Stages 6 & 7: Feature normalization and padding --- if feature_normalize_before_pad: - # Stage 6 before 7: normalize then pad if do_feature_normalize and feature_normalization_config is not None: features = self.feature_normalize( features, feature_normalization_config=feature_normalization_config @@ -485,7 +527,6 @@ def _preprocess( features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of ) else: - # Stage 7 before 6: pad then normalize if do_pad_features: features = self.pad_features( features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of @@ -498,4 +539,6 @@ def _preprocess( # Stack into batch output_key = self.model_input_names[0] stacked = torch.stack(features, dim=0) if return_tensors else features + if self.add_channel_dim and isinstance(stacked, torch.Tensor): + stacked = stacked.unsqueeze(1) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) \ No newline at end of file diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 662fbcd45e07..cd44d9b6e0bb 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -61,6 +61,9 @@ class BaseAudioProcessor(AudioProcessingMixin): feature_normalization_config = None do_pad_features = None do_resample = False + add_channel_dim = False + pad_to_multiple_of = None + transpose_features = False def __init__( self, @@ -104,6 +107,7 @@ def __init__( sampling_rate=self.sample_rate, norm=sc.mel_scale_config.norm, mel_scale=sc.mel_scale_config.mel_scale, + triangularize_in_mel_space=sc.mel_scale_config.triangularize_in_mel_space, ) def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 679835b443ec..5275e8a5b3e7 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -77,6 +77,7 @@ class StftConfig: normalized: bool = False onesided: bool | None = None pad: int = 0 + periodic: bool = True def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -100,6 +101,7 @@ class MelScaleConfig: f_max: float | None = None mel_scale: str = "htk" norm: str | None = None + triangularize_in_mel_space: bool = False def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -118,6 +120,11 @@ class SpectrogramConfig: mel_scale_config: MelScaleConfig = field(default_factory=MelScaleConfig) log_mode: str = "log10" chunk_length: int | None = None + global_log_mel_max: float | None = None + preemphasis: float | None = None + remove_dc_offset: bool = False + mel_floor: float = 1e-10 + waveform_scale: float | None = None def __getitem__(self, key): if hasattr(self, key): @@ -154,6 +161,11 @@ def from_dict(cls, d: dict) -> "SpectrogramConfig": mel_scale_config=mel_scale_config, log_mode=d.get("log_mode", "log10"), chunk_length=d.get("chunk_length"), + global_log_mel_max=d.get("global_log_mel_max"), + preemphasis=d.get("preemphasis"), + remove_dc_offset=d.get("remove_dc_offset", False), + mel_floor=d.get("mel_floor", 1e-10), + waveform_scale=d.get("waveform_scale"), ) diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py new file mode 100644 index 000000000000..f4bc248465b2 --- /dev/null +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -0,0 +1,82 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...feature_extraction_utils import BatchFeature + + +class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + max_length_frames = 1024 + transpose_features = True + + # AudioSet normalization constants + ast_mean = -4.2677393 + ast_std = 4.5689974 + + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + win_length=400, + hop_length=160, + window_fn="hann_window", + power=2.0, + center=False, + periodic=False, + ), + mel_scale_config=MelScaleConfig( + n_mels=128, + f_min=20.0, + f_max=8000.0, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ), + log_mode="log", + preemphasis=0.97, + remove_dc_offset=True, + mel_floor=1.192092955078125e-07, + ) + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Extract spectrogram via generic config-based API + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + + # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) + features = [f.T for f in features] + + # Pad or truncate to max_length_frames + padded = [] + for fbank in features: + n_frames = fbank.shape[0] + if n_frames < self.max_length_frames: + pad_amount = self.max_length_frames - n_frames + fbank = np.pad(fbank, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + elif n_frames > self.max_length_frames: + fbank = fbank[: self.max_length_frames, :] + padded.append(fbank) + + # Normalize with AudioSet stats + normalized = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] + + # Stack into batch + output_key = self.model_input_names[0] + stacked = np.stack(normalized, axis=0) if return_tensors else normalized + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py new file mode 100644 index 000000000000..4672057a1530 --- /dev/null +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -0,0 +1,162 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_utils import BatchFeature + + +class ClapAudioProcessor(NumpyAudioBackend): + sample_rate = 48000 + force_mono = True + n_fft = 1024 + hop_length = 480 + n_mels = 64 + f_min = 0 + f_max = 14000 + max_length_s = 10 + truncation_mode = "fusion" # "fusion" or "rand_trunc" + padding_mode = "repeatpad" # "repeatpad", "repeat", or "pad" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.nb_max_samples = self.max_length_s * self.sample_rate + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + self.n_fft // 2, + num_mel_filters=self.n_mels, + min_frequency=self.f_min, + max_frequency=self.f_max, + sampling_rate=self.sample_rate, + norm=None, + mel_scale="htk", + ) + self.mel_filters_slaney = mel_filter_bank( + num_frequency_bins=1 + self.n_fft // 2, + num_mel_filters=self.n_mels, + min_frequency=self.f_min, + max_frequency=self.f_max, + sampling_rate=self.sample_rate, + norm="slaney", + mel_scale="htk", + ) + + def _np_extract_fbank_features(self, waveform, mel_filters=None): + if mel_filters is None: + mel_filters = self.mel_filters + log_mel = spectrogram( + waveform, + window_function(self.n_fft, "hann"), + frame_length=self.n_fft, + hop_length=self.hop_length, + power=2.0, + mel_filters=mel_filters, + log_mel="dB", + ) + return log_mel.T + + def _random_mel_fusion(self, mel, total_frames, chunk_frames): + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) + if len(ranges[1]) == 0: + ranges[1] = [0] + if len(ranges[2]) == 0: + ranges[2] = [0] + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + mel_tensor = torch.tensor(mel[None, None, :]) + mel_shrink = torch.nn.functional.interpolate( + mel_tensor, size=[chunk_frames, 64], mode="bilinear", align_corners=False + ) + mel_shrink = mel_shrink[0][0].numpy() + return np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) + + def _get_input_mel(self, waveform, max_length, truncation, padding): + if waveform.shape[0] > max_length: + if truncation == "rand_trunc": + longer = True + overflow = len(waveform) - max_length + idx = np.random.randint(0, overflow + 1) + waveform = waveform[idx : idx + max_length] + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + elif truncation == "fusion": + mel = self._np_extract_fbank_features(waveform, self.mel_filters) + chunk_frames = max_length // self.hop_length + 1 + total_frames = mel.shape[0] + if chunk_frames == total_frames: + input_mel = np.stack([mel, mel, mel, mel], axis=0) + longer = False + else: + input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) + longer = True + else: + raise NotImplementedError(f"data_truncating {truncation} not implemented") + else: + longer = False + if waveform.shape[0] < max_length: + if padding == "repeat": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat + 1)[:max_length] + if padding == "repeatpad": + n_repeat = int(max_length / len(waveform)) + waveform = np.tile(waveform, n_repeat) + waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) + + if truncation == "fusion": + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) + input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) + else: + input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + + return input_mel, longer + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + truncation_mode = self.truncation_mode + padding_mode = self.padding_mode + nb_max_samples = max_length if max_length else self.nb_max_samples + + padded_inputs = [ + self._get_input_mel(np.squeeze(waveform), nb_max_samples, truncation_mode, padding_mode) + for waveform in audio + ] + + input_mel = [] + is_longer = [] + for mel, longer in padded_inputs: + input_mel.append(mel) + is_longer.append(longer) + + if truncation_mode == "fusion" and sum(is_longer) == 0: + rand_idx = np.random.randint(0, len(input_mel)) + is_longer[rand_idx] = True + + is_longer = [[longer] for longer in is_longer] + + input_features = {"input_features": input_mel, "is_longer": is_longer} + input_features = BatchFeature(input_features) + + if return_tensors is not None: + input_features = input_features.convert_to_tensors(return_tensors) + + return input_features + + +__all__ = ["ClapAudioProcessor"] diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py new file mode 100644 index 000000000000..fdf9810a1ac7 --- /dev/null +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...feature_extraction_utils import BatchFeature + + +class ClvpAudioProcessor(NumpyAudioBackend): + sample_rate = 22050 + force_mono = True + n_fft = 1024 + hop_length = 256 + n_mels = 80 + max_length = 132300 # 6 seconds at 22050 Hz + + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=1024, + hop_length=256, + window_fn="hann_window", + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=0.0, + f_max=8000.0, + norm="slaney", + mel_scale="htk", + ), + log_mode="log", + mel_floor=1e-5, + ) + + def __init__(self, mel_norms=None, **kwargs): + super().__init__(**kwargs) + self.mel_norms = mel_norms + + def extract_spectrogram(self, audio, *, spectrogram_config): + # Use the generic config-based API for the core spectrogram + features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config) + + # Apply mel_norms if provided + if self.mel_norms is not None: + features = [f / np.array(self.mel_norms)[:, None] for f in features] + + return features + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Determine the raw-audio target length + if max_length is None: + max_length = self.max_length + + # Truncate to max_length first + audio = [a[..., :max_length] for a in audio] + + # Pad raw audio: if padding=True, pad to longest in batch; otherwise pad to max_length + if padding is True or padding == "longest": + pad_length = max(a.shape[-1] for a in audio) + else: + pad_length = max_length + audio = self.pad_values(audio, max_length=pad_length, truncation=False, pad_to_multiple_of=pad_to_multiple_of) + + # Extract spectrogram via config-based API (with mel_norms applied) + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + + # Cast to float32 to match the legacy FeatureExtractor + features = [f.astype(np.float32) for f in features] + + output_key = self.model_input_names[0] + stacked = np.stack(features, axis=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["ClvpAudioProcessor"] diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py new file mode 100644 index 000000000000..f0a27bd57555 --- /dev/null +++ b/src/transformers/models/dac/audio_processing_dac.py @@ -0,0 +1,24 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import NumpyAudioBackend + + +class DacAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + add_channel_dim = True + + +__all__ = ["DacAudioProcessor"] diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py new file mode 100644 index 000000000000..e1b7b0301e71 --- /dev/null +++ b/src/transformers/models/dia/audio_processing_dia.py @@ -0,0 +1,25 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import NumpyAudioBackend + + +class DiaAudioProcessor(NumpyAudioBackend): + sample_rate = 44100 + force_mono = True + add_channel_dim = True + pad_to_multiple_of = 512 + + +__all__ = ["DiaAudioProcessor"] diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py new file mode 100644 index 000000000000..022a7e145313 --- /dev/null +++ b/src/transformers/models/encodec/audio_processing_encodec.py @@ -0,0 +1,24 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import NumpyAudioBackend + + +class EncodecAudioProcessor(NumpyAudioBackend): + sample_rate = 24000 + force_mono = True + add_channel_dim = True + + +__all__ = ["EncodecAudioProcessor"] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py new file mode 100644 index 000000000000..27d9a0898f3a --- /dev/null +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -0,0 +1,170 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Sequence + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...feature_extraction_utils import BatchFeature + + +def _create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, fft_length, norm=None): + """HTK-style mel filterbank matrix matching Gemma3n FE implementation.""" + all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length) + m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) + m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) + m_pts = np.linspace(m_min, m_max, n_mels + 2) + f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) + f_diff = f_pts[1:] - f_pts[:-1] + slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) + zero = np.zeros(1, dtype=np.float32) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + fb = np.maximum(zero, np.minimum(down_slopes, up_slopes)) + if norm == "slaney": + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + fb *= np.expand_dims(enorm, 0) + return fb + + +def _unfold(array, dimension, size, step): + """NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim.""" + if array.ndim == 1: + array = array[np.newaxis, :] + batch_size, original_length = array.shape + num_frames = (original_length - size) // step + 1 + if num_frames <= 0: + return np.zeros((batch_size, 0, size), dtype=array.dtype) + output_shape = (batch_size, num_frames, size) + output_strides = (array.strides[0], array.strides[1] * step, array.strides[1]) + return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides) + + +class Gemma3nAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + frame_length = 512 # 32ms at 16kHz + hop_length = 160 # 10ms at 16kHz + n_mels = 128 + min_frequency = 125.0 + max_frequency = 7600.0 + preemphasis_coeff = 0.97 + preemphasis_htk_flavor = True + fft_overdrive = True + mel_floor = 1e-5 + max_length = 480000 # 30 seconds + truncation = True + pad_to_multiple_of = 128 + + def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): + super().__init__(**kwargs) + + fft_length = 2 ** math.ceil(math.log2(self.frame_length)) + if self.fft_overdrive: + fft_length *= 2 + self.fft_length = fft_length + + hann_arange = np.arange(self.frame_length, dtype=np.float32) + self.window = (0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))).astype(np.float32) + + self.mel_filters = _create_fb_matrix( + n_freqs=self.fft_length // 2 + 1, + f_min=self.min_frequency, + f_max=self.max_frequency, + n_mels=self.n_mels, + sample_rate=self.sample_rate, + fft_length=self.fft_length, + norm=None, + ) + + if per_bin_mean is not None: + self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, self.n_mels) + else: + self.per_bin_mean = None + + if per_bin_stddev is not None: + self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, self.n_mels) + else: + self.per_bin_stddev = None + + def _extract_spectrogram(self, waveform): + if waveform.ndim == 1: + waveform = np.expand_dims(waveform, axis=0) + + frame_size_for_unfold = self.frame_length + 1 + frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length) + + if self.preemphasis_coeff > 0.0: + if self.preemphasis_htk_flavor: + first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis_coeff) + rest_in_frame = ( + frames_to_process[..., 1:-1] - self.preemphasis_coeff * frames_to_process[..., :-2] + ) + frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + else: + frames = frames_to_process[..., 1:] - self.preemphasis_coeff * frames_to_process[..., :-1] + else: + frames = frames_to_process[..., :-1] + + frames = frames * self.window + stft = np.fft.rfft(frames, n=self.fft_length, axis=-1) + magnitude_spec = np.abs(stft) + + mel_spec = np.matmul(magnitude_spec, self.mel_filters) + log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor)) + + if self.per_bin_mean is not None: + log_mel_spec = log_mel_spec - self.per_bin_mean + if self.per_bin_stddev is not None: + log_mel_spec = log_mel_spec / self.per_bin_stddev + + return log_mel_spec.squeeze(0) # (frames, n_mels) + + def extract_spectrogram(self, audio, *, spectrogram_config): + return [self._extract_spectrogram(waveform) for waveform in audio] + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Use class defaults for max_length, truncation, pad_to_multiple_of if not overridden + if max_length is None: + max_length = self.max_length + if truncation is None: + truncation = self.truncation + if pad_to_multiple_of is None: + pad_to_multiple_of = self.pad_to_multiple_of + + # Truncate first (separate from padding, matching FE behavior) + if truncation and max_length is not None: + audio = [a[..., :max_length] for a in audio] + + # Pad to longest in batch (matching FE "longest" padding strategy) + pad_length = max(a.shape[-1] for a in audio) + if pad_to_multiple_of is not None and (pad_length % pad_to_multiple_of != 0): + pad_length = ((pad_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + audio = [self.pad(a, pad_length) for a in audio] + + # Extract spectrogram + features = self.extract_spectrogram(audio, spectrogram_config=None) + + # Cast to float32 to match FE output + features = [f.astype(np.float32) for f in features] + + # Stack and return + output_key = self.model_input_names[0] + stacked = np.stack(features, axis=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["Gemma3nAudioProcessor"] diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py new file mode 100644 index 000000000000..a567b8d4f8fe --- /dev/null +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -0,0 +1,79 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...feature_extraction_utils import BatchFeature + + +class GraniteSpeechAudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + hop_length=160, + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + ), + log_mode="log10", + ) + + def extract_spectrogram(self, audio, *, spectrogram_config): + import torch + + # Use parent's extract_spectrogram for basic mel spectrogram + # Parent returns list of (n_mels, frames) tensors with log10 + (x+4)/4 normalization + features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config) + + # Transpose each: (n_mels, frames) -> (frames, n_mels) + features = [f.permute(1, 0) for f in features] + + # Remove last frame if odd + features = [f[:-1] if f.shape[0] % 2 == 1 else f for f in features] + + # Frame stacking: (frames, n_mels) -> (frames//2, 2*n_mels) + features = [f.reshape(-1, 2 * f.shape[-1]) for f in features] + + return features + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + import torch + + # Pad raw audio values + if padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + + # Extract spectrogram with frame stacking + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + + # Pad features to same length + max_feat_len = max(f.shape[0] for f in features) + padded = [] + for f in features: + if f.shape[0] < max_feat_len: + pad_amount = max_feat_len - f.shape[0] + f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) + padded.append(f) + + output_key = self.model_input_names[0] + stacked = torch.stack(padded, dim=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["GraniteSpeechAudioProcessor"] diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py new file mode 100644 index 000000000000..8f3e3f314b9c --- /dev/null +++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py @@ -0,0 +1,53 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend + + +class KyutaiSpeechToTextAudioProcessor(NumpyAudioBackend): + sample_rate = 24000 + force_mono = True + add_channel_dim = True + + def __init__(self, audio_delay_seconds=2.5, audio_silence_prefix_seconds=1.0, **kwargs): + self.audio_delay_seconds = audio_delay_seconds + self.audio_silence_prefix_seconds = audio_silence_prefix_seconds + super().__init__(**kwargs) + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + result = super()._preprocess( + audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs + ) + + pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate) + pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate) + + if pad_left > 0 or pad_right > 0: + output_key = self.model_input_names[0] + data = result[output_key] + + if isinstance(data, np.ndarray): + pad_width = [(0, 0)] * (data.ndim - 1) + [(pad_left, pad_right)] + result[output_key] = np.pad(data, pad_width, mode="constant", constant_values=0.0) + else: + import torch.nn.functional as F + + result[output_key] = F.pad(data, (pad_left, pad_right), mode="constant", value=0.0) + + return result + + +__all__ = ["KyutaiSpeechToTextAudioProcessor"] diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py new file mode 100644 index 000000000000..f12e3086b39f --- /dev/null +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, hertz_to_mel + + +def _linear_to_mel_weight_matrix(num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz): + """Kaldi-style mel weight matrix matching the LASR FE implementation.""" + internal_dtype = np.float64 + bands_to_zero = 1 + nyquist_hertz = sample_rate / 2.0 + linear_frequencies = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins, dtype=internal_dtype)[bands_to_zero:] + spectrogram_bins_mel = hertz_to_mel(linear_frequencies, mel_scale="kaldi")[:, np.newaxis] + + edges = np.linspace( + hertz_to_mel(lower_edge_hertz, mel_scale="kaldi"), + hertz_to_mel(upper_edge_hertz, mel_scale="kaldi"), + num_mel_bins + 2, + dtype=internal_dtype, + ) + lower_edge_mel = edges[:-2][np.newaxis, :] + center_mel = edges[1:-1][np.newaxis, :] + upper_edge_mel = edges[2:][np.newaxis, :] + + lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel) + upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel) + mel_weights = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes)) + return np.pad(mel_weights, [[bands_to_zero, 0], [0, 0]]).astype(np.float64) + + +class LasrAudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + add_channel_dim = True + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig(n_fft=512, hop_length=160, win_length=400, power=2.0), + mel_scale_config=MelScaleConfig(n_mels=128, f_min=125.0, f_max=7500.0, mel_scale="kaldi"), + log_mode="log", + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.mel_filters = _linear_to_mel_weight_matrix( + num_mel_bins=128, + num_spectrogram_bins=512 // 2 + 1, + sample_rate=self.sample_rate, + lower_edge_hertz=125.0, + upper_edge_hertz=7500.0, + ) + + def extract_spectrogram(self, audio, *, spectrogram_config): + import torch + + stft_cfg = spectrogram_config.stft_config + n_fft = stft_cfg.n_fft + hop_length = stft_cfg.hop_length + win_length = stft_cfg.win_length or n_fft + + waveform = torch.stack(audio, dim=0).to(torch.float64) + device = waveform.device + + window = torch.hann_window(win_length, periodic=False, device=device, dtype=torch.float64) + frames = waveform.unfold(-1, win_length, hop_length) + stft = torch.fft.rfft(window * frames, n=n_fft) + power_spec = torch.abs(stft) ** 2 + + mel_filters = torch.from_numpy(self.mel_filters).to(device) + mel_spec = torch.clamp(power_spec @ mel_filters, min=1e-5) + mel_spec = torch.log(mel_spec) + + return [mel_spec[i].to(torch.float32) for i in range(mel_spec.shape[0])] + + +__all__ = ["LasrAudioProcessor"] diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py new file mode 100644 index 000000000000..62005e416256 --- /dev/null +++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py @@ -0,0 +1,106 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend +from ...feature_extraction_utils import BatchFeature +from ...utils.import_utils import requires + + +class MusicgenMelodyAudioProcessor(TorchAudioBackend): + sample_rate = 32000 + force_mono = True + n_fft = 16384 + hop_length = 4096 + n_chroma = 12 + chunk_length = 30 + + @requires(backends=("librosa", "torch")) + def __init__(self, **kwargs): + super().__init__(**kwargs) + import librosa + import numpy as np + import torch + + self.chroma_filters = torch.from_numpy( + librosa.filters.chroma(sr=self.sample_rate, n_fft=self.n_fft, tuning=0, n_chroma=self.n_chroma) + ).float() + + def extract_spectrogram(self, audio, *, spectrogram_config): + import torch + import torchaudio + + waveform = torch.stack(audio, dim=0) + device = waveform.device + batch_size = waveform.shape[0] + + # Pad if too short for FFT + if waveform.shape[-1] < self.n_fft: + pad = self.n_fft - waveform.shape[-1] + rest = 0 if pad % 2 == 0 else 1 + waveform = torch.nn.functional.pad(waveform, (pad // 2, pad // 2 + rest), "constant", 0) + + # Add channel dim for spectrogram: (batch, 1, length) + waveform = waveform.unsqueeze(1) + + # Power spectrogram (normalized) + spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, win_length=self.n_fft, hop_length=self.hop_length, + power=2, center=True, pad=0, normalized=True, + ).to(device) + spec = spec_transform(waveform).squeeze(1) + + # Chroma features + chroma_filters = self.chroma_filters.to(device) + raw_chroma = torch.einsum("cf, ...ft->...ct", chroma_filters, spec) + + # Normalize with inf norm + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=float("inf"), dim=-2, eps=1e-6) + + # Transpose: (batch, chroma, frames) -> (batch, frames, chroma) + norm_chroma = norm_chroma.transpose(1, 2) + + # One-hot encoding: argmax along chroma dim + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return [norm_chroma[i] for i in range(batch_size)] + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + import torch + + # Pad raw audio + if padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + + # Extract chroma features + features = self.extract_spectrogram(audio, spectrogram_config=None) + + # Pad features + max_feat_len = max(f.shape[0] for f in features) + padded = [] + for f in features: + if f.shape[0] < max_feat_len: + pad_amount = max_feat_len - f.shape[0] + f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) + padded.append(f) + + output_key = self.model_input_names[0] + stacked = torch.stack(padded, dim=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["MusicgenMelodyAudioProcessor"] diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py new file mode 100644 index 000000000000..d83bce115644 --- /dev/null +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -0,0 +1,119 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import librosa +import torch + +from ...audio_processing_backends import TorchAudioBackend +from ...feature_extraction_utils import BatchFeature + +LOG_ZERO_GUARD_VALUE = 2**-24 +EPSILON = 1e-5 + + +class ParakeetAudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + preemphasis = 0.97 + n_fft = 512 + hop_length = 160 + win_length = 400 + n_mels = 80 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + # Use librosa for mel filters to match the FeatureExtractor exactly + # (mel_filter_bank uses float64 internally, causing numerical differences) + mel_filters = librosa.filters.mel( + sr=self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=0.0, + fmax=self.sample_rate / 2, + norm="slaney", + ) + self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) + + def _torch_extract_fbank_features(self, waveform, device="cpu"): + """Extract log-mel spectrogram features, matching the FE implementation.""" + window = torch.hann_window(self.win_length, periodic=False, device=device) + stft = torch.stft( + waveform, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=window, + return_complex=True, + pad_mode="constant", + ) + # Match original implementation: view_as_real then sqrt(sum of squares) + magnitudes = torch.view_as_real(stft) + magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) + magnitudes = magnitudes.pow(2) + + # Log mel spectrogram + mel_filters = self.mel_filters.to(device) + mel_spec = mel_filters @ magnitudes + mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) + + # (batch, n_mels, frames) -> (batch, frames, n_mels) + mel_spec = mel_spec.permute(0, 2, 1) + return mel_spec + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + device = "cpu" + + # Record original audio lengths before padding + audio_lengths = torch.tensor([a.shape[-1] for a in audio]) + + # Pad values to longest + audio = self.pad_values(audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of) + + # Stack into batch tensor + waveform = torch.stack(audio, dim=0).to(torch.float32) + + # Preemphasis (mask-aware, matching FE) + if self.preemphasis is not None: + timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) < audio_lengths.unsqueeze(1) + waveform = torch.cat( + [waveform[:, :1], waveform[:, 1:] - self.preemphasis * waveform[:, :-1]], dim=1 + ) + waveform = waveform.masked_fill(~timemask, 0.0) + + # Extract log-mel spectrogram + input_features = self._torch_extract_fbank_features(waveform, device) + + # Compute feature lengths (matching FE formula) + features_lengths = torch.floor_divide( + audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length + ) + + # Build attention mask over feature frames + attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None] + + # Mask-aware normalization (matching FE exactly) + mask = attention_mask.unsqueeze(-1) + input_features_masked = input_features * mask + mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1) + mean = mean.unsqueeze(1) + variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) + std = torch.sqrt(variance).unsqueeze(1) + input_features = (input_features - mean) / (std + EPSILON) + input_features *= mask + + output_key = self.model_input_names[0] + return BatchFeature(data={output_key: input_features}, tensor_type=return_tensors) + + +__all__ = ["ParakeetAudioProcessor"] diff --git a/src/transformers/models/pe_audio/audio_processing_pe_audio.py b/src/transformers/models/pe_audio/audio_processing_pe_audio.py new file mode 100644 index 000000000000..1c8969b28ed2 --- /dev/null +++ b/src/transformers/models/pe_audio/audio_processing_pe_audio.py @@ -0,0 +1,23 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import NumpyAudioBackend + + +class PeAudioAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + + +__all__ = ["PeAudioAudioProcessor"] diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py new file mode 100644 index 000000000000..01781c69a5db --- /dev/null +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -0,0 +1,117 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import mel_filter_bank +from ...feature_extraction_utils import BatchFeature + + +class Phi4MultimodalAudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + preemphasis = 0.97 + n_fft = 512 + hop_length = 160 + win_length = 400 + n_mels = 80 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.mel_filters = mel_filter_bank( + num_frequency_bins=self.n_fft // 2 + 1, + num_mel_filters=self.n_mels, + min_frequency=0, + max_frequency=7690, + sampling_rate=self.sample_rate, + norm=None, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ) + + def extract_spectrogram(self, audio, *, spectrogram_config): + import torch + + waveform = torch.stack(audio, dim=0) + device = waveform.device + batch_size = waveform.shape[0] + lengths = torch.tensor([a.shape[-1] for a in audio], device=device) + + # Unfold into frames + frames = waveform.unfold(-1, self.win_length, self.hop_length) + + # Frame-level masking for padded inputs + if batch_size > 1: + frames = frames.clone() + to_mask_batch_idxs = torch.arange(batch_size, device=device)[lengths != lengths.max()] + if to_mask_batch_idxs.numel() > 0: + batch_idxs_down = (lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 + batch_idxs_up = (lengths[to_mask_batch_idxs] // self.hop_length) - 1 + offset_idx = batch_idxs_down.min() + max_idx = batch_idxs_up.max() + + mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1) + mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( + mask < (batch_idxs_up - offset_idx).unsqueeze(1) + ) + mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length) + masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) + frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames + + # Pre-emphasis + frames_prev = torch.roll(frames, 1, dims=-1) + frames_prev[:, :, 0] = frames_prev[:, :, 1] + frames = (frames - self.preemphasis * frames_prev) * 32768 + + # Hamming window + FFT + fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64) + S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) + S = S.view(batch_size, -1, S.shape[-1]).to(torch.complex64) + + spec = torch.abs(S) + spec_power = spec**2 + + # Mel filterbank + log + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) + log_spec = torch.log(log_spec) + + return [log_spec[i] for i in range(batch_size)] + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + import torch + + # Pad values to longest + if padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + + # Extract spectrogram + features = self.extract_spectrogram(audio, spectrogram_config=None) + + # Pad features and stack + max_feat_len = max(f.shape[0] for f in features) + padded = [] + for f in features: + if f.shape[0] < max_feat_len: + pad_amount = max_feat_len - f.shape[0] + f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) + padded.append(f) + + output_key = self.model_input_names[0] + stacked = torch.stack(padded, dim=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["Phi4MultimodalAudioProcessor"] diff --git a/src/transformers/models/pop2piano/audio_processing_pop2piano.py b/src/transformers/models/pop2piano/audio_processing_pop2piano.py new file mode 100644 index 000000000000..9cd546b15a59 --- /dev/null +++ b/src/transformers/models/pop2piano/audio_processing_pop2piano.py @@ -0,0 +1,34 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: Full Pop2Piano feature extraction requires the Essentia library for +# beat detection (RhythmExtractor2013) and scipy for beat interpolation. +# This audio processor provides the basic mel spectrogram configuration but +# does not implement the complete beat-aligned segmentation pipeline. + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig + + +class Pop2PianoAudioProcessor(NumpyAudioBackend): + sample_rate = 22050 + force_mono = True + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig(n_fft=4096, hop_length=1024, power=2.0), + mel_scale_config=MelScaleConfig(n_mels=512, f_min=10.0, mel_scale="htk"), + log_mode="log10", + ) + + +__all__ = ["Pop2PianoAudioProcessor"] diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py new file mode 100644 index 000000000000..3a3ef16b5cbd --- /dev/null +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -0,0 +1,96 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...feature_extraction_utils import BatchFeature + + +class SeamlessM4tAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + stride = 2 + + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + win_length=400, + hop_length=160, + window_fn="povey_window", + power=2.0, + center=False, + periodic=False, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=20.0, + f_max=8000.0, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ), + log_mode="log", + preemphasis=0.97, + remove_dc_offset=True, + mel_floor=1.192092955078125e-07, + waveform_scale=32768.0, + ) + + def feature_normalize(self, features, *, feature_normalization_config): + # Per-mel-bin normalization with ddof=1 for variance + normalized = [] + for f in features: + mean = np.expand_dims(f.mean(axis=0), 0) + var = np.expand_dims(f.var(axis=0, ddof=1), 0) + normalized.append((f - mean) / np.sqrt(var + 1e-7)) + return normalized + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Extract Kaldi-style features via generic config-based API + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + + # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) + features = [f.T for f in features] + + # Per-mel-bin normalization + features = self.feature_normalize(features, feature_normalization_config=None) + + # Pad features to longest (pad_to_multiple_of=2 for stride) + max_len = max(f.shape[0] for f in features) + if max_len % self.stride != 0: + max_len = ((max_len // self.stride) + 1) * self.stride + padded = [] + for f in features: + if f.shape[0] < max_len: + pad_amount = max_len - f.shape[0] + f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + padded.append(f) + + stacked = np.stack(padded, axis=0) # (batch, frames, n_mels) + batch_size, num_frames, num_channels = stacked.shape + + # Stride concatenation + remainder = num_frames % self.stride + if remainder != 0: + stacked = stacked[:, : num_frames - remainder, :] + num_frames = num_frames - remainder + + stacked = stacked.reshape(batch_size, num_frames // self.stride, num_channels * self.stride) + + output_key = self.model_input_names[0] + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["SeamlessM4tAudioProcessor"] diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py new file mode 100644 index 000000000000..b670eb2a724c --- /dev/null +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -0,0 +1,95 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...feature_extraction_utils import BatchFeature + + +class SpeechToTextAudioProcessor(NumpyAudioBackend): + sample_rate = 16000 + force_mono = True + + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + win_length=400, + hop_length=160, + window_fn="povey_window", + power=2.0, + center=False, + periodic=False, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=20.0, + f_max=8000.0, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ), + log_mode="log", + preemphasis=0.97, + remove_dc_offset=True, + mel_floor=1.192092955078125e-07, + waveform_scale=32768.0, + ) + + def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): + super().__init__(**kwargs) + self.normalize_means = normalize_means + self.normalize_vars = normalize_vars + + @staticmethod + def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, padding_value=0.0): + if normalize_means: + mean = x[:input_length].mean(axis=0) + x = np.subtract(x, mean) + if normalize_vars: + std = x[:input_length].std(axis=0) + x = np.divide(x, std) + if input_length < x.shape[0]: + x[input_length:] = padding_value + return x.astype(np.float32) + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Extract Kaldi-style features via generic config-based API + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + + # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) + features = [f.T for f in features] + lengths = [f.shape[0] for f in features] + + # Pad features to longest + max_len = max(lengths) + padded = [] + for f in features: + if f.shape[0] < max_len: + pad_amount = max_len - f.shape[0] + f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + padded.append(f) + + # Utterance CMVN normalization + normalized = [ + self.utterance_cmvn(f, length, self.normalize_means, self.normalize_vars, self.padding_value) + for f, length in zip(padded, lengths) + ] + + output_key = self.model_input_names[0] + stacked = np.stack(normalized, axis=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["SpeechToTextAudioProcessor"] diff --git a/src/transformers/models/speecht5/audio_processing_speecht5.py b/src/transformers/models/speecht5/audio_processing_speecht5.py new file mode 100644 index 000000000000..4fc4c2226d35 --- /dev/null +++ b/src/transformers/models/speecht5/audio_processing_speecht5.py @@ -0,0 +1,23 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend + + +class SpeechT5AudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + + +__all__ = ["SpeechT5AudioProcessor"] diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py new file mode 100644 index 000000000000..ca8e64808c26 --- /dev/null +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -0,0 +1,118 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...feature_extraction_utils import BatchFeature + + +class UnivNetAudioProcessor(NumpyAudioBackend): + sample_rate = 24000 + force_mono = True + n_fft = 1024 + hop_length = 256 + n_mels = 100 + fmin = 0.0 + fmax = 12000.0 + mel_floor = 1e-9 + compression_clip_val = 1e-5 + compression_factor = 1.0 + do_normalize = False + normalize_min = -11.512925148010254 + normalize_max = 2.3143386840820312 + max_length_s = 10 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.num_max_samples = self.max_length_s * self.sample_rate + self.mel_filters = mel_filter_bank( + num_frequency_bins=1 + self.n_fft // 2, + num_mel_filters=self.n_mels, + min_frequency=self.fmin, + max_frequency=self.fmax, + sampling_rate=self.sample_rate, + norm="slaney", + mel_scale="slaney", + ) + self.window = window_function(self.n_fft, "hann", periodic=True) + + def mel_spectrogram(self, waveform): + # Reflect-pad waveform + pad_amount = int((self.n_fft - self.hop_length) / 2) + waveform = np.pad(waveform, (pad_amount, pad_amount), mode="reflect") + + # Complex spectrogram + complex_spec = spectrogram( + waveform, + window=self.window, + frame_length=self.n_fft, + hop_length=self.hop_length, + fft_length=self.n_fft, + power=None, + center=False, + mel_filters=None, + mel_floor=None, + ) + + # Custom amplitude spectrogram: sqrt(real^2 + imag^2 + mel_floor) + amplitude_spec = np.sqrt(np.real(complex_spec) ** 2 + np.imag(complex_spec) ** 2 + self.mel_floor) + + # Apply mel filter bank + mel_spec = np.matmul(self.mel_filters.T, amplitude_spec) + + # Log compression + log_mel = np.log(np.clip(mel_spec, a_min=self.compression_clip_val, a_max=None) * self.compression_factor) + + return log_mel.T # (frames, n_mels) + + def normalize(self, spectrogram_data): + return 2 * ((spectrogram_data - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 + + def extract_spectrogram(self, audio, *, spectrogram_config): + features = [] + for waveform in audio: + waveform = np.squeeze(waveform) + mel = self.mel_spectrogram(waveform) + if self.do_normalize: + mel = self.normalize(mel) + features.append(mel) + return features + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Pad raw audio + if padding: + audio = self.pad_values( + audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of + ) + + # Extract mel spectrograms + features = self.extract_spectrogram(audio, spectrogram_config=None) + + # Pad features + max_feat_len = max(f.shape[0] for f in features) + padded = [] + for f in features: + if f.shape[0] < max_feat_len: + pad_amount = max_feat_len - f.shape[0] + f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + padded.append(f) + + output_key = self.model_input_names[0] + stacked = np.stack(padded, axis=0) + return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + +__all__ = ["UnivNetAudioProcessor"] diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py new file mode 100644 index 000000000000..1d342166a432 --- /dev/null +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -0,0 +1,47 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import NormalizationConfig + + +class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): + sample_rate = 24000 + force_mono = True + add_channel_dim = True + do_values_normalize = True + normalization_config = NormalizationConfig(method="rms_normalize", normalize_before_pad=True) + + def __init__(self, target_dB_FS=-25, eps=1e-6, **kwargs): + self.target_dB_FS = target_dB_FS + self.eps = eps + super().__init__(**kwargs) + + def values_normalize(self, audio, *, normalization_config): + import torch + + if normalization_config.method == "rms_normalize": + normalized = [] + for a in audio: + rms = torch.sqrt(torch.mean(a**2)) + a = a * (10 ** (self.target_dB_FS / 20) / (rms + self.eps)) + max_val = torch.max(torch.abs(a)) + if max_val > 1.0: + a = a / (max_val + self.eps) + normalized.append(a) + return normalized + return super().values_normalize(audio, normalization_config=normalization_config) + + +__all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py new file mode 100644 index 000000000000..15a1203b8d6c --- /dev/null +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -0,0 +1,38 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig + + +class VoxtralRealtimeAudioProcessor(TorchAudioBackend): + sample_rate = 16000 + force_mono = True + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=400, + hop_length=160, + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=128, + mel_scale="slaney", + norm="slaney", + ), + log_mode="log10", + global_log_mel_max=1.5, + ) + + +__all__ = ["VoxtralRealtimeAudioProcessor"] diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py index 0258c133243a..a5258961c7c3 100644 --- a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py @@ -17,8 +17,6 @@ class Wav2Vec2AudioProcessor(TorchAudioBackend): - model_input_names = ["input_values", "attention_mask"] - sample_rate = 16000 force_mono = True do_values_normalize = True diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 4b8a7013f677..2536fc75f9d1 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -17,8 +17,6 @@ class WhisperAudioProcessor(TorchAudioBackend): - model_input_names = ["input_features"] - sample_rate = 16000 force_mono = True truncation = True @@ -32,6 +30,7 @@ class WhisperAudioProcessor(TorchAudioBackend): mel_scale_config=MelScaleConfig( n_mels=80, mel_scale="slaney", + norm="slaney", ), log_mode="log10", chunk_length=30, From 8a3066da93da01cf312be80aa1fe8e508245cb40 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Mon, 9 Mar 2026 16:37:58 +0100 Subject: [PATCH 06/28] update --- src/transformers/audio_processing_backends.py | 474 +++++------------- src/transformers/audio_processing_utils.py | 317 +++++++----- src/transformers/audio_utils.py | 39 +- ...rocessing_audio_spectrogram_transformer.py | 22 +- .../gemma3n/audio_processing_gemma3n.py | 160 +++--- .../parakeet/audio_processing_parakeet.py | 111 ++-- .../audio_processing_seamless_m4t.py | 4 +- ...processing_vibevoice_acoustic_tokenizer.py | 25 +- .../wav2vec2/audio_processing_wav2vec2.py | 14 +- .../whisper/audio_processing_whisper.py | 16 +- 10 files changed, 487 insertions(+), 695 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index c23b2272afc1..aa88845a5bb3 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -6,6 +6,7 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # + # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +17,9 @@ import numpy as np from .audio_processing_utils import BaseAudioProcessor -from .audio_utils import SpectrogramConfig, NormalizationConfig +from .audio_utils import SpectrogramConfig, mel_filter_bank from .feature_extraction_utils import BatchFeature -from .utils import logging, is_torch_available +from .utils import PaddingStrategy, TensorType, is_torch_available, is_torch_tensor, logging, to_numpy logger = logging.get_logger(__name__) @@ -33,23 +34,29 @@ class NumpyAudioBackend(BaseAudioProcessor): @property def backend(self) -> str: - return "numpy" + return "numpy" - def process_audio(self, audio_el): + def _process_audio(self, audio_el): """ Process a single raw audio input into a np.ndarray. - Handles mono conversion (averaging channels) and ensures numpy format. + Handles mono conversion (averaging channels) and numpy conversion. + Closely mirrors the torch backend logic: expects channel-first. """ if not isinstance(audio_el, np.ndarray): audio_el = np.asarray(audio_el) - if self.force_mono: - audio_el = audio_el.mean(axis=1) if audio_el.ndim > 1 else audio_el - + if audio_el.ndim > 1: + # Expecting channel-first: (channels, samples) + if self.force_mono and audio_el.shape[0] > 1: + audio_el = audio_el.mean(axis=0) + elif audio_el.shape[0] == 1: + audio_el = np.squeeze(audio_el, axis=0) + else: + raise ValueError("Audio has more than one channel but force_mono is False") return audio_el - def pad(self, audio: np.ndarray, max_length: int) -> np.ndarray: + def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray: """Pad a single audio array to a target length using np.pad.""" current_length = audio.shape[-1] if current_length >= max_length: @@ -71,64 +78,19 @@ def pad(self, audio: np.ndarray, max_length: int) -> np.ndarray: return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) - def pad_values( - self, - audio: list[np.ndarray], - *, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - ) -> list[np.ndarray]: - """Truncate and/or pad raw audio values (stage 3).""" - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - if truncation: - if max_length is None: - raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") - audio = [a[..., :max_length] for a in audio] - - if max_length is None: - max_length = max(a.shape[-1] for a in audio) - - if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - audio = [self.pad(a, max_length) for a in audio] - return audio - - def values_normalize( - self, - audio: list[np.ndarray], - *, - normalization_config: NormalizationConfig, - ) -> list[np.ndarray]: - """Normalize raw audio values (stage 4). Supports zero-mean-unit-var.""" - if normalization_config.method == "zero_mean_unit_var": - return [ - (a - np.mean(a)) / np.sqrt(np.var(a) + 1e-7) - for a in audio - ] - raise ValueError(f"Unknown normalization method: {normalization_config.method}") - - def extract_spectrogram( + def _extract_spectrogram( self, audio: list[np.ndarray], *, spectrogram_config: SpectrogramConfig, + **kwargs, ) -> list[np.ndarray]: - """Extract log-mel spectrogram features using the numpy spectrogram() function.""" + """Compute the (power) spectrogram via STFT using the numpy backend.""" from .audio_utils import spectrogram as compute_spectrogram, window_function - if not hasattr(self, "mel_filters"): - raise NotImplementedError( - f"{self.__class__.__name__} does not have `mel_filters`. " - "Either set `mel_filters` or override `extract_spectrogram`." - ) - stft_cfg = spectrogram_config.stft_config n_fft = stft_cfg.n_fft - hop_length = stft_cfg.hop_length if stft_cfg.hop_length is not None else n_fft // 4 + hop_length = stft_cfg.hop_length win_length = stft_cfg.win_length if stft_cfg.win_length is not None else n_fft # Build window — map torch names like "hann_window" to audio_utils names like "hann" @@ -151,7 +113,7 @@ def extract_spectrogram( pad_mode=stft_cfg.pad_mode, preemphasis=spectrogram_config.preemphasis, remove_dc_offset=spectrogram_config.remove_dc_offset, - mel_filters=self.mel_filters, + mel_filters=None, mel_floor=spectrogram_config.mel_floor, log_mel=spectrogram_config.log_mode if spectrogram_config.log_mode != "log10" else "log10", ) @@ -159,47 +121,35 @@ def extract_spectrogram( return features - def feature_normalize( + def _apply_mel_scale( self, features: list[np.ndarray], *, - feature_normalization_config: NormalizationConfig, - ) -> list[np.ndarray]: - """Normalize extracted features (stage 6). Supports zero-mean-unit-var.""" - if feature_normalization_config.method == "zero_mean_unit_var": - return [ - (f - np.mean(f)) / np.sqrt(np.var(f) + 1e-7) - for f in features - ] - raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") - - def pad_features( - self, - features: list[np.ndarray], - *, - max_length: int | None = None, - pad_to_multiple_of: int | None = None, + spectrogram_config: SpectrogramConfig, + **kwargs, ) -> list[np.ndarray]: - """Pad 2D features to a target length (stage 7).""" - if max_length is None: - max_length = max(f.shape[-1] for f in features) - - if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - padded = [] - for f in features: - current_length = f.shape[-1] - if current_length >= max_length: - padded.append(f[..., :max_length]) - else: - pad_length = max_length - current_length - if f.ndim == 2: - pad_width = [(0, 0), (0, pad_length)] - else: - pad_width = [(0, 0)] * (f.ndim - 1) + [(0, pad_length)] - padded.append(np.pad(f, pad_width, mode="constant", constant_values=0.0)) - return padded + """Apply mel filterbank to spectrogram features using the numpy backend.""" + if not hasattr(self, "mel_filters"): + raise ValueError( + f"{self.__class__.__name__} does not have `mel_filters`. " + "Either set `mel_filters` or override `_apply_mel_scale`." + ) + + mel_filters = self.mel_filters + return [mel_filters.T @ spec for spec in features] + + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + return mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + ) def _preprocess( self, @@ -209,80 +159,21 @@ def _preprocess( truncation, pad_to_multiple_of, return_tensors, - do_pad_values=None, - do_values_normalize=None, - normalization_config=None, spectrogram_config=None, - do_feature_normalize=None, - feature_normalization_config=None, - do_pad_features=None, + do_extract_spectrogram=None, + do_batch_spectrogram=True, **kwargs, ) -> BatchFeature: - """Preprocess using NumPy backend: 5-stage pipeline (stages 3-7).""" - # Default do_values_normalize to True if a normalization config is provided - if do_values_normalize is None: - do_values_normalize = normalization_config is not None - - # Determine normalize_before_pad for values - values_normalize_before_pad = ( - normalization_config.normalize_before_pad if normalization_config is not None else True - ) - feature_normalize_before_pad = ( - feature_normalization_config.normalize_before_pad if feature_normalization_config is not None else True - ) + # pad and truncate + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - # --- Stages 3 & 4: Values padding and normalization --- - if values_normalize_before_pad: - # Stage 4 before 3: normalize then pad - if do_values_normalize and normalization_config is not None: - audio = self.values_normalize(audio, normalization_config=normalization_config) - if do_pad_values or padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) - else: - # Stage 3 before 4: pad then normalize - if do_pad_values or padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) - if do_values_normalize and normalization_config is not None: - audio = self.values_normalize(audio, normalization_config=normalization_config) - - # --- Stage 5: Feature extraction --- - if spectrogram_config is not None: - features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) - if self.transpose_features: - features = [f.T for f in features] + if do_extract_spectrogram: + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, do_batch_spectrogram=do_batch_spectrogram) + output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) else: - features = audio - - # --- Stages 6 & 7: Feature normalization and padding --- - if feature_normalize_before_pad: - if do_feature_normalize and feature_normalization_config is not None: - features = self.feature_normalize( - features, feature_normalization_config=feature_normalization_config - ) - if do_pad_features: - features = self.pad_features( - features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of - ) - else: - if do_pad_features: - features = self.pad_features( - features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of - ) - if do_feature_normalize and feature_normalization_config is not None: - features = self.feature_normalize( - features, feature_normalization_config=feature_normalization_config - ) - - # Stack into batch - output_key = self.model_input_names[0] - stacked = np.stack(features, axis=0) if return_tensors else features - if self.add_channel_dim and isinstance(stacked, np.ndarray): - stacked = stacked[:, np.newaxis, :] - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) + + return output class TorchAudioBackend(BaseAudioProcessor): @@ -292,7 +183,7 @@ class TorchAudioBackend(BaseAudioProcessor): def backend(self) -> str: return "torch" - def process_audio(self, audio_el): + def _process_audio(self, audio_el): """ Process a single raw audio input into a torch.Tensor. @@ -300,15 +191,21 @@ def process_audio(self, audio_el): """ import torch - if self.force_mono: - audio_el = audio_el.mean(axis=1) if audio_el.ndim > 1 else audio_el - if isinstance(audio_el, np.ndarray): audio_el = torch.from_numpy(audio_el) + if audio_el.ndim > 1: + # TODO: we would need to ensure somewhere audio is channel first + if self.force_mono and audio_el.shape[0] > 1: + audio_el = audio_el.mean(dim=0) + elif audio_el.shape[0] == 1: + audio_el = audio_el.squeeze(0) + else: + raise ValueError("Audio has more than one channel but force_mono is False") + return audio_el - def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": + def _pad_single(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": """Pad a single audio tensor to a target length using torch.nn.functional.pad.""" import torch.nn.functional as F @@ -331,132 +228,85 @@ def pad(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": return F.pad(audio, pad_args, "constant", self.padding_value) - def pad_values( - self, - audio: list["torch.Tensor"], - *, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - ) -> list["torch.Tensor"]: - """Truncate and/or pad raw audio values (stage 3).""" - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - if truncation: - if max_length is None: - raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") - audio = [a[..., :max_length] for a in audio] - - if max_length is None: - max_length = max(a.shape[-1] for a in audio) - - if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - audio = [self.pad(a, max_length) for a in audio] - return audio - - def values_normalize( - self, - audio: list["torch.Tensor"], - *, - normalization_config: NormalizationConfig, - ) -> list["torch.Tensor"]: - """Normalize raw audio values (stage 4). Supports zero-mean-unit-var.""" - import torch - - if normalization_config.method == "zero_mean_unit_var": - return [ - (a - torch.mean(a)) / torch.sqrt(torch.var(a, correction=0) + 1e-7) - for a in audio - ] - raise ValueError(f"Unknown normalization method: {normalization_config.method}") - - def extract_spectrogram( + def _extract_spectrogram( self, audio: list["torch.Tensor"], *, spectrogram_config: SpectrogramConfig, + **kwargs, ) -> list["torch.Tensor"]: - """Extract log-mel spectrogram features using the provided config and mel_filters.""" + """Compute the (power) spectrogram via STFT using the torch backend.""" import torch - if not hasattr(self, "mel_filters"): - raise NotImplementedError( - f"{self.__class__.__name__} does not have `mel_filters`. " - "Either set `mel_filters` or override `extract_spectrogram`." - ) - stft_cfg = spectrogram_config.stft_config n_fft = stft_cfg.n_fft hop_length = stft_cfg.hop_length + win_length = stft_cfg.win_length if stft_cfg.win_length is not None else n_fft - waveform = torch.stack(audio, dim=0) + # Stack list into batch for efficient batched STFT if not already batched + if isinstance(audio, torch.Tensor) and audio.dim() == 2: + waveform = audio + else: + waveform = torch.stack(audio) # (batch, length) device = waveform.device - window = torch.hann_window(n_fft, device=device) - stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True) + if spectrogram_config.preemphasis is not None: + audio_ranges = kwargs.get("audio_ranges", None) + timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) + timemask = timemask < audio_ranges.unsqueeze(1) + waveform = waveform.masked_fill(~timemask, 0.0) + + window_fn = getattr(torch, stft_cfg.window_fn, torch.hann_window) + window = window_fn(win_length, periodic=stft_cfg.periodic, device=device) + + stft = torch.stft( + waveform, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + center=stft_cfg.center, + pad_mode=stft_cfg.pad_mode, + normalized=stft_cfg.normalized, + onesided=stft_cfg.onesided, + return_complex=True, + ) magnitudes = stft[..., :-1].abs() ** stft_cfg.power - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - mel_spec = mel_filters.T @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - if spectrogram_config.global_log_mel_max is not None: - max_val = torch.tensor( - spectrogram_config.global_log_mel_max, device=log_spec.device, dtype=log_spec.dtype - ) - elif waveform.dim() == 2: - max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] - else: - max_val = log_spec.max() - log_spec = torch.maximum(log_spec, max_val - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - - return [log_spec[i] for i in range(log_spec.shape[0])] + return [magnitudes[i] for i in range(magnitudes.shape[0])] - def feature_normalize( + def _apply_mel_scale( self, features: list["torch.Tensor"], *, - feature_normalization_config: NormalizationConfig, + spectrogram_config: SpectrogramConfig, + **kwargs, ) -> list["torch.Tensor"]: - """Normalize extracted features (stage 6). Supports zero-mean-unit-var.""" + """Apply mel filterbank to spectrogram features using the torch backend.""" import torch - if feature_normalization_config.method == "zero_mean_unit_var": - return [ - (f - torch.mean(f)) / torch.sqrt(torch.var(f, correction=0) + 1e-7) - for f in features - ] - raise ValueError(f"Unknown normalization method: {feature_normalization_config.method}") - - def pad_features( - self, - features: list["torch.Tensor"], - *, - max_length: int | None = None, - pad_to_multiple_of: int | None = None, - ) -> list["torch.Tensor"]: - """Pad 2D features to a target length (stage 7).""" - import torch.nn.functional as F - - if max_length is None: - max_length = max(f.shape[-1] for f in features) + if not hasattr(self, "mel_filters"): + raise ValueError( + f"{self.__class__.__name__} does not have `mel_filters`. " + "Either set `mel_filters` or override `_apply_mel_scale`." + ) - if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + device = features[0].device + mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + return [mel_filters.T @ spec for spec in features] - padded = [] - for f in features: - current_length = f.shape[-1] - if current_length >= max_length: - padded.append(f[..., :max_length]) - else: - pad_length = max_length - current_length - padded.append(F.pad(f, (0, pad_length), "constant", 0.0)) - return padded + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + return mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + ) def _preprocess( self, @@ -466,79 +316,21 @@ def _preprocess( truncation, pad_to_multiple_of, return_tensors, - do_pad_values=None, - do_values_normalize=None, - normalization_config=None, spectrogram_config=None, - do_feature_normalize=None, - feature_normalization_config=None, - do_pad_features=None, + do_extract_spectrogram=None, + do_batch_spectrogram=True, **kwargs, ) -> BatchFeature: - """Preprocess using Torch backend: 5-stage pipeline (stages 3-7).""" import torch - # Default do_values_normalize to True if a normalization config is provided - if do_values_normalize is None: - do_values_normalize = normalization_config is not None + # pad and truncate + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - # Determine normalize_before_pad for values - values_normalize_before_pad = ( - normalization_config.normalize_before_pad if normalization_config is not None else True - ) - feature_normalize_before_pad = ( - feature_normalization_config.normalize_before_pad if feature_normalization_config is not None else True - ) - - # --- Stages 3 & 4: Values padding and normalization --- - if values_normalize_before_pad: - # Stage 4 before 3: normalize then pad - if do_values_normalize and normalization_config is not None: - audio = self.values_normalize(audio, normalization_config=normalization_config) - if do_pad_values or padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) + if do_extract_spectrogram: + audio = torch.stack(audio) if do_batch_spectrogram else audio + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, do_batch_spectrogram=do_batch_spectrogram) + output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) else: - # Stage 3 before 4: pad then normalize - if do_pad_values or padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) - if do_values_normalize and normalization_config is not None: - audio = self.values_normalize(audio, normalization_config=normalization_config) - - # --- Stage 5: Feature extraction --- - if spectrogram_config is not None: - features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) - if self.transpose_features: - features = [f.permute(*reversed(range(f.dim()))) for f in features] - else: - features = audio - - # --- Stages 6 & 7: Feature normalization and padding --- - if feature_normalize_before_pad: - if do_feature_normalize and feature_normalization_config is not None: - features = self.feature_normalize( - features, feature_normalization_config=feature_normalization_config - ) - if do_pad_features: - features = self.pad_features( - features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of - ) - else: - if do_pad_features: - features = self.pad_features( - features, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of - ) - if do_feature_normalize and feature_normalization_config is not None: - features = self.feature_normalize( - features, feature_normalization_config=feature_normalization_config - ) - - # Stack into batch - output_key = self.model_input_names[0] - stacked = torch.stack(features, dim=0) if return_tensors else features - if self.add_channel_dim and isinstance(stacked, torch.Tensor): - stacked = stacked.unsqueeze(1) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) \ No newline at end of file + output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) + + return output \ No newline at end of file diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index cd44d9b6e0bb..8226e573cddd 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import fields, replace from typing import Unpack from huggingface_hub.dataclasses import validate_typed_dict from .audio_processing_base import AudioProcessingMixin -from .audio_utils import AudioInput, SpectrogramConfig, NormalizationConfig, make_list_of_audio, mel_filter_bank +from .audio_utils import AudioInput, SpectrogramConfig, make_list_of_audio, mel_filter_bank from .feature_extraction_utils import BatchFeature from .processing_utils import AudioKwargs -from .utils import TensorType, logging +from .utils import PaddingStrategy, TensorType, logging logger = logging.get_logger(__name__) @@ -31,10 +32,9 @@ class AudioProcessingKwargs(AudioKwargs, total=False): do_pad_values: bool | None do_values_normalize: bool | None - normalization_config: dict | NormalizationConfig | None spectrogram_config: dict | SpectrogramConfig | None + do_extract_spectrogram: bool | None do_feature_normalize: bool | None - feature_normalization_config: dict | NormalizationConfig | None do_pad_features: bool | None do_resample: bool | None @@ -43,11 +43,13 @@ class BaseAudioProcessor(AudioProcessingMixin): model_input_names = ["audio"] valid_kwargs = AudioProcessingKwargs unused_kwargs = None + feature_size = 1 padding = True padding_side = "right" padding_value = 0.0 max_length = None truncation = None + return_attention_mask = True sample_rate: int = None force_mono: bool = None @@ -55,10 +57,11 @@ class BaseAudioProcessor(AudioProcessingMixin): # Pipeline stage defaults do_pad_values = None do_values_normalize = None - normalization_config = None + normalize_before_pad = True spectrogram_config = None + do_extract_spectrogram = None do_feature_normalize = None - feature_normalization_config = None + feature_normalize_before_pad = True do_pad_features = None do_resample = False add_channel_dim = False @@ -95,126 +98,231 @@ def __init__( for key, value in attributes.items(): setattr(self, key, value) - # Derive max_length and mel_filters from spectrogram_config - if self.spectrogram_config is not None: - sc = self.spectrogram_config + # Derive mel_filters from spectrogram_config if mel_scale_config is set + # TODO: maybe the mel spectrogram initialization should be lazy? + if self.spectrogram_config is not None and self.spectrogram_config.mel_scale_config is not None: if not hasattr(self, "mel_filters"): - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + sc.stft_config.n_fft // 2, - num_mel_filters=sc.mel_scale_config.n_mels, - min_frequency=sc.mel_scale_config.f_min, - max_frequency=sc.mel_scale_config.f_max if sc.mel_scale_config.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=sc.mel_scale_config.norm, - mel_scale=sc.mel_scale_config.mel_scale, - triangularize_in_mel_space=sc.mel_scale_config.triangularize_in_mel_space, - ) + self.mel_filters = self._mel_filter_bank(self.spectrogram_config) def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: return self.preprocess(audio, *args, **kwargs) - def process_audio(self, *args, **kwargs): + def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: """ - Process a single raw audio input into the backend's working format. - - Implemented by backend subclasses (e.g., `TorchAudioBackend`). Converts a raw input - (NumPy array) to the backend's internal format (e.g., `torch.Tensor`), handles - mono conversion if needed. + Preprocess an audio or a batch of audio. """ - raise NotImplementedError + # Perform type validation on received kwargs + validate_typed_dict(self.valid_kwargs, kwargs) + + # Set default kwargs from self. + for kwarg_name in self._valid_kwargs_names: + kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + + # Standardize kwargs (coerce dicts to config dataclasses) + kwargs = self._standardize_kwargs(**kwargs) + + # Validate kwargs + self._validate_preprocess_kwargs(**kwargs) + + return self._preprocess_audio_like_inputs(audio, *args, **kwargs) + + def _preprocess_audio_like_inputs( + self, + audio: AudioInput, + *args, + sample_rate: int | None = None, + **kwargs: Unpack[AudioProcessingKwargs], + ) -> BatchFeature: + audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) + return self._preprocess(audio, *args, **kwargs) def _preprocess(self, *args, **kwargs): + raise NotImplementedError + + def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list: """ - Perform the actual batch audio preprocessing pipeline (stages 3-7). + Prepare audio-like inputs for processing by structuring and then converting each + audio item via `process_audio`. - Implemented by backend subclasses (e.g., `TorchAudioBackend`). Receives a list of - already-prepared audio tensors and applies the configured preprocessing operations. - Returns a `BatchFeature` with the processed audio values. + Analogous to `_prepare_image_like_inputs` in the image processing pipeline. """ - raise NotImplementedError + audio = self._prepare_audio_structure(audio, sample_rate=sample_rate) + audio = [self.process_audio(audio_el) for audio_el in audio] + return audio - def pad(self, *args, **kwargs): + def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = None) -> list: """ - Pad a single audio tensor to a target length. + Prepare the audio structure for processing: fetch URL inputs, validate sample rate, + and flatten into a list of audio arrays. - Implemented by backend subclasses (e.g., `TorchAudioBackend`). + Analogous to `_prepare_images_structure` in the image processing pipeline. """ - raise NotImplementedError + is_url_input = isinstance(audio, str) or ( + isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio) + ) - def pad_values(self, *args, **kwargs): + if is_url_input: + # URL inputs: load directly at the correct sample rate + audio = self.fetch_audio(audio) + else: + # Array inputs: validate that the user-provided sample rate matches the model's + if sample_rate is not None: + if sample_rate != self.sample_rate: + raise ValueError( + f"The model corresponding to this audio processor: {self.__class__.__name__} was trained using a" + f" sample rate of {self.sample_rate}. Please make sure that the provided `audio` input" + f" was sampled with {self.sample_rate} and not {sample_rate}." + ) + else: + logger.warning( + f"It is strongly recommended to pass the `sample_rate` argument to `{self.__class__.__name__}()`. " + "Failing to do so can result in silent errors that might be hard to debug." + ) + + audio = make_list_of_audio(audio) + return audio + + def _process_audio(self, *args, **kwargs): """ - Pad raw audio values to a target length (pipeline stage 3). + Process a single raw audio input into the backend's working format. - Implemented by backend subclasses. + Implemented by backend subclasses (e.g., `TorchAudioBackend`). Converts a raw input + (NumPy array) to the backend's internal format (e.g., `torch.Tensor`), handles + mono conversion if needed. """ raise NotImplementedError - def values_normalize(self, *args, **kwargs): + def process_audio(self, *args, **kwargs): + return self._process_audio(*args, **kwargs) + + def pad( + self, + audio: AudioInput, # TODO: this type makes it unclear to know the have an iterable + padding: bool | str | PaddingStrategy = True, + max_length: int | None = None, + truncation: bool = False, + pad_to_multiple_of: int | None = None, + ): + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + + if truncation: + if max_length is None: + # TODO: maybe this check should happen in the _validate_preprocess_kwargs method + raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") + trunc_length = max_length + if pad_to_multiple_of is not None and (trunc_length % pad_to_multiple_of != 0): + trunc_length = ((trunc_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + audio = [self._truncate_single(audio_el, max_length=trunc_length) for audio_el in audio] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(audio_el.shape[-1] for audio_el in audio) + padding_strategy = PaddingStrategy.MAX_LENGTH + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if padding_strategy != PaddingStrategy.DO_NOT_PAD: + audio = [self._pad_single(audio_el, max_length=max_length) for audio_el in audio] + + return audio + + def _truncate_single(self, audio_el, max_length: int): + """Truncate a single audio element to max_length along the time axis.""" + if audio_el.shape[-1] > max_length: + return audio_el[..., :max_length] + return audio_el + + def _pad_single(self, audio, max_length: int) -> AudioInput: """ - Normalize raw audio values (pipeline stage 4). + Pad a single input (on left/right) up to predefined length or max length in the batch. Implemented by backend subclasses. """ raise NotImplementedError - def extract_spectrogram(self, *args, **kwargs): + def extract_spectrogram(self, audio, *, do_batch_spectrogram: bool = True, spectrogram_config: SpectrogramConfig | None = None, **kwargs): """ - Extract spectrogram from audio (pipeline stage 5). + Both the numpy and torch backends implement this method in a batched/ sequential manner. + Is is batched by default, but can be set to be sequential. + This can extract just a spectrogram or a Mel spectrogram if a mel config is provided. - Implemented by model-specific processor subclasses. + Any extra kwargs whose names match ``SpectrogramConfig`` fields will + override the corresponding value on the config for this call. """ - raise NotImplementedError + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + + config_field_names = {f.name for f in fields(SpectrogramConfig)} + overrides = {k: kwargs.pop(k) for k in list(kwargs) if k in config_field_names} + if overrides: + spectrogram_config = replace(spectrogram_config, **overrides) + + if do_batch_spectrogram: + features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + if spectrogram_config.mel_scale_config is not None: + features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) + else: + features = [self._extract_spectrogram(audio_el, spectrogram_config=spectrogram_config, **kwargs) for audio_el in audio] + if spectrogram_config.mel_scale_config is not None: + features = [self._apply_mel_scale(feature_el, spectrogram_config=spectrogram_config, **kwargs) for feature_el in features] + return features - def feature_normalize(self, *args, **kwargs): + def _extract_spectrogram(self, *args, **kwargs): """ - Normalize extracted features (pipeline stage 6). + Compute the (power) spectrogram via STFT. - Implemented by backend subclasses. + Implemented by backend subclasses (e.g., ``TorchAudioBackend``). """ raise NotImplementedError - def pad_features(self, *args, **kwargs): + def _apply_mel_scale(self, *args, **kwargs): """ - Pad extracted features to a target length (pipeline stage 7). + Apply mel filterbank to a spectrogram. - Implemented by backend subclasses. + Implemented by backend subclasses (e.g., ``TorchAudioBackend``). """ raise NotImplementedError - def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: - """ - Preprocess an audio or a batch of audio. - """ - # Perform type validation on received kwargs - validate_typed_dict(self.valid_kwargs, kwargs) + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): + raise NotImplementedError - # Set default kwargs from self. - for kwarg_name in self._valid_kwargs_names: - kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None)) + def _get_padding_strategies(self, padding=False, max_length=None): + """Find the correct padding strategy.""" + if padding is not False: + if padding is True: + padding_strategy = PaddingStrategy.LONGEST + elif not isinstance(padding, PaddingStrategy): + padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding + else: + padding_strategy = PaddingStrategy.DO_NOT_PAD - # Standardize kwargs (coerce dicts to config dataclasses) - kwargs = self._standardize_kwargs(**kwargs) + if max_length is None: + if padding_strategy == PaddingStrategy.MAX_LENGTH: + raise ValueError( + f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined" + ) - # Validate kwargs - self._validate_preprocess_kwargs(**kwargs) + if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None): + raise ValueError( + "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use" + " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`." + ) - return self._preprocess_audio_like_inputs(audio, *args, **kwargs) + return padding_strategy def _standardize_kwargs( self, **kwargs, ) -> dict: """Coerce dict configs to their dataclass form.""" - if isinstance(kwargs.get("normalization_config"), dict): - kwargs["normalization_config"] = NormalizationConfig.from_dict(kwargs["normalization_config"]) if isinstance(kwargs.get("spectrogram_config"), dict): kwargs["spectrogram_config"] = SpectrogramConfig.from_dict( kwargs["spectrogram_config"] ) - if isinstance(kwargs.get("feature_normalization_config"), dict): - kwargs["feature_normalization_config"] = NormalizationConfig.from_dict( - kwargs["feature_normalization_config"] - ) + if kwargs.get("spectrogram_config") is not None and kwargs.get("do_extract_spectrogram") is None: + kwargs["do_extract_spectrogram"] = True return kwargs def _validate_preprocess_kwargs( @@ -224,83 +332,18 @@ def _validate_preprocess_kwargs( truncation: bool | None = None, pad_to_multiple_of: int | None = None, return_tensors: str | TensorType | None = None, - do_values_normalize: bool | None = None, - normalization_config: NormalizationConfig | None = None, - do_feature_normalize: bool | None = None, - feature_normalization_config: NormalizationConfig | None = None, **kwargs, ): """Validate the kwargs for the preprocess method.""" - if do_values_normalize and normalization_config is None: - raise ValueError( - "`do_values_normalize=True` requires `normalization_config` to be set." - ) - if do_feature_normalize and feature_normalization_config is None: - raise ValueError( - "`do_feature_normalize=True` requires `feature_normalization_config` to be set." - ) if truncation and max_length is None: raise ValueError( "When setting `truncation=True`, make sure that `max_length` is defined." - ) - - def _preprocess_audio_like_inputs( - self, - audio: AudioInput, - *args, - sample_rate: int | None = None, - **kwargs: Unpack[AudioProcessingKwargs], - ) -> BatchFeature: - audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) - return self._preprocess(audio, *args, **kwargs) - - def _prepare_audio_structure(self, audio: AudioInput, sample_rate: int | None = None) -> list: - """ - Prepare the audio structure for processing: fetch URL inputs, validate sample rate, - and flatten into a list of audio arrays. - - Analogous to `_prepare_images_structure` in the image processing pipeline. - """ - is_url_input = isinstance(audio, str) or ( - isinstance(audio, (list, tuple)) and all(isinstance(el, str) for el in audio) - ) - - if is_url_input: - # URL inputs: load directly at the correct sample rate - audio = self.fetch_audio(audio) - else: - # Array inputs: validate that the user-provided sample rate matches the model's - if sample_rate is not None: - if sample_rate != self.sample_rate: - raise ValueError( - f"The model corresponding to this audio processor: {self.__class__.__name__} was trained using a" - f" sample rate of {self.sample_rate}. Please make sure that the provided `audio` input" - f" was sampled with {self.sample_rate} and not {sample_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sample_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - audio = make_list_of_audio(audio) - return audio - - def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list: - """ - Prepare audio-like inputs for processing by structuring and then converting each - audio item via `process_audio`. - - Analogous to `_prepare_image_like_inputs` in the image processing pipeline. - """ - audio = self._prepare_audio_structure(audio, sample_rate=sample_rate) - audio = [self.process_audio(audio_el) for audio_el in audio] - return audio + ) def to_dict(self): output = super().to_dict() # Serialize config dataclasses to plain dicts for JSON persistence - for key in ("normalization_config", "spectrogram_config", "feature_normalization_config"): + for key in ("spectrogram_config",): if key in output and hasattr(output[key], "to_dict"): output[key] = output[key].to_dict() diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 5275e8a5b3e7..e71b87f33deb 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -117,7 +117,7 @@ class SpectrogramConfig: """Configuration for spectrogram extraction, composed of STFT and mel scale sub-configs.""" stft_config: StftConfig = field(default_factory=StftConfig) - mel_scale_config: MelScaleConfig = field(default_factory=MelScaleConfig) + mel_scale_config: MelScaleConfig | None = None log_mode: str = "log10" chunk_length: int | None = None global_log_mel_max: float | None = None @@ -155,7 +155,7 @@ def to_dict(self) -> dict: @classmethod def from_dict(cls, d: dict) -> "SpectrogramConfig": stft_config = StftConfig.from_dict(d["stft_config"]) if "stft_config" in d else StftConfig() - mel_scale_config = MelScaleConfig.from_dict(d["mel_scale_config"]) if "mel_scale_config" in d else MelScaleConfig() + mel_scale_config = MelScaleConfig.from_dict(d["mel_scale_config"]) if "mel_scale_config" in d else None return cls( stft_config=stft_config, mel_scale_config=mel_scale_config, @@ -169,41 +169,6 @@ def from_dict(cls, d: dict) -> "SpectrogramConfig": ) -@dataclass(frozen=True) -class NormalizationConfig: - """Hashable dictionary to store audio normalization configuration.""" - - method: str = "zero_mean_unit_var" - normalize_before_pad: bool = True - - def __getitem__(self, key): - if hasattr(self, key): - return getattr(self, key) - raise KeyError(f"Key {key} not found in NormalizationConfig.") - - def __iter__(self): - for f in fields(self): - val = getattr(self, f.name) - if val is not None: - yield f.name, val - - def __eq__(self, other): - if isinstance(other, dict): - return dict(self) == other - if isinstance(other, NormalizationConfig): - return tuple(getattr(self, f.name) for f in fields(self)) == tuple( - getattr(other, f.name) for f in fields(self) - ) - return NotImplemented - - def to_dict(self) -> dict: - return dict(self) - - @classmethod - def from_dict(cls, d: dict) -> "NormalizationConfig": - valid_keys = {f.name for f in fields(cls)} - return cls(**{k: v for k, v in d.items() if k in valid_keys}) - def load_audio(audio: str | np.ndarray, sampling_rate=16000, timeout=None) -> np.ndarray: """ diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index f4bc248465b2..85ba48c1d06f 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -22,8 +22,10 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True + do_extract_spectrogram = True + max_length_frames = 1024 - transpose_features = True + do_normalize = True # AudioSet normalization constants ast_mean = -4.2677393 @@ -52,11 +54,10 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): mel_floor=1.192092955078125e-07, ) - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Extract spectrogram via generic config-based API - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + def extract_spectrogram(self, audio, **kwargs): + features = super().extract_spectrogram(audio, **kwargs) - # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) + # (n_mels, frames) -> (frames, n_mels) features = [f.T for f in features] # Pad or truncate to max_length_frames @@ -71,12 +72,13 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of padded.append(fbank) # Normalize with AudioSet stats - normalized = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] + return [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] + + def _preprocess(self, audio, **kwargs): + output = super()._preprocess(audio, **kwargs) + # TODO: it is wrongly named input_values in the original feature extractor + return BatchFeature({"audio_values": output["audio_features"]}) - # Stack into batch - output_key = self.model_input_names[0] - stacked = np.stack(normalized, axis=0) if return_tensors else normalized - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) __all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index 27d9a0898f3a..87dd86f1cec4 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -13,11 +13,11 @@ # limitations under the License. import math -from collections.abc import Sequence import numpy as np from ...audio_processing_backends import NumpyAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature @@ -56,112 +56,136 @@ def _unfold(array, dimension, size, step): class Gemma3nAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True - frame_length = 512 # 32ms at 16kHz - hop_length = 160 # 10ms at 16kHz - n_mels = 128 - min_frequency = 125.0 - max_frequency = 7600.0 - preemphasis_coeff = 0.97 - preemphasis_htk_flavor = True - fft_overdrive = True - mel_floor = 1e-5 max_length = 480000 # 30 seconds truncation = True pad_to_multiple_of = 128 + preemphasis_htk_flavor = True + + # n_fft = 1024 (512 frame_length → next power of 2 → 512 → ×2 fft_overdrive) + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=1024, + win_length=512, + hop_length=160, + power=1.0, + center=False, + ), + mel_scale_config=MelScaleConfig( + n_mels=128, + f_min=125.0, + f_max=7600.0, + mel_scale="htk", + ), + mel_floor=1e-5, + log_mode="log", + preemphasis=0.97, + ) def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): super().__init__(**kwargs) - fft_length = 2 ** math.ceil(math.log2(self.frame_length)) - if self.fft_overdrive: - fft_length *= 2 - self.fft_length = fft_length - - hann_arange = np.arange(self.frame_length, dtype=np.float32) - self.window = (0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length))).astype(np.float32) - - self.mel_filters = _create_fb_matrix( - n_freqs=self.fft_length // 2 + 1, - f_min=self.min_frequency, - f_max=self.max_frequency, - n_mels=self.n_mels, - sample_rate=self.sample_rate, - fft_length=self.fft_length, - norm=None, - ) + # Pre-compute window from stft_config + win_length = self.spectrogram_config.stft_config.win_length + hann_arange = np.arange(win_length, dtype=np.float32) + self.window = (0.5 * (1 - np.cos(2 * np.pi * hann_arange / win_length))).astype(np.float32) + n_mels = self.spectrogram_config.mel_scale_config.n_mels if per_bin_mean is not None: - self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, self.n_mels) + self.per_bin_mean = np.array(per_bin_mean).reshape(1, n_mels) else: self.per_bin_mean = None if per_bin_stddev is not None: - self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, self.n_mels) + self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, n_mels) else: self.per_bin_stddev = None - def _extract_spectrogram(self, waveform): - if waveform.ndim == 1: - waveform = np.expand_dims(waveform, axis=0) - - frame_size_for_unfold = self.frame_length + 1 - frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length) + def _mel_filter_bank(self, spectrogram_config): + """Custom HTK-style mel filterbank matching the original Gemma3n FE.""" + sc = spectrogram_config + msc = sc.mel_scale_config + return _create_fb_matrix( + n_freqs=sc.stft_config.n_fft // 2 + 1, + f_min=msc.f_min, + f_max=msc.f_max, + n_mels=msc.n_mels, + sample_rate=self.sample_rate, + fft_length=sc.stft_config.n_fft, + ) - if self.preemphasis_coeff > 0.0: - if self.preemphasis_htk_flavor: - first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis_coeff) - rest_in_frame = ( - frames_to_process[..., 1:-1] - self.preemphasis_coeff * frames_to_process[..., :-2] - ) - frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): + """Custom STFT with HTK-flavor preemphasis.""" + stft_cfg = spectrogram_config.stft_config + preemphasis = spectrogram_config.preemphasis + + features = [] + for waveform in audio: + if waveform.ndim == 1: + waveform = np.expand_dims(waveform, axis=0) + + frame_size_for_unfold = stft_cfg.win_length + 1 + frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length) + + if preemphasis is not None and preemphasis > 0.0: + if self.preemphasis_htk_flavor: + first_in_frame = frames_to_process[..., :1] * (1.0 - preemphasis) + rest_in_frame = ( + frames_to_process[..., 1:-1] - preemphasis * frames_to_process[..., :-2] + ) + frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + else: + frames = frames_to_process[..., 1:] - preemphasis * frames_to_process[..., :-1] else: - frames = frames_to_process[..., 1:] - self.preemphasis_coeff * frames_to_process[..., :-1] - else: - frames = frames_to_process[..., :-1] + frames = frames_to_process[..., :-1] - frames = frames * self.window - stft = np.fft.rfft(frames, n=self.fft_length, axis=-1) - magnitude_spec = np.abs(stft) + frames = frames * self.window + stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) + magnitude_spec = np.abs(stft) + features.append(magnitude_spec.squeeze(0)) # (frames, n_freqs) - mel_spec = np.matmul(magnitude_spec, self.mel_filters) - log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor)) + return features - if self.per_bin_mean is not None: - log_mel_spec = log_mel_spec - self.per_bin_mean - if self.per_bin_stddev is not None: - log_mel_spec = log_mel_spec / self.per_bin_stddev + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + """Apply mel filterbank, log compression, and per-bin normalization.""" + result = [] + for mag_spec in features: + mel_spec = np.matmul(mag_spec, self.mel_filters) + log_mel_spec = np.log(np.maximum(mel_spec, spectrogram_config.mel_floor)) - return log_mel_spec.squeeze(0) # (frames, n_mels) + if self.per_bin_mean is not None: + log_mel_spec = log_mel_spec - self.per_bin_mean + if self.per_bin_stddev is not None: + log_mel_spec = log_mel_spec / self.per_bin_stddev - def extract_spectrogram(self, audio, *, spectrogram_config): - return [self._extract_spectrogram(waveform) for waveform in audio] + result.append(log_mel_spec.astype(np.float32)) + return result - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Use class defaults for max_length, truncation, pad_to_multiple_of if not overridden + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, + spectrogram_config=None, do_extract_spectrogram=None, **kwargs): if max_length is None: max_length = self.max_length if truncation is None: truncation = self.truncation if pad_to_multiple_of is None: pad_to_multiple_of = self.pad_to_multiple_of + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config - # Truncate first (separate from padding, matching FE behavior) + # Truncate then pad to longest in batch (matching FE "longest" padding strategy) if truncation and max_length is not None: audio = [a[..., :max_length] for a in audio] - # Pad to longest in batch (matching FE "longest" padding strategy) pad_length = max(a.shape[-1] for a in audio) if pad_to_multiple_of is not None and (pad_length % pad_to_multiple_of != 0): pad_length = ((pad_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - audio = [self.pad(a, pad_length) for a in audio] - - # Extract spectrogram - features = self.extract_spectrogram(audio, spectrogram_config=None) + audio = [self._pad_single(a, pad_length) for a in audio] - # Cast to float32 to match FE output - features = [f.astype(np.float32) for f in features] + # Extract spectrogram via orchestrator (_extract_spectrogram + _apply_mel_scale) + if do_extract_spectrogram is not False and spectrogram_config is not None: + features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + else: + features = audio - # Stack and return output_key = self.model_input_names[0] stacked = np.stack(features, axis=0) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index d83bce115644..c7bcd8f3cd05 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -16,6 +16,7 @@ import torch from ...audio_processing_backends import TorchAudioBackend +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature LOG_ZERO_GUARD_VALUE = 2**-24 @@ -31,89 +32,35 @@ class ParakeetAudioProcessor(TorchAudioBackend): win_length = 400 n_mels = 80 - def __init__(self, **kwargs): - super().__init__(**kwargs) - # Use librosa for mel filters to match the FeatureExtractor exactly - # (mel_filter_bank uses float64 internally, causing numerical differences) - mel_filters = librosa.filters.mel( - sr=self.sample_rate, - n_fft=self.n_fft, - n_mels=self.n_mels, - fmin=0.0, - fmax=self.sample_rate / 2, - norm="slaney", - ) - self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) - - def _torch_extract_fbank_features(self, waveform, device="cpu"): - """Extract log-mel spectrogram features, matching the FE implementation.""" - window = torch.hann_window(self.win_length, periodic=False, device=device) - stft = torch.stft( - waveform, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=window, - return_complex=True, + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + hop_length=160, + win_length=400, + window_fn="hann_window", + periodic=False, pad_mode="constant", + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=0.0, + norm="slaney", + ), + preemphasis=0.97, + ) + + def _mel_filter_bank(self, spectrogram_config): + """Use librosa for mel filters to match the FeatureExtractor exactly + (mel_filter_bank uses float64 internally, causing numerical differences).""" + msc = spectrogram_config.mel_scale_config + return librosa.filters.mel( + sr=self.sample_rate, + n_fft=spectrogram_config.stft_config.n_fft, + n_mels=msc.n_mels, + fmin=msc.f_min, + fmax=msc.f_max if msc.f_max is not None else self.sample_rate / 2, + norm=msc.norm, ) - # Match original implementation: view_as_real then sqrt(sum of squares) - magnitudes = torch.view_as_real(stft) - magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) - magnitudes = magnitudes.pow(2) - - # Log mel spectrogram - mel_filters = self.mel_filters.to(device) - mel_spec = mel_filters @ magnitudes - mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) - - # (batch, n_mels, frames) -> (batch, frames, n_mels) - mel_spec = mel_spec.permute(0, 2, 1) - return mel_spec - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - device = "cpu" - - # Record original audio lengths before padding - audio_lengths = torch.tensor([a.shape[-1] for a in audio]) - - # Pad values to longest - audio = self.pad_values(audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of) - - # Stack into batch tensor - waveform = torch.stack(audio, dim=0).to(torch.float32) - - # Preemphasis (mask-aware, matching FE) - if self.preemphasis is not None: - timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) < audio_lengths.unsqueeze(1) - waveform = torch.cat( - [waveform[:, :1], waveform[:, 1:] - self.preemphasis * waveform[:, :-1]], dim=1 - ) - waveform = waveform.masked_fill(~timemask, 0.0) - - # Extract log-mel spectrogram - input_features = self._torch_extract_fbank_features(waveform, device) - - # Compute feature lengths (matching FE formula) - features_lengths = torch.floor_divide( - audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length - ) - - # Build attention mask over feature frames - attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None] - - # Mask-aware normalization (matching FE exactly) - mask = attention_mask.unsqueeze(-1) - input_features_masked = input_features * mask - mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1) - mean = mean.unsqueeze(1) - variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) - std = torch.sqrt(variance).unsqueeze(1) - input_features = (input_features - mean) / (std + EPSILON) - input_features *= mask - - output_key = self.model_input_names[0] - return BatchFeature(data={output_key: input_features}, tensor_type=return_tensors) - __all__ = ["ParakeetAudioProcessor"] diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index 3a3ef16b5cbd..127f595cf3e8 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -48,7 +48,7 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): waveform_scale=32768.0, ) - def feature_normalize(self, features, *, feature_normalization_config): + def feature_normalize(self, features): # Per-mel-bin normalization with ddof=1 for variance normalized = [] for f in features: @@ -65,7 +65,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of features = [f.T for f in features] # Per-mel-bin normalization - features = self.feature_normalize(features, feature_normalization_config=None) + features = self.feature_normalize(features) # Pad features to longest (pad_to_multiple_of=2 for stride) max_len = max(f.shape[0] for f in features) diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 1d342166a432..28049d0eccc3 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -13,7 +13,6 @@ # limitations under the License. from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import NormalizationConfig class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): @@ -21,27 +20,25 @@ class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): force_mono = True add_channel_dim = True do_values_normalize = True - normalization_config = NormalizationConfig(method="rms_normalize", normalize_before_pad=True) + normalize_before_pad = True def __init__(self, target_dB_FS=-25, eps=1e-6, **kwargs): self.target_dB_FS = target_dB_FS self.eps = eps super().__init__(**kwargs) - def values_normalize(self, audio, *, normalization_config): + def values_normalize(self, audio): import torch - if normalization_config.method == "rms_normalize": - normalized = [] - for a in audio: - rms = torch.sqrt(torch.mean(a**2)) - a = a * (10 ** (self.target_dB_FS / 20) / (rms + self.eps)) - max_val = torch.max(torch.abs(a)) - if max_val > 1.0: - a = a / (max_val + self.eps) - normalized.append(a) - return normalized - return super().values_normalize(audio, normalization_config=normalization_config) + normalized = [] + for a in audio: + rms = torch.sqrt(torch.mean(a**2)) + a = a * (10 ** (self.target_dB_FS / 20) / (rms + self.eps)) + max_val = torch.max(torch.abs(a)) + if max_val > 1.0: + a = a / (max_val + self.eps) + normalized.append(a) + return normalized __all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py index a5258961c7c3..66467620f39d 100644 --- a/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py +++ b/src/transformers/models/wav2vec2/audio_processing_wav2vec2.py @@ -12,15 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import NormalizationConfig class Wav2Vec2AudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True - do_values_normalize = True - normalization_config = NormalizationConfig(method="zero_mean_unit_var") + do_normalize = True + + def _process_audio(self, audio_el): + audio_el = super()._process_audio(audio_el) + + if self.do_normalize: + audio_el = (audio_el - audio_el.mean()) / torch.sqrt(audio_el.var(correction=0) + 1e-7) + + return audio_el __all__ = ["Wav2Vec2AudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 2536fc75f9d1..a120cb4790f7 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -33,8 +33,22 @@ class WhisperAudioProcessor(TorchAudioBackend): norm="slaney", ), log_mode="log10", - chunk_length=30, ) + def extract_spectrogram(self, audio, **kwargs): + import torch + + features = super().extract_spectrogram(audio, **kwargs) + spectrogram_config = kwargs.get("spectrogram_config", self.spectrogram_config) + mel_floor = spectrogram_config.mel_floor + processed = [] + for spec in features: + log_spec = torch.clamp(spec, min=mel_floor).log10() + max_val = log_spec.max() + log_spec = torch.maximum(log_spec, max_val - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + processed.append(log_spec) + return processed + __all__ = ["WhisperAudioProcessor"] From 0decb663ff870c124df08a7ab95406461fd3f8dc Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 10 Mar 2026 10:58:44 +0100 Subject: [PATCH 07/28] update --- src/transformers/audio_processing_backends.py | 111 +++++++++++++++++- src/transformers/audio_processing_utils.py | 14 ++- ...rocessing_audio_spectrogram_transformer.py | 16 ++- .../models/clap/audio_processing_clap.py | 4 +- .../models/clvp/audio_processing_clvp.py | 4 +- .../models/dac/audio_processing_dac.py | 12 ++ .../models/dia/audio_processing_dia.py | 14 +++ .../encodec/audio_processing_encodec.py | 12 ++ .../gemma3n/audio_processing_gemma3n.py | 22 +++- .../audio_processing_granite_speech.py | 45 +------ .../audio_processing_kyutai_speech_to_text.py | 35 ++++-- .../models/lasr/audio_processing_lasr.py | 2 +- .../audio_processing_musicgen_melody.py | 6 +- .../parakeet/audio_processing_parakeet.py | 98 +++++++++++----- .../audio_processing_phi4_multimodal.py | 110 ++++++++++------- .../audio_processing_seamless_m4t.py | 2 +- .../audio_processing_speech_to_text.py | 2 +- .../univnet/audio_processing_univnet.py | 6 +- ...processing_vibevoice_acoustic_tokenizer.py | 35 +++--- ...extraction_vibevoice_acoustic_tokenizer.py | 2 +- .../audio_processing_voxtral_realtime.py | 22 ++++ .../whisper/audio_processing_whisper.py | 11 +- 22 files changed, 395 insertions(+), 190 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index aa88845a5bb3..842e8f55dab1 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -17,7 +17,7 @@ import numpy as np from .audio_processing_utils import BaseAudioProcessor -from .audio_utils import SpectrogramConfig, mel_filter_bank +from .audio_utils import SpectrogramConfig, amplitude_to_db, mel_filter_bank, power_to_db from .feature_extraction_utils import BatchFeature from .utils import PaddingStrategy, TensorType, is_torch_available, is_torch_tensor, logging, to_numpy @@ -115,7 +115,7 @@ def _extract_spectrogram( remove_dc_offset=spectrogram_config.remove_dc_offset, mel_filters=None, mel_floor=spectrogram_config.mel_floor, - log_mel=spectrogram_config.log_mode if spectrogram_config.log_mode != "log10" else "log10", + log_mel=None, ) features.append(spec) @@ -138,6 +138,47 @@ def _apply_mel_scale( mel_filters = self.mel_filters return [mel_filters.T @ spec for spec in features] + def _normalize_magnitude( + self, + features: list[np.ndarray], + *, + spectrogram_config: SpectrogramConfig, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: float | None = None, + dtype: np.dtype = np.float32, + **kwargs, + ) -> list[np.ndarray]: + """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. + + Mirrors the normalization logic in `audio_utils.spectrogram()`. + """ + log_mel = spectrogram_config.log_mode + mel_floor = spectrogram_config.mel_floor + power = spectrogram_config.stft_config.power + + if log_mel is None: + return features + + # Clamp to mel_floor before taking log + result = [np.maximum(mel_floor, spec) for spec in features] + + if log_mel == "log": + result = [np.log(spec).astype(dtype) for spec in result] + elif log_mel == "log10": + result = [np.log10(spec).astype(dtype) for spec in result] + elif log_mel == "dB": + if power == 1.0: + result = [amplitude_to_db(spec, reference, min_value, db_range).astype(dtype) for spec in result] + elif power == 2.0: + result = [power_to_db(spec, reference, min_value, db_range).astype(dtype) for spec in result] + else: + raise ValueError(f"Cannot use log_mel option 'dB' with power {power}") + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + return result + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config @@ -252,9 +293,10 @@ def _extract_spectrogram( if spectrogram_config.preemphasis is not None: audio_ranges = kwargs.get("audio_ranges", None) - timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) - timemask = timemask < audio_ranges.unsqueeze(1) - waveform = waveform.masked_fill(~timemask, 0.0) + if audio_ranges is not None: + timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) + timemask = timemask < audio_ranges.unsqueeze(1) + waveform = waveform.masked_fill(~timemask, 0.0) window_fn = getattr(torch, stft_cfg.window_fn, torch.hann_window) window = window_fn(win_length, periodic=stft_cfg.periodic, device=device) @@ -295,6 +337,65 @@ def _apply_mel_scale( mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) return [mel_filters.T @ spec for spec in features] + def _normalize_magnitude( + self, + features: list["torch.Tensor"], + *, + spectrogram_config: SpectrogramConfig, + reference: float = 1.0, + min_value: float = 1e-10, + db_range: float | None = None, + dtype: "torch.dtype | None" = None, + **kwargs, + ) -> list["torch.Tensor"]: + """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. + + Mirrors the normalization logic in `audio_utils.spectrogram()`. + """ + import torch + + log_mel = spectrogram_config.log_mode + mel_floor = spectrogram_config.mel_floor + power = spectrogram_config.stft_config.power + + if dtype is None: + dtype = torch.float32 + + if log_mel is None: + return features + + # Clamp to mel_floor before taking log + result = [torch.clamp(spec, min=mel_floor) for spec in features] + + if log_mel == "log": + result = [torch.log(spec).to(dtype) for spec in result] + elif log_mel == "log10": + result = [torch.log10(spec).to(dtype) for spec in result] + elif log_mel == "dB": + if reference <= 0.0: + raise ValueError("reference must be greater than zero") + if min_value <= 0.0: + raise ValueError("min_value must be greater than zero") + reference = max(min_value, reference) + multiplier = 10.0 if power == 2.0 else 20.0 if power == 1.0 else None + if multiplier is None: + raise ValueError(f"Cannot use log_mel option 'dB' with power {power}") + log_ref = np.log10(reference) + processed = [] + for spec in result: + spec = torch.clamp(spec, min=min_value) + spec = multiplier * (torch.log10(spec) - log_ref) + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + spec = torch.clamp(spec, min=spec.max() - db_range) + processed.append(spec.to(dtype)) + result = processed + else: + raise ValueError(f"Unknown log_mel option: {log_mel}") + + return result + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 8226e573cddd..6666dd7346ab 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -31,10 +31,8 @@ class AudioProcessingKwargs(AudioKwargs, total=False): """Extended keyword arguments for the audio processing pipeline.""" do_pad_values: bool | None - do_values_normalize: bool | None spectrogram_config: dict | SpectrogramConfig | None do_extract_spectrogram: bool | None - do_feature_normalize: bool | None do_pad_features: bool | None do_resample: bool | None @@ -55,8 +53,6 @@ class BaseAudioProcessor(AudioProcessingMixin): force_mono: bool = None # Pipeline stage defaults - do_pad_values = None - do_values_normalize = None normalize_before_pad = True spectrogram_config = None do_extract_spectrogram = None @@ -261,10 +257,12 @@ def extract_spectrogram(self, audio, *, do_batch_spectrogram: bool = True, spect features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) if spectrogram_config.mel_scale_config is not None: features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) + features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) else: features = [self._extract_spectrogram(audio_el, spectrogram_config=spectrogram_config, **kwargs) for audio_el in audio] if spectrogram_config.mel_scale_config is not None: features = [self._apply_mel_scale(feature_el, spectrogram_config=spectrogram_config, **kwargs) for feature_el in features] + features = [self._normalize_magnitude(feature_el, spectrogram_config=spectrogram_config, **kwargs) for feature_el in features] return features def _extract_spectrogram(self, *args, **kwargs): @@ -283,6 +281,14 @@ def _apply_mel_scale(self, *args, **kwargs): """ raise NotImplementedError + def _normalize_magnitude(self, *args, **kwargs): + """ + Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. + + Implemented by backend subclasses (e.g., ``TorchAudioBackend``). + """ + raise NotImplementedError + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): raise NotImplementedError diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index 85ba48c1d06f..b20db82fd2ee 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -22,7 +22,6 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True - do_extract_spectrogram = True max_length_frames = 1024 do_normalize = True @@ -54,8 +53,9 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): mel_floor=1.192092955078125e-07, ) - def extract_spectrogram(self, audio, **kwargs): - features = super().extract_spectrogram(audio, **kwargs) + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Compute spectrogram per-sample (no audio padding beforehand) + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) # (n_mels, frames) -> (frames, n_mels) features = [f.T for f in features] @@ -72,13 +72,11 @@ def extract_spectrogram(self, audio, **kwargs): padded.append(fbank) # Normalize with AudioSet stats - return [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] - - def _preprocess(self, audio, **kwargs): - output = super()._preprocess(audio, **kwargs) - # TODO: it is wrongly named input_values in the original feature extractor - return BatchFeature({"audio_values": output["audio_features"]}) + if self.do_normalize: + padded = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] + stacked = np.stack(padded, axis=0) + return BatchFeature({"audio_values": stacked}, tensor_type=return_tensors) __all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index 4672057a1530..5119470489c8 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -29,7 +29,7 @@ class ClapAudioProcessor(NumpyAudioBackend): f_min = 0 f_max = 14000 max_length_s = 10 - truncation_mode = "fusion" # "fusion" or "rand_trunc" + truncation_mode = "rand_trunc" # "fusion" or "rand_trunc" padding_mode = "repeatpad" # "repeatpad", "repeat", or "pad" def __init__(self, **kwargs): @@ -150,7 +150,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of is_longer = [[longer] for longer in is_longer] - input_features = {"input_features": input_mel, "is_longer": is_longer} + input_features = {"audio_features": input_mel, "is_longer": is_longer} input_features = BatchFeature(input_features) if return_tensors is not None: diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index fdf9810a1ac7..082e049a9cca 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -72,7 +72,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of pad_length = max(a.shape[-1] for a in audio) else: pad_length = max_length - audio = self.pad_values(audio, max_length=pad_length, truncation=False, pad_to_multiple_of=pad_to_multiple_of) + audio = self.pad(audio, padding=True, max_length=pad_length) # Extract spectrogram via config-based API (with mel_norms applied) features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) @@ -80,7 +80,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of # Cast to float32 to match the legacy FeatureExtractor features = [f.astype(np.float32) for f in features] - output_key = self.model_input_names[0] + output_key = "audio_features" stacked = np.stack(features, axis=0) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py index f0a27bd57555..077f00ea3697 100644 --- a/src/transformers/models/dac/audio_processing_dac.py +++ b/src/transformers/models/dac/audio_processing_dac.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ...audio_processing_backends import NumpyAudioBackend +from ...feature_extraction_utils import BatchFeature class DacAudioProcessor(NumpyAudioBackend): @@ -20,5 +23,14 @@ class DacAudioProcessor(NumpyAudioBackend): force_mono = True add_channel_dim = True + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + lengths = [a.shape[-1] for a in audio] + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = max(a.shape[-1] for a in audio) + padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) + stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) + return output + __all__ = ["DacAudioProcessor"] diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py index e1b7b0301e71..ef1a0b38c6d0 100644 --- a/src/transformers/models/dia/audio_processing_dia.py +++ b/src/transformers/models/dia/audio_processing_dia.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ...audio_processing_backends import NumpyAudioBackend +from ...feature_extraction_utils import BatchFeature class DiaAudioProcessor(NumpyAudioBackend): @@ -21,5 +24,16 @@ class DiaAudioProcessor(NumpyAudioBackend): add_channel_dim = True pad_to_multiple_of = 512 + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + if pad_to_multiple_of is None: + pad_to_multiple_of = self.pad_to_multiple_of + lengths = [a.shape[-1] for a in audio] + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = max(a.shape[-1] for a in audio) + padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) + stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) + return output + __all__ = ["DiaAudioProcessor"] diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py index 022a7e145313..89376fbe7d5b 100644 --- a/src/transformers/models/encodec/audio_processing_encodec.py +++ b/src/transformers/models/encodec/audio_processing_encodec.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ...audio_processing_backends import NumpyAudioBackend +from ...feature_extraction_utils import BatchFeature class EncodecAudioProcessor(NumpyAudioBackend): @@ -20,5 +23,14 @@ class EncodecAudioProcessor(NumpyAudioBackend): force_mono = True add_channel_dim = True + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + lengths = [a.shape[-1] for a in audio] + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = max(a.shape[-1] for a in audio) + padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) + stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) + return output + __all__ = ["EncodecAudioProcessor"] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index 87dd86f1cec4..d24bb7bb9e6f 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -175,6 +175,9 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of if truncation and max_length is not None: audio = [a[..., :max_length] for a in audio] + # Record original audio lengths before padding (for computing the frame mask) + original_lengths = [a.shape[-1] for a in audio] + pad_length = max(a.shape[-1] for a in audio) if pad_to_multiple_of is not None and (pad_length % pad_to_multiple_of != 0): pad_length = ((pad_length // pad_to_multiple_of) + 1) * pad_to_multiple_of @@ -186,9 +189,24 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of else: features = audio - output_key = self.model_input_names[0] + output_key = "audio_features" stacked = np.stack(features, axis=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + # Compute per-sample spectrogram frame counts from original (pre-padding) audio lengths + frame_size_for_unfold = spectrogram_config.stft_config.win_length + 1 # 513 + hop_length = spectrogram_config.stft_config.hop_length # 160 + frame_counts = [(orig_len - frame_size_for_unfold) // hop_length + 1 for orig_len in original_lengths] + + # Build mask: 1 for real frames, 0 for padded frames + max_frames = stacked.shape[1] + input_features_mask = np.array( + [[1] * fc + [0] * (max_frames - fc) for fc in frame_counts], dtype=np.int32 + ) + + return BatchFeature( + data={output_key: stacked, "input_features_mask": input_features_mask}, + tensor_type=return_tensors, + ) __all__ = ["Gemma3nAudioProcessor"] diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index a567b8d4f8fe..4e250c535b62 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -23,6 +23,7 @@ class GraniteSpeechAudioProcessor(TorchAudioBackend): spectrogram_config = SpectrogramConfig( stft_config=StftConfig( n_fft=512, + win_length=400, hop_length=160, power=2.0, ), @@ -31,49 +32,7 @@ class GraniteSpeechAudioProcessor(TorchAudioBackend): ), log_mode="log10", ) - - def extract_spectrogram(self, audio, *, spectrogram_config): - import torch - - # Use parent's extract_spectrogram for basic mel spectrogram - # Parent returns list of (n_mels, frames) tensors with log10 + (x+4)/4 normalization - features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config) - - # Transpose each: (n_mels, frames) -> (frames, n_mels) - features = [f.permute(1, 0) for f in features] - - # Remove last frame if odd - features = [f[:-1] if f.shape[0] % 2 == 1 else f for f in features] - - # Frame stacking: (frames, n_mels) -> (frames//2, 2*n_mels) - features = [f.reshape(-1, 2 * f.shape[-1]) for f in features] - - return features - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - import torch - - # Pad raw audio values - if padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) - - # Extract spectrogram with frame stacking - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - - # Pad features to same length - max_feat_len = max(f.shape[0] for f in features) - padded = [] - for f in features: - if f.shape[0] < max_feat_len: - pad_amount = max_feat_len - f.shape[0] - f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) - padded.append(f) - - output_key = self.model_input_names[0] - stacked = torch.stack(padded, dim=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + __all__ = ["GraniteSpeechAudioProcessor"] diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py index 8f3e3f314b9c..fcee9eee0313 100644 --- a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py @@ -15,6 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend +from ...feature_extraction_utils import BatchFeature class KyutaiSpeechToTextAudioProcessor(NumpyAudioBackend): @@ -28,26 +29,34 @@ def __init__(self, audio_delay_seconds=2.5, audio_silence_prefix_seconds=1.0, ** super().__init__(**kwargs) def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - result = super()._preprocess( - audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs - ) + # Track lengths for padding_mask + lengths = [a.shape[-1] for a in audio] + # Pad audio to batch longest + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = max(a.shape[-1] for a in audio) + + # Create padding_mask (1 for real audio, 0 for padding) + padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) + + # Stack audio with channel dim + stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + + # Add silence prefix (left) and delay (right) padding pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate) pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate) if pad_left > 0 or pad_right > 0: - output_key = self.model_input_names[0] - data = result[output_key] - - if isinstance(data, np.ndarray): - pad_width = [(0, 0)] * (data.ndim - 1) + [(pad_left, pad_right)] - result[output_key] = np.pad(data, pad_width, mode="constant", constant_values=0.0) - else: - import torch.nn.functional as F + # Pad audio + audio_pad_width = [(0, 0), (0, 0), (pad_left, pad_right)] + stacked = np.pad(stacked, audio_pad_width, mode="constant", constant_values=0.0) - result[output_key] = F.pad(data, (pad_left, pad_right), mode="constant", value=0.0) + # Pad padding_mask + mask_pad_width = [(0, 0), (pad_left, pad_right)] + padding_mask = np.pad(padding_mask, mask_pad_width, mode="constant", constant_values=0) - return result + output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) + return output __all__ = ["KyutaiSpeechToTextAudioProcessor"] diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index f12e3086b39f..7df73470e135 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -62,7 +62,7 @@ def __init__(self, **kwargs): upper_edge_hertz=7500.0, ) - def extract_spectrogram(self, audio, *, spectrogram_config): + def extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): import torch stft_cfg = spectrogram_config.stft_config diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py index 62005e416256..bdeebd2f55f8 100644 --- a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py @@ -82,9 +82,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of # Pad raw audio if padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) + audio = self.pad(audio, padding=True, max_length=max_length) # Extract chroma features features = self.extract_spectrogram(audio, spectrogram_config=None) @@ -98,7 +96,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) padded.append(f) - output_key = self.model_input_names[0] + output_key = "audio_features" stacked = torch.stack(padded, dim=0) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index c7bcd8f3cd05..c93df2e8d0e7 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -16,7 +16,6 @@ import torch from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature LOG_ZERO_GUARD_VALUE = 2**-24 @@ -32,35 +31,76 @@ class ParakeetAudioProcessor(TorchAudioBackend): win_length = 400 n_mels = 80 - spectrogram_config = SpectrogramConfig( - stft_config=StftConfig( - n_fft=512, - hop_length=160, - win_length=400, - window_fn="hann_window", - periodic=False, - pad_mode="constant", - power=2.0, - ), - mel_scale_config=MelScaleConfig( - n_mels=80, - f_min=0.0, - norm="slaney", - ), - preemphasis=0.97, - ) - - def _mel_filter_bank(self, spectrogram_config): - """Use librosa for mel filters to match the FeatureExtractor exactly - (mel_filter_bank uses float64 internally, causing numerical differences).""" - msc = spectrogram_config.mel_scale_config - return librosa.filters.mel( + def __init__(self, **kwargs): + super().__init__(**kwargs) + mel_filters = librosa.filters.mel( sr=self.sample_rate, - n_fft=spectrogram_config.stft_config.n_fft, - n_mels=msc.n_mels, - fmin=msc.f_min, - fmax=msc.f_max if msc.f_max is not None else self.sample_rate / 2, - norm=msc.norm, + n_fft=self.n_fft, + n_mels=self.n_mels, + fmin=0.0, + fmax=self.sample_rate / 2, + norm="slaney", + ) + self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Pad raw audio + lengths = [a.shape[-1] for a in audio] + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + + # Stack into batch + waveform = torch.stack(audio) # (batch, length) + audio_lengths = torch.tensor(lengths) + + # Preemphasis with masking for padded regions + if self.preemphasis is not None: + timemask = torch.arange(waveform.shape[1]).unsqueeze(0) < audio_lengths.unsqueeze(1) + waveform = torch.cat( + [waveform[:, :1], waveform[:, 1:] - self.preemphasis * waveform[:, :-1]], dim=1 + ) + waveform = waveform.masked_fill(~timemask, 0.0) + + # STFT + window = torch.hann_window(self.win_length, periodic=False) + stft = torch.stft( + waveform, + self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=window, + return_complex=True, + pad_mode="constant", + ) + # Match FE: view_as_real -> pow(2).sum(-1).sqrt().pow(2) + magnitudes = torch.view_as_real(stft) + magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) + magnitudes = magnitudes.pow(2) + + # Mel spectrogram + log + mel_spec = self.mel_filters @ magnitudes + mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) + + # (batch, mels, frames) -> (batch, frames, mels) + mel_spec = mel_spec.permute(0, 2, 1) + + # Per-utterance normalization + features_lengths = torch.floor_divide( + audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length + ) + attention_mask = torch.arange(mel_spec.shape[1])[None, :] < features_lengths[:, None] + mask = attention_mask.unsqueeze(-1) + mel_masked = mel_spec * mask + mean = mel_masked.sum(dim=1) / features_lengths.unsqueeze(-1) + mean = mean.unsqueeze(1) + variance = ((mel_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) + std = torch.sqrt(variance).unsqueeze(1) + mel_spec = (mel_spec - mean) / (std + EPSILON) + mel_spec *= mask + + return BatchFeature( + data={"audio_features": mel_spec}, + tensor_type=return_tensors, ) + __all__ = ["ParakeetAudioProcessor"] diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index 01781c69a5db..f7b48a48823c 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -25,42 +25,45 @@ class Phi4MultimodalAudioProcessor(TorchAudioBackend): hop_length = 160 win_length = 400 n_mels = 80 + mel_min_frequency = 0 + mel_max_frequency = 7690 + audio_compression_rate = 8 + audio_downsample_rate = 1 + audio_feat_stride = 1 def __init__(self, **kwargs): super().__init__(**kwargs) self.mel_filters = mel_filter_bank( num_frequency_bins=self.n_fft // 2 + 1, num_mel_filters=self.n_mels, - min_frequency=0, - max_frequency=7690, + min_frequency=self.mel_min_frequency, + max_frequency=self.mel_max_frequency, sampling_rate=self.sample_rate, - norm=None, - mel_scale="kaldi", triangularize_in_mel_space=True, + mel_scale="kaldi", ) - def extract_spectrogram(self, audio, *, spectrogram_config): + def extract_spectrogram(self, audio, **kwargs): import torch - waveform = torch.stack(audio, dim=0) - device = waveform.device + waveform = torch.stack(audio) # (batch, length) batch_size = waveform.shape[0] - lengths = torch.tensor([a.shape[-1] for a in audio], device=device) + audio_lengths = kwargs.get("audio_lengths") - # Unfold into frames + fft_window = torch.hamming_window(self.win_length, periodic=False, dtype=torch.float64) frames = waveform.unfold(-1, self.win_length, self.hop_length) - # Frame-level masking for padded inputs - if batch_size > 1: + # Mask frames that overlap the boundary between real audio and padding + if batch_size > 1 and audio_lengths is not None: frames = frames.clone() - to_mask_batch_idxs = torch.arange(batch_size, device=device)[lengths != lengths.max()] + to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()] if to_mask_batch_idxs.numel() > 0: - batch_idxs_down = (lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 - batch_idxs_up = (lengths[to_mask_batch_idxs] // self.hop_length) - 1 + batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 + batch_idxs_up = (audio_lengths[to_mask_batch_idxs] // self.hop_length) - 1 offset_idx = batch_idxs_down.min() max_idx = batch_idxs_up.max() - mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1) + mask = torch.arange(max_idx - offset_idx).expand(to_mask_batch_idxs.shape[0], -1) mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( mask < (batch_idxs_up - offset_idx).unsqueeze(1) ) @@ -68,50 +71,73 @@ def extract_spectrogram(self, audio, *, spectrogram_config): masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames - # Pre-emphasis + # Pre-emphasis on frames with scaling frames_prev = torch.roll(frames, 1, dims=-1) frames_prev[:, :, 0] = frames_prev[:, :, 1] frames = (frames - self.preemphasis * frames_prev) * 32768 - # Hamming window + FFT - fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64) + # FFT S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) - S = S.view(batch_size, -1, S.shape[-1]).to(torch.complex64) + S = S.view(frames.shape[0], -1, S.shape[-1]) + S = S.to(torch.complex64) - spec = torch.abs(S) - spec_power = spec**2 + spec_power = torch.abs(S) ** 2 # Mel filterbank + log - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) + mel_filters = torch.from_numpy(self.mel_filters).to(torch.float32) log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) log_spec = torch.log(log_spec) return [log_spec[i] for i in range(batch_size)] - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + def _compute_audio_embed_size(self, audio_frames): import torch - # Pad values to longest - if padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) + integer = audio_frames // self.audio_compression_rate + remainder = audio_frames % self.audio_compression_rate + result = integer + (remainder > 0).to(integer.dtype) + + integer = result // self.audio_downsample_rate + remainder = result % self.audio_downsample_rate + result = integer + (remainder > 0).to(integer.dtype) + + return result + + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + **kwargs, + ) -> BatchFeature: + import torch + + # Capture original lengths before padding + audio_lengths = torch.tensor([a.shape[-1] for a in audio]) + + # Pad and truncate + audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) # Extract spectrogram - features = self.extract_spectrogram(audio, spectrogram_config=None) - - # Pad features and stack - max_feat_len = max(f.shape[0] for f in features) - padded = [] - for f in features: - if f.shape[0] < max_feat_len: - pad_amount = max_feat_len - f.shape[0] - f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) - padded.append(f) - - output_key = self.model_input_names[0] - stacked = torch.stack(padded, dim=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + features = self.extract_spectrogram(audio, audio_lengths=audio_lengths) + + # Compute audio_embed_sizes + feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 + feature_lengths = feature_lengths * self.audio_feat_stride + audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) + + data = {"audio_features": features, "audio_embed_sizes": audio_embed_sizes} + + # Attention mask for batched inputs with different lengths + if len(audio_lengths) > 1: + feature_attention_mask = torch.arange(0, feature_lengths.max())[None, :] < feature_lengths[:, None] + data["audio_attention_mask"] = feature_attention_mask + + output = BatchFeature(data, tensor_type=return_tensors) + return output __all__ = ["Phi4MultimodalAudioProcessor"] diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index 127f595cf3e8..ea5700cbfd73 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -89,7 +89,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of stacked = stacked.reshape(batch_size, num_frames // self.stride, num_channels * self.stride) - output_key = self.model_input_names[0] + output_key = "audio_features" return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index b670eb2a724c..d60310a98494 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -87,7 +87,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of for f, length in zip(padded, lengths) ] - output_key = self.model_input_names[0] + output_key = "audio_features" stacked = np.stack(normalized, axis=0) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index ca8e64808c26..2b801be4a075 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -94,9 +94,7 @@ def extract_spectrogram(self, audio, *, spectrogram_config): def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): # Pad raw audio if padding: - audio = self.pad_values( - audio, max_length=max_length, truncation=truncation, pad_to_multiple_of=pad_to_multiple_of - ) + audio = self.pad(audio, padding=True, max_length=max_length) # Extract mel spectrograms features = self.extract_spectrogram(audio, spectrogram_config=None) @@ -110,7 +108,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) padded.append(f) - output_key = self.model_input_names[0] + output_key = "audio_features" stacked = np.stack(padded, axis=0) return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 28049d0eccc3..34ddf266e524 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -12,33 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from ...audio_processing_backends import TorchAudioBackend +from ...feature_extraction_utils import BatchFeature class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): sample_rate = 24000 force_mono = True add_channel_dim = True - do_values_normalize = True - normalize_before_pad = True - - def __init__(self, target_dB_FS=-25, eps=1e-6, **kwargs): - self.target_dB_FS = target_dB_FS - self.eps = eps - super().__init__(**kwargs) - - def values_normalize(self, audio): - import torch - - normalized = [] - for a in audio: - rms = torch.sqrt(torch.mean(a**2)) - a = a * (10 ** (self.target_dB_FS / 20) / (rms + self.eps)) - max_val = torch.max(torch.abs(a)) - if max_val > 1.0: - a = a / (max_val + self.eps) - normalized.append(a) - return normalized + + target_dB_FS = -25 + eps = 1e-6 + + def _process_audio(self, audio_el): + audio_el = super()._process_audio(audio_el) + rms = torch.sqrt(torch.mean(audio_el**2)) + audio_el = audio_el * (10 ** (self.target_dB_FS / 20) / (rms + self.eps)) + max_val = torch.max(torch.abs(audio_el)) + if max_val > 1.0: + audio_el = audio_el / (max_val + self.eps) + return audio_el __all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py index 6e0c82762283..c42db2bc0d74 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py @@ -136,7 +136,7 @@ def __call__( output_values["padding_mask"] = output_values.pop("attention_mask") # add channel dimension - output_values["input_values"] = output_values["input_values"][:, None, :] + # output_values["input_values"] = output_values["input_values"][:, None, :] return output_values diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index 15a1203b8d6c..edc598f31e3f 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -34,5 +34,27 @@ class VoxtralRealtimeAudioProcessor(TorchAudioBackend): global_log_mel_max=1.5, ) + def extract_spectrogram(self, audio, **kwargs): + import torch + + features = super().extract_spectrogram(audio, **kwargs) + spectrogram_config = kwargs.get("spectrogram_config", self.spectrogram_config) + global_log_mel_max = spectrogram_config.global_log_mel_max + + processed = [] + for spec in features: + if global_log_mel_max is not None: + spec_max = torch.tensor( + global_log_mel_max, + device=spec.device, + dtype=spec.dtype, + ) + else: + spec_max = spec.max() + spec = torch.maximum(spec, spec_max - 8.0) + spec = (spec + 4.0) / 4.0 + processed.append(spec) + return processed + __all__ = ["VoxtralRealtimeAudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index a120cb4790f7..801ad0f99c3f 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -39,15 +39,12 @@ def extract_spectrogram(self, audio, **kwargs): import torch features = super().extract_spectrogram(audio, **kwargs) - spectrogram_config = kwargs.get("spectrogram_config", self.spectrogram_config) - mel_floor = spectrogram_config.mel_floor processed = [] for spec in features: - log_spec = torch.clamp(spec, min=mel_floor).log10() - max_val = log_spec.max() - log_spec = torch.maximum(log_spec, max_val - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - processed.append(log_spec) + max_val = spec.max() + spec = torch.maximum(spec, max_val - 8.0) + spec = (spec + 4.0) / 4.0 + processed.append(spec) return processed From 173981491d312ca8565f6a29adf833540957ed2a Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Tue, 10 Mar 2026 19:00:10 +0100 Subject: [PATCH 08/28] update --- src/transformers/audio_processing_backends.py | 119 ++++++++---------- src/transformers/audio_processing_utils.py | 26 ++-- .../audio_processing_granite_speech.py | 8 +- 3 files changed, 63 insertions(+), 90 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 842e8f55dab1..0406d9cc2805 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -14,19 +14,28 @@ # limitations under the License. +import sys +from pathlib import Path + import numpy as np from .audio_processing_utils import BaseAudioProcessor -from .audio_utils import SpectrogramConfig, amplitude_to_db, mel_filter_bank, power_to_db +from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db from .feature_extraction_utils import BatchFeature -from .utils import PaddingStrategy, TensorType, is_torch_available, is_torch_tensor, logging, to_numpy +from .utils import PaddingStrategy, TensorType, is_torch_available, logging logger = logging.get_logger(__name__) +_WORKSPACE_ROOT = str(Path(__file__).resolve().parents[3]) +if _WORKSPACE_ROOT not in sys.path: + sys.path.insert(0, _WORKSPACE_ROOT) + +from spectrograms import numpy_mel_spectrogram as _np_spec if is_torch_available(): import torch + from spectrograms import torch_mel_spectrogram as _torch_spec class NumpyAudioBackend(BaseAudioProcessor): @@ -86,40 +95,24 @@ def _extract_spectrogram( **kwargs, ) -> list[np.ndarray]: """Compute the (power) spectrogram via STFT using the numpy backend.""" - from .audio_utils import spectrogram as compute_spectrogram, window_function - stft_cfg = spectrogram_config.stft_config - n_fft = stft_cfg.n_fft - hop_length = stft_cfg.hop_length - win_length = stft_cfg.win_length if stft_cfg.win_length is not None else n_fft - - # Build window — map torch names like "hann_window" to audio_utils names like "hann" - window_name = stft_cfg.window_fn.replace("_window", "") - window = window_function(win_length, window_name, periodic=stft_cfg.periodic) - - features = [] - for waveform in audio: - w = waveform - if spectrogram_config.waveform_scale is not None: - w = np.squeeze(w) * spectrogram_config.waveform_scale - spec = compute_spectrogram( - w, - window=window, - frame_length=win_length, - hop_length=hop_length, - fft_length=n_fft, + + return _np_spec._extract_spectrogram( + audio, + self.sample_rate, + n_fft=stft_cfg.n_fft, + win_length=stft_cfg.win_length, + hop_length=stft_cfg.hop_length, + window_fn=stft_cfg.window_fn, power=stft_cfg.power, center=stft_cfg.center, pad_mode=stft_cfg.pad_mode, + normalized=stft_cfg.normalized, + pad=stft_cfg.pad, + periodic=stft_cfg.periodic, preemphasis=spectrogram_config.preemphasis, remove_dc_offset=spectrogram_config.remove_dc_offset, - mel_filters=None, - mel_floor=spectrogram_config.mel_floor, - log_mel=None, ) - features.append(spec) - - return features def _apply_mel_scale( self, @@ -129,14 +122,7 @@ def _apply_mel_scale( **kwargs, ) -> list[np.ndarray]: """Apply mel filterbank to spectrogram features using the numpy backend.""" - if not hasattr(self, "mel_filters"): - raise ValueError( - f"{self.__class__.__name__} does not have `mel_filters`. " - "Either set `mel_filters` or override `_apply_mel_scale`." - ) - - mel_filters = self.mel_filters - return [mel_filters.T @ spec for spec in features] + return _np_spec._apply_mel_scale(features, self.mel_filters, mel_floor=spectrogram_config.mel_floor) def _normalize_magnitude( self, @@ -182,7 +168,7 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - return mel_filter_bank( + return _np_spec.mel_filter_bank( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -190,6 +176,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): sampling_rate=self.sample_rate, norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) def _preprocess( @@ -205,11 +192,14 @@ def _preprocess( do_batch_spectrogram=True, **kwargs, ) -> BatchFeature: + import numpy as np + # pad and truncate audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) if do_extract_spectrogram: - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, do_batch_spectrogram=do_batch_spectrogram) + audio = np.stack(audio) if do_batch_spectrogram else audio + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) else: output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) @@ -280,40 +270,38 @@ def _extract_spectrogram( import torch stft_cfg = spectrogram_config.stft_config - n_fft = stft_cfg.n_fft - hop_length = stft_cfg.hop_length - win_length = stft_cfg.win_length if stft_cfg.win_length is not None else n_fft - # Stack list into batch for efficient batched STFT if not already batched if isinstance(audio, torch.Tensor) and audio.dim() == 2: waveform = audio else: - waveform = torch.stack(audio) # (batch, length) - device = waveform.device + waveform = torch.stack(audio) if spectrogram_config.preemphasis is not None: audio_ranges = kwargs.get("audio_ranges", None) if audio_ranges is not None: + device = waveform.device timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) timemask = timemask < audio_ranges.unsqueeze(1) waveform = waveform.masked_fill(~timemask, 0.0) - window_fn = getattr(torch, stft_cfg.window_fn, torch.hann_window) - window = window_fn(win_length, periodic=stft_cfg.periodic, device=device) - - stft = torch.stft( + magnitudes = _torch_spec._extract_spectrogram( waveform, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, + self.sample_rate, + n_fft=stft_cfg.n_fft, + win_length=stft_cfg.win_length, + hop_length=stft_cfg.hop_length, + window_fn=stft_cfg.window_fn, + wkwargs=stft_cfg.wkwargs, + power=stft_cfg.power, center=stft_cfg.center, pad_mode=stft_cfg.pad_mode, normalized=stft_cfg.normalized, - onesided=stft_cfg.onesided, - return_complex=True, + pad=stft_cfg.pad, + periodic=stft_cfg.periodic, + preemphasis=spectrogram_config.preemphasis, + remove_dc_offset=spectrogram_config.remove_dc_offset, ) - magnitudes = stft[..., :-1].abs() ** stft_cfg.power + magnitudes = magnitudes[..., :-1] return [magnitudes[i] for i in range(magnitudes.shape[0])] @@ -325,17 +313,7 @@ def _apply_mel_scale( **kwargs, ) -> list["torch.Tensor"]: """Apply mel filterbank to spectrogram features using the torch backend.""" - import torch - - if not hasattr(self, "mel_filters"): - raise ValueError( - f"{self.__class__.__name__} does not have `mel_filters`. " - "Either set `mel_filters` or override `_apply_mel_scale`." - ) - - device = features[0].device - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - return [mel_filters.T @ spec for spec in features] + return _torch_spec._apply_mel_scale(features, self.mel_filters, mel_floor=spectrogram_config.mel_floor) def _normalize_magnitude( self, @@ -399,7 +377,7 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - return mel_filter_bank( + return _torch_spec.mel_filter_bank_torch( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -407,7 +385,8 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): sampling_rate=self.sample_rate, norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, - ) + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + ).numpy() def _preprocess( self, @@ -429,7 +408,7 @@ def _preprocess( if do_extract_spectrogram: audio = torch.stack(audio) if do_batch_spectrogram else audio - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, do_batch_spectrogram=do_batch_spectrogram) + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) else: output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 6666dd7346ab..44eb96bfdc81 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -52,17 +52,9 @@ class BaseAudioProcessor(AudioProcessingMixin): sample_rate: int = None force_mono: bool = None - # Pipeline stage defaults - normalize_before_pad = True spectrogram_config = None do_extract_spectrogram = None - do_feature_normalize = None - feature_normalize_before_pad = True - do_pad_features = None - do_resample = False - add_channel_dim = False pad_to_multiple_of = None - transpose_features = False def __init__( self, @@ -236,7 +228,8 @@ def _pad_single(self, audio, max_length: int) -> AudioInput: """ raise NotImplementedError - def extract_spectrogram(self, audio, *, do_batch_spectrogram: bool = True, spectrogram_config: SpectrogramConfig | None = None, **kwargs): + def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | None = None, **kwargs): + # TODO: it might be a bit unclear to have extract_spectrogram and _extract_spectrogram methods. """ Both the numpy and torch backends implement this method in a batched/ sequential manner. Is is batched by default, but can be set to be sequential. @@ -252,17 +245,12 @@ def extract_spectrogram(self, audio, *, do_batch_spectrogram: bool = True, spect overrides = {k: kwargs.pop(k) for k in list(kwargs) if k in config_field_names} if overrides: spectrogram_config = replace(spectrogram_config, **overrides) + + features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + if spectrogram_config.mel_scale_config is not None: + features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) + features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) - if do_batch_spectrogram: - features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) - if spectrogram_config.mel_scale_config is not None: - features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) - features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) - else: - features = [self._extract_spectrogram(audio_el, spectrogram_config=spectrogram_config, **kwargs) for audio_el in audio] - if spectrogram_config.mel_scale_config is not None: - features = [self._apply_mel_scale(feature_el, spectrogram_config=spectrogram_config, **kwargs) for feature_el in features] - features = [self._normalize_magnitude(feature_el, spectrogram_config=spectrogram_config, **kwargs) for feature_el in features] return features def _extract_spectrogram(self, *args, **kwargs): diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index 4e250c535b62..c3875e42c4c3 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -32,7 +32,13 @@ class GraniteSpeechAudioProcessor(TorchAudioBackend): ), log_mode="log10", ) - + + def extract_spectrogram(self, audio, **kwargs): + features = super().extract_spectrogram(audio, **kwargs) + features = [f.T for f in features] + features = [f.reshape(-1, 2 * f.shape[-1]) for f in features] + + return features __all__ = ["GraniteSpeechAudioProcessor"] From 715ab5a7df6f4fabf234805d72d42d9ac5047f34 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 11 Mar 2026 13:56:00 +0100 Subject: [PATCH 09/28] torch equal on audio processor vs feature extractor outputs --- src/transformers/audio_processing_backends.py | 82 ++++++------- src/transformers/audio_processing_utils.py | 35 ++++-- ...rocessing_audio_spectrogram_transformer.py | 38 ++++-- .../models/clap/audio_processing_clap.py | 110 ++++++++++++------ .../models/clvp/audio_processing_clvp.py | 50 ++++++-- .../gemma3n/audio_processing_gemma3n.py | 31 ++++- .../audio_processing_granite_speech.py | 64 +++++++++- .../models/lasr/audio_processing_lasr.py | 11 +- .../audio_processing_seamless_m4t.py | 46 ++++++-- .../audio_processing_speech_to_text.py | 41 +++++-- .../univnet/audio_processing_univnet.py | 18 ++- .../audio_processing_voxtral_realtime.py | 54 +++++++-- .../whisper/audio_processing_whisper.py | 49 ++++++-- 13 files changed, 472 insertions(+), 157 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 0406d9cc2805..7464da298af7 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -126,7 +126,7 @@ def _apply_mel_scale( def _normalize_magnitude( self, - features: list[np.ndarray], + features: np.ndarray, *, spectrogram_config: SpectrogramConfig, reference: float = 1.0, @@ -134,9 +134,10 @@ def _normalize_magnitude( db_range: float | None = None, dtype: np.dtype = np.float32, **kwargs, - ) -> list[np.ndarray]: + ) -> np.ndarray: """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. + Accepts a single or batched spectrogram (not a list). Mirrors the normalization logic in `audio_utils.spectrogram()`. """ log_mel = spectrogram_config.log_mode @@ -147,17 +148,17 @@ def _normalize_magnitude( return features # Clamp to mel_floor before taking log - result = [np.maximum(mel_floor, spec) for spec in features] + result = np.maximum(mel_floor, features) if log_mel == "log": - result = [np.log(spec).astype(dtype) for spec in result] + result = np.log(result).astype(dtype) elif log_mel == "log10": - result = [np.log10(spec).astype(dtype) for spec in result] + result = np.log10(result).astype(dtype) elif log_mel == "dB": if power == 1.0: - result = [amplitude_to_db(spec, reference, min_value, db_range).astype(dtype) for spec in result] + result = amplitude_to_db(result, reference, min_value, db_range).astype(dtype) elif power == 2.0: - result = [power_to_db(spec, reference, min_value, db_range).astype(dtype) for spec in result] + result = power_to_db(result, reference, min_value, db_range).astype(dtype) else: raise ValueError(f"Cannot use log_mel option 'dB' with power {power}") else: @@ -261,31 +262,25 @@ def _pad_single(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": def _extract_spectrogram( self, - audio: list["torch.Tensor"], + audio: list["torch.Tensor"], # TODO: this can be either a audio or batch of audio and this should be documented *, spectrogram_config: SpectrogramConfig, **kwargs, ) -> list["torch.Tensor"]: """Compute the (power) spectrogram via STFT using the torch backend.""" - import torch stft_cfg = spectrogram_config.stft_config - - if isinstance(audio, torch.Tensor) and audio.dim() == 2: - waveform = audio - else: - waveform = torch.stack(audio) - - if spectrogram_config.preemphasis is not None: - audio_ranges = kwargs.get("audio_ranges", None) - if audio_ranges is not None: - device = waveform.device - timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) - timemask = timemask < audio_ranges.unsqueeze(1) - waveform = waveform.masked_fill(~timemask, 0.0) + + # if spectrogram_config.preemphasis is not None: + # audio_ranges = kwargs.get("audio_ranges", None) + # if audio_ranges is not None: + # device = waveform.device + # timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) + # timemask = timemask < audio_ranges.unsqueeze(1) + # waveform = waveform.masked_fill(~timemask, 0.0) magnitudes = _torch_spec._extract_spectrogram( - waveform, + audio, self.sample_rate, n_fft=stft_cfg.n_fft, win_length=stft_cfg.win_length, @@ -301,9 +296,8 @@ def _extract_spectrogram( preemphasis=spectrogram_config.preemphasis, remove_dc_offset=spectrogram_config.remove_dc_offset, ) - magnitudes = magnitudes[..., :-1] - return [magnitudes[i] for i in range(magnitudes.shape[0])] + return magnitudes def _apply_mel_scale( self, @@ -317,7 +311,7 @@ def _apply_mel_scale( def _normalize_magnitude( self, - features: list["torch.Tensor"], + features: "torch.Tensor", *, spectrogram_config: SpectrogramConfig, reference: float = 1.0, @@ -325,11 +319,8 @@ def _normalize_magnitude( db_range: float | None = None, dtype: "torch.dtype | None" = None, **kwargs, - ) -> list["torch.Tensor"]: - """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. - - Mirrors the normalization logic in `audio_utils.spectrogram()`. - """ + ) -> "torch.Tensor": + """Apply magnitude normalization (log, log10, or dB scaling) to batched spectrogram features (torch.Tensor only).""" import torch log_mel = spectrogram_config.log_mode @@ -343,12 +334,12 @@ def _normalize_magnitude( return features # Clamp to mel_floor before taking log - result = [torch.clamp(spec, min=mel_floor) for spec in features] + result = torch.clamp(features, min=mel_floor) if log_mel == "log": - result = [torch.log(spec).to(dtype) for spec in result] + result = torch.log(result).to(dtype) elif log_mel == "log10": - result = [torch.log10(spec).to(dtype) for spec in result] + result = torch.log10(result).to(dtype) elif log_mel == "dB": if reference <= 0.0: raise ValueError("reference must be greater than zero") @@ -358,17 +349,16 @@ def _normalize_magnitude( multiplier = 10.0 if power == 2.0 else 20.0 if power == 1.0 else None if multiplier is None: raise ValueError(f"Cannot use log_mel option 'dB' with power {power}") - log_ref = np.log10(reference) - processed = [] - for spec in result: - spec = torch.clamp(spec, min=min_value) - spec = multiplier * (torch.log10(spec) - log_ref) - if db_range is not None: - if db_range <= 0.0: - raise ValueError("db_range must be greater than zero") - spec = torch.clamp(spec, min=spec.max() - db_range) - processed.append(spec.to(dtype)) - result = processed + log_ref = torch.log10(torch.tensor(reference, dtype=result.dtype, device=result.device)) + result = torch.clamp(result, min=min_value) + result = multiplier * (torch.log10(result) - log_ref) + if db_range is not None: + if db_range <= 0.0: + raise ValueError("db_range must be greater than zero") + # Clamp each sample so the minimum value is (max - db_range) + max_vals = result.amax(dim=-2, keepdim=True) if result.ndim > 2 else result.max() + result = torch.clamp(result, min=max_vals - db_range) + result = result.to(dtype) else: raise ValueError(f"Unknown log_mel option: {log_mel}") @@ -386,7 +376,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - ).numpy() + ) def _preprocess( self, diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 44eb96bfdc81..e379be374edc 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -15,10 +15,11 @@ from dataclasses import fields, replace from typing import Unpack +import numpy as np from huggingface_hub.dataclasses import validate_typed_dict from .audio_processing_base import AudioProcessingMixin -from .audio_utils import AudioInput, SpectrogramConfig, make_list_of_audio, mel_filter_bank +from .audio_utils import AudioInput, SpectrogramConfig, make_list_of_audio from .feature_extraction_utils import BatchFeature from .processing_utils import AudioKwargs from .utils import PaddingStrategy, TensorType, logging @@ -35,6 +36,7 @@ class AudioProcessingKwargs(AudioKwargs, total=False): do_extract_spectrogram: bool | None do_pad_features: bool | None do_resample: bool | None + generator: np.random.Generator | None class BaseAudioProcessor(AudioProcessingMixin): @@ -87,7 +89,7 @@ def __init__( setattr(self, key, value) # Derive mel_filters from spectrogram_config if mel_scale_config is set - # TODO: maybe the mel spectrogram initialization should be lazy? + # TODO: maybe the mel spectrogram initialization should be lazy? if self.spectrogram_config is not None and self.spectrogram_config.mel_scale_config is not None: if not hasattr(self, "mel_filters"): self.mel_filters = self._mel_filter_bank(self.spectrogram_config) @@ -113,7 +115,7 @@ def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingK self._validate_preprocess_kwargs(**kwargs) return self._preprocess_audio_like_inputs(audio, *args, **kwargs) - + def _preprocess_audio_like_inputs( self, audio: AudioInput, @@ -245,11 +247,26 @@ def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | overrides = {k: kwargs.pop(k) for k in list(kwargs) if k in config_field_names} if overrides: spectrogram_config = replace(spectrogram_config, **overrides) - - features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) - if spectrogram_config.mel_scale_config is not None: - features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) - features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) + + if isinstance(audio, list): + features = [ + self._extract_spectrogram(a, spectrogram_config=spectrogram_config, **kwargs) + for a in audio + ] + if spectrogram_config.mel_scale_config is not None: + features = [ + self._apply_mel_scale(f, spectrogram_config=spectrogram_config, **kwargs) + for f in features + ] + features = [ + self._normalize_magnitude(f, spectrogram_config=spectrogram_config, **kwargs) + for f in features + ] + else: + features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + if spectrogram_config.mel_scale_config is not None: + features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) + features = self._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) return features @@ -332,7 +349,7 @@ def _validate_preprocess_kwargs( if truncation and max_length is None: raise ValueError( "When setting `truncation=True`, make sure that `max_length` is defined." - ) + ) def to_dict(self): output = super().to_dict() diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index b20db82fd2ee..5a731fddafb9 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -14,12 +14,19 @@ import numpy as np -from ...audio_processing_backends import NumpyAudioBackend +from ...audio_processing_backends import NumpyAudioBackend, TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature +from ...utils import is_speech_available, is_torch_available +if is_speech_available(): + import torchaudio.compliance.kaldi as ta_kaldi -class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): +if is_torch_available(): + import torch + + +class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_speech_available() else TorchAudioBackend): sample_rate = 16000 force_mono = True @@ -53,12 +60,29 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): mel_floor=1.192092955078125e-07, ) - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Compute spectrogram per-sample (no audio padding beforehand) - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + def _extract_fbank_features_torchaudio(self, waveform) -> np.ndarray: + """Extract mel-filter bank features using torchaudio Kaldi (matches ASTFeatureExtractor).""" + if isinstance(waveform, np.ndarray): + waveform = torch.from_numpy(waveform) + waveform = waveform.unsqueeze(0) + fbank = ta_kaldi.fbank( + waveform, + sample_frequency=self.sample_rate, + window_type="hanning", + num_mel_bins=128, + ) + return fbank.numpy() - # (n_mels, frames) -> (frames, n_mels) - features = [f.T for f in features] + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Compute spectrogram per-sample using the same method as ASTFeatureExtractor + if is_speech_available(): + # Use torchaudio Kaldi for exact match with ASTFeatureExtractor + features = [self._extract_fbank_features_torchaudio(waveform) for waveform in audio] + else: + # Use numpy spectrogram (fallback) + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + # (n_mels, frames) -> (frames, n_mels) + features = [f.T for f in features] # Pad or truncate to max_length_frames padded = [] diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index 5119470489c8..6d43d3e57fcb 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -13,62 +13,96 @@ # limitations under the License. import numpy as np -import torch from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function from ...feature_extraction_utils import BatchFeature class ClapAudioProcessor(NumpyAudioBackend): sample_rate = 48000 force_mono = True - n_fft = 1024 - hop_length = 480 - n_mels = 64 - f_min = 0 - f_max = 14000 max_length_s = 10 truncation_mode = "rand_trunc" # "fusion" or "rand_trunc" padding_mode = "repeatpad" # "repeatpad", "repeat", or "pad" + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=1024, + hop_length=480, + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=64, + f_min=50, + f_max=14000, + mel_scale="slaney", + norm="slaney", + ), + log_mode="dB", + ) + + # Fusion mode uses a different mel filter bank (htk scale, no norm) + spectrogram_config_fusion = SpectrogramConfig( + stft_config=StftConfig( + n_fft=1024, + hop_length=480, + power=2.0, + ), + mel_scale_config=MelScaleConfig( + n_mels=64, + f_min=0, + f_max=14000, + mel_scale="htk", + ), + log_mode="dB", + ) + def __init__(self, **kwargs): super().__init__(**kwargs) self.nb_max_samples = self.max_length_s * self.sample_rate - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + self.n_fft // 2, - num_mel_filters=self.n_mels, - min_frequency=self.f_min, - max_frequency=self.f_max, + self.mel_filters_fusion = self._mel_filter_bank(self.spectrogram_config_fusion) + + def _mel_filter_bank(self, spectrogram_config): + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + return mel_filter_bank( + num_frequency_bins=(stft_cfg.n_fft // 2) + 1, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, sampling_rate=self.sample_rate, - norm=None, - mel_scale="htk", - ) - self.mel_filters_slaney = mel_filter_bank( - num_frequency_bins=1 + self.n_fft // 2, - num_mel_filters=self.n_mels, - min_frequency=self.f_min, - max_frequency=self.f_max, - sampling_rate=self.sample_rate, - norm="slaney", - mel_scale="htk", + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, ) - def _np_extract_fbank_features(self, waveform, mel_filters=None): - if mel_filters is None: + def _extract_single_mel(self, waveform, spectrogram_config=None): + """Extract mel spectrogram for a single waveform using audio_utils.spectrogram.""" + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + + # Use the correct mel filters for this config + if spectrogram_config is self.spectrogram_config_fusion: + mel_filters = self.mel_filters_fusion + else: mel_filters = self.mel_filters - log_mel = spectrogram( + + log_mel_spectrogram = spectrogram( waveform, - window_function(self.n_fft, "hann"), - frame_length=self.n_fft, - hop_length=self.hop_length, + window_function(stft_cfg.n_fft, "hann"), + frame_length=stft_cfg.n_fft, + hop_length=stft_cfg.hop_length, power=2.0, mel_filters=mel_filters, log_mel="dB", ) - return log_mel.T + return log_mel_spectrogram.T def _random_mel_fusion(self, mel, total_frames, chunk_frames): + import torch + ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) if len(ranges[1]) == 0: ranges[1] = [0] @@ -90,16 +124,18 @@ def _random_mel_fusion(self, mel, total_frames, chunk_frames): return np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) def _get_input_mel(self, waveform, max_length, truncation, padding): + hop_length = self.spectrogram_config.stft_config.hop_length + if waveform.shape[0] > max_length: if truncation == "rand_trunc": longer = True overflow = len(waveform) - max_length idx = np.random.randint(0, overflow + 1) waveform = waveform[idx : idx + max_length] - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + input_mel = self._extract_single_mel(waveform)[None, :] elif truncation == "fusion": - mel = self._np_extract_fbank_features(waveform, self.mel_filters) - chunk_frames = max_length // self.hop_length + 1 + mel = self._extract_single_mel(waveform, spectrogram_config=self.spectrogram_config_fusion) + chunk_frames = max_length // hop_length + 1 total_frames = mel.shape[0] if chunk_frames == total_frames: input_mel = np.stack([mel, mel, mel, mel], axis=0) @@ -121,17 +157,17 @@ def _get_input_mel(self, waveform, max_length, truncation, padding): waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) if truncation == "fusion": - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) + input_mel = self._extract_single_mel(waveform, spectrogram_config=self.spectrogram_config_fusion) input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) else: - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] + input_mel = self._extract_single_mel(waveform)[None, :] return input_mel, longer def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): truncation_mode = self.truncation_mode - padding_mode = self.padding_mode - nb_max_samples = max_length if max_length else self.nb_max_samples + padding_mode = padding if padding else self.padding_mode + nb_max_samples = max_length if isinstance(max_length, int) and max_length > 0 else self.nb_max_samples padded_inputs = [ self._get_input_mel(np.squeeze(waveform), nb_max_samples, truncation_mode, padding_mode) diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index 082e049a9cca..624607bff742 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function from ...feature_extraction_utils import BatchFeature @@ -49,13 +49,45 @@ def __init__(self, mel_norms=None, **kwargs): super().__init__(**kwargs) self.mel_norms = mel_norms - def extract_spectrogram(self, audio, *, spectrogram_config): - # Use the generic config-based API for the core spectrogram - features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config) - - # Apply mel_norms if provided - if self.mel_norms is not None: - features = [f / np.array(self.mel_norms)[:, None] for f in features] + def _mel_filter_bank(self, spectrogram_config): + mel_cfg = spectrogram_config.mel_scale_config + stft_cfg = spectrogram_config.stft_config + return mel_filter_bank( + num_frequency_bins=1 + (stft_cfg.n_fft // 2), + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else 8000.0, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + ) + + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + + if not isinstance(audio, list): + audio = [audio] + + stft_cfg = spectrogram_config.stft_config + features = [] + for waveform in audio: + waveform = np.squeeze(waveform) + log_spec = spectrogram( + waveform, + window_function(stft_cfg.n_fft, "hann"), + frame_length=stft_cfg.n_fft, + hop_length=stft_cfg.hop_length, + power=2.0, + mel_filters=self.mel_filters, + log_mel=None, + ) + log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None)) + + if self.mel_norms is not None: + log_spec = log_spec / np.array(self.mel_norms)[:, None] + + features.append(log_spec) return features @@ -74,7 +106,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of pad_length = max_length audio = self.pad(audio, padding=True, max_length=pad_length) - # Extract spectrogram via config-based API (with mel_norms applied) + # Extract spectrogram via audio_utils (with mel_norms applied) features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) # Cast to float32 to match the legacy FeatureExtractor diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index d24bb7bb9e6f..db6bafbeb460 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -113,6 +113,19 @@ def _mel_filter_bank(self, spectrogram_config): fft_length=sc.stft_config.n_fft, ) + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + + # Process all waveforms at once (bypass base class per-element iteration) + if not isinstance(audio, list): + audio = [audio] + + features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) + # Skip _normalize_magnitude: _apply_mel_scale already applies log + per-bin normalization + return features + def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): """Custom STFT with HTK-flavor preemphasis.""" stft_cfg = spectrogram_config.stft_config @@ -175,12 +188,13 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of if truncation and max_length is not None: audio = [a[..., :max_length] for a in audio] - # Record original audio lengths before padding (for computing the frame mask) - original_lengths = [a.shape[-1] for a in audio] - + # Determine pad length (to longest in batch, rounded to multiple) pad_length = max(a.shape[-1] for a in audio) if pad_to_multiple_of is not None and (pad_length % pad_to_multiple_of != 0): pad_length = ((pad_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + # Record original audio lengths and pad + original_lengths = [a.shape[-1] for a in audio] audio = [self._pad_single(a, pad_length) for a in audio] # Extract spectrogram via orchestrator (_extract_spectrogram + _apply_mel_scale) @@ -192,15 +206,20 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of output_key = "audio_features" stacked = np.stack(features, axis=0) - # Compute per-sample spectrogram frame counts from original (pre-padding) audio lengths + # Compute per-sample spectrogram frame counts from the padded audio length + # This matches the FeatureExtractor which subsamples the attention mask from padded audio frame_size_for_unfold = spectrogram_config.stft_config.win_length + 1 # 513 hop_length = spectrogram_config.stft_config.hop_length # 160 - frame_counts = [(orig_len - frame_size_for_unfold) // hop_length + 1 for orig_len in original_lengths] + # Use padded length for frame count to match FeatureExtractor behavior + frame_counts = [(pad_length - frame_size_for_unfold) // hop_length + 1 for _ in original_lengths] + # Then limit by actual audio length to avoid counting frames beyond original audio + max_frames_per_audio = [(orig_len - 1) // hop_length + 1 for orig_len in original_lengths] + frame_counts = [min(fc, max_f) for fc, max_f in zip(frame_counts, max_frames_per_audio)] # Build mask: 1 for real frames, 0 for padded frames max_frames = stacked.shape[1] input_features_mask = np.array( - [[1] * fc + [0] * (max_frames - fc) for fc in frame_counts], dtype=np.int32 + [[True] * fc + [False] * (max_frames - fc) for fc in frame_counts], dtype=bool ) return BatchFeature( diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index c3875e42c4c3..3ea66476ae66 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math + +import torch + from ...audio_processing_backends import TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature @@ -20,6 +24,8 @@ class GraniteSpeechAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True + projector_window_size = 15 + projector_downsample_rate = 5 spectrogram_config = SpectrogramConfig( stft_config=StftConfig( n_fft=512, @@ -35,10 +41,62 @@ class GraniteSpeechAudioProcessor(TorchAudioBackend): def extract_spectrogram(self, audio, **kwargs): features = super().extract_spectrogram(audio, **kwargs) - features = [f.T for f in features] - features = [f.reshape(-1, 2 * f.shape[-1]) for f in features] - return features + processed = [] + for f in features: + # f is (n_mels, frames) from base; transpose to (frames, n_mels) + f = f.T + + # Apply max-8 normalization matching the FE + mx = f.amax(dim=(-2, -1), keepdim=True) + f = torch.maximum(f, mx - 8.0) + f = f / 4.0 + 1.0 + + # Remove last frame if odd + if f.shape[0] % 2 == 1: + f = f[:-1] + + # Stack pairs of frames: (frames//2, n_mels*2) + f = f.reshape(-1, 2 * f.shape[-1]) + processed.append(f) + + return processed + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, + spectrogram_config=None, do_extract_spectrogram=None, **kwargs): + hop_length = self.spectrogram_config.stft_config.hop_length + + # Record original lengths before padding + audio_lengths = [a.shape[-1] for a in audio] + + # Pad audio to longest in batch + audio = self.pad(audio, padding=True, max_length=max_length) + + # Stack and extract spectrogram + audio_stacked = torch.stack(audio) + features = self.extract_spectrogram(audio_stacked, spectrogram_config=spectrogram_config) + + # Compute audio_embed_sizes matching the FE + effective_window_size = self.projector_window_size // self.projector_downsample_rate + audio_embed_sizes = [] + for raw_length in audio_lengths: + mel_length = raw_length // hop_length + 1 + encoder_length = mel_length // 2 + nblocks = math.ceil(encoder_length / self.projector_window_size) + projector_length = nblocks * effective_window_size + audio_embed_sizes.append(projector_length) + + # Build input_features_mask matching the FE + input_features_mask = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor( + audio_embed_sizes + ).view(-1, 1) + + data = { + "audio_features": features, + "audio_embed_sizes": audio_embed_sizes, + "input_features_mask": input_features_mask, + } + return BatchFeature(data=data, tensor_type=return_tensors) __all__ = ["GraniteSpeechAudioProcessor"] diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index 7df73470e135..5ba188a6e1d7 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -62,15 +62,22 @@ def __init__(self, **kwargs): upper_edge_hertz=7500.0, ) - def extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): import torch + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + stft_cfg = spectrogram_config.stft_config n_fft = stft_cfg.n_fft hop_length = stft_cfg.hop_length win_length = stft_cfg.win_length or n_fft - waveform = torch.stack(audio, dim=0).to(torch.float64) + if isinstance(audio, list): + waveform = torch.stack(audio, dim=0).to(torch.float64) + else: + waveform = audio.to(torch.float64) + device = waveform.device window = torch.hann_window(win_length, periodic=False, device=device, dtype=torch.float64) diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index ea5700cbfd73..a4e96e8178c6 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function from ...feature_extraction_utils import BatchFeature @@ -29,7 +29,7 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): n_fft=512, win_length=400, hop_length=160, - window_fn="povey_window", + window_fn="povey", power=2.0, center=False, periodic=False, @@ -48,6 +48,41 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): waveform_scale=32768.0, ) + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.window = window_function(400, "povey", periodic=False) + + def _mel_filter_bank(self, spectrogram_config): + mel_cfg = spectrogram_config.mel_scale_config + return mel_filter_bank( + num_frequency_bins=257, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate // 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + ) + + def _extract_fbank_features(self, waveform): + waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers + features = spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T + return features + def feature_normalize(self, features): # Per-mel-bin normalization with ddof=1 for variance normalized = [] @@ -58,11 +93,8 @@ def feature_normalize(self, features): return normalized def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Extract Kaldi-style features via generic config-based API - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - - # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) - features = [f.T for f in features] + # Extract Kaldi-style features matching the FE exactly + features = [self._extract_fbank_features(waveform) for waveform in audio] # Per-mel-bin normalization features = self.feature_normalize(features) diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index d60310a98494..abe048cd92f3 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -15,8 +15,13 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function from ...feature_extraction_utils import BatchFeature +from ...utils import is_speech_available + +if is_speech_available(): + import torch + import torchaudio.compliance.kaldi as ta_kaldi class SpeechToTextAudioProcessor(NumpyAudioBackend): @@ -28,7 +33,7 @@ class SpeechToTextAudioProcessor(NumpyAudioBackend): n_fft=512, win_length=400, hop_length=160, - window_fn="povey_window", + window_fn="povey", power=2.0, center=False, periodic=False, @@ -51,6 +56,31 @@ def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): super().__init__(**kwargs) self.normalize_means = normalize_means self.normalize_vars = normalize_vars + if not is_speech_available(): + self.window = window_function(400, "povey", periodic=False) + + def _extract_fbank_features(self, waveform): + waveform = waveform * (2**15) # Kaldi compliance + if is_speech_available(): + waveform_tensor = torch.from_numpy(waveform).unsqueeze(0) + features = ta_kaldi.fbank(waveform_tensor, num_mel_bins=80, sample_frequency=self.sample_rate) + return features.numpy() + else: + waveform = np.squeeze(waveform) + return spectrogram( + waveform, + self.window, + frame_length=400, + hop_length=160, + fft_length=512, + power=2.0, + center=False, + preemphasis=0.97, + mel_filters=self.mel_filters, + log_mel="log", + mel_floor=1.192092955078125e-07, + remove_dc_offset=True, + ).T @staticmethod def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, padding_value=0.0): @@ -65,11 +95,8 @@ def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, p return x.astype(np.float32) def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Extract Kaldi-style features via generic config-based API - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - - # Generic extract_spectrogram returns (n_mels, frames); transpose to (frames, n_mels) - features = [f.T for f in features] + # Extract Kaldi-style features matching the FE exactly + features = [self._extract_fbank_features(waveform) for waveform in audio] lengths = [f.shape[0] for f in features] # Pad features to longest diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index 2b801be4a075..449646726727 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -88,10 +88,10 @@ def extract_spectrogram(self, audio, *, spectrogram_config): mel = self.mel_spectrogram(waveform) if self.do_normalize: mel = self.normalize(mel) - features.append(mel) + features.append(mel.astype(np.float32)) return features - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, generator=None, **kwargs): # Pad raw audio if padding: audio = self.pad(audio, padding=True, max_length=max_length) @@ -110,7 +110,19 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of output_key = "audio_features" stacked = np.stack(padded, axis=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + + # Generate noise sequence matching the FE + if generator is None: + generator = np.random.default_rng() + noise = [ + generator.standard_normal((f.shape[0], 64), dtype=np.float32) + for f in padded + ] + + return BatchFeature( + data={output_key: stacked, "noise_sequence": noise}, + tensor_type=return_tensors, + ) __all__ = ["UnivNetAudioProcessor"] diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index edc598f31e3f..60bead1ebc59 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank class VoxtralRealtimeAudioProcessor(TorchAudioBackend): @@ -34,21 +36,36 @@ class VoxtralRealtimeAudioProcessor(TorchAudioBackend): global_log_mel_max=1.5, ) - def extract_spectrogram(self, audio, **kwargs): - import torch + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config - features = super().extract_spectrogram(audio, **kwargs) - spectrogram_config = kwargs.get("spectrogram_config", self.spectrogram_config) + stft_cfg = spectrogram_config.stft_config global_log_mel_max = spectrogram_config.global_log_mel_max + if isinstance(audio, list): + waveform = torch.stack(audio) + else: + waveform = audio + + device = waveform.device + window = torch.hann_window(stft_cfg.n_fft, device=device) + stft = torch.stft( + waveform, stft_cfg.n_fft, stft_cfg.hop_length, + window=window, return_complex=True, center=True, + ) + magnitudes = stft[..., :-1].abs() ** 2 + + mel_filters = self.mel_filters.to(device, torch.float32) + mel_spec = mel_filters.T @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + processed = [] - for spec in features: + for i in range(log_spec.shape[0]): + spec = log_spec[i] if global_log_mel_max is not None: - spec_max = torch.tensor( - global_log_mel_max, - device=spec.device, - dtype=spec.dtype, - ) + spec_max = torch.tensor(global_log_mel_max, device=spec.device, dtype=spec.dtype) else: spec_max = spec.max() spec = torch.maximum(spec, spec_max - 8.0) @@ -56,5 +73,20 @@ def extract_spectrogram(self, audio, **kwargs): processed.append(spec) return processed + def _mel_filter_bank(self, spectrogram_config): + """Override to use numpy mel_filter_bank for exact match with feature extractor.""" + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + mel_filters_np = mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + ) + return torch.from_numpy(mel_filters_np).to(torch.float32) + __all__ = ["VoxtralRealtimeAudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 801ad0f99c3f..ca333bd9e765 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank class WhisperAudioProcessor(TorchAudioBackend): @@ -36,16 +38,43 @@ class WhisperAudioProcessor(TorchAudioBackend): ) def extract_spectrogram(self, audio, **kwargs): - import torch - features = super().extract_spectrogram(audio, **kwargs) - processed = [] - for spec in features: - max_val = spec.max() - spec = torch.maximum(spec, max_val - 8.0) - spec = (spec + 4.0) / 4.0 - processed.append(spec) - return processed + features = features[..., :-1] # whisper skips last frame + + max_vals = features.amax(dim=(-2, -1), keepdim=True) + features = torch.maximum(features, max_vals - 8.0) + features = (features + 4.0) / 4.0 + + return features + + def _mel_filter_bank(self, spectrogram_config): + """ + Override to use the same numpy-based mel filter bank as WhisperFeatureExtractor + for exact numerical compatibility. + """ + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + mel_filters = mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + ) + return torch.from_numpy(mel_filters).to(torch.float32) + + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + """ + Override to use the same matrix multiplication order as WhisperFeatureExtractor + for exact numerical compatibility. FeatureExtractor uses (n_mels, n_freq) @ (n_freq, time), + while the generic spectrograms module uses (time, n_freq) @ (n_freq, n_mels) then transpose. + The different summation order produces slightly different rounding (1 ULP). + """ + stacked = torch.stack(features) if isinstance(features, list) else features + mel_spec = torch.matmul(self.mel_filters.T, stacked) + return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) __all__ = ["WhisperAudioProcessor"] From 1fbe91601902e70699d53cfdcf3d239eebd4b4dd Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Wed, 11 Mar 2026 17:43:06 +0100 Subject: [PATCH 10/28] update passing test --- src/transformers/audio_processing_backends.py | 9 +- ...rocessing_audio_spectrogram_transformer.py | 1 + .../models/clap/audio_processing_clap.py | 93 ++++- .../gemma3n/audio_processing_gemma3n.py | 93 +---- .../audio_processing_musicgen_melody.py | 1 - .../parakeet/audio_processing_parakeet.py | 126 +++--- .../audio_processing_phi4_multimodal.py | 1 - .../qwen2_vl/image_processing_qwen2_vl.py | 2 +- .../audio_processing_speech_to_text.py | 1 + ...processing_vibevoice_acoustic_tokenizer.py | 3 +- .../whisper/audio_processing_whisper.py | 4 +- ..._audio_processors_vs_feature_extractors.py | 368 ++++++++++++++++++ 12 files changed, 544 insertions(+), 158 deletions(-) create mode 100644 tests/test_audio_processors_vs_feature_extractors.py diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 7464da298af7..ca211d1c37d6 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -22,7 +22,7 @@ from .audio_processing_utils import BaseAudioProcessor from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db from .feature_extraction_utils import BatchFeature -from .utils import PaddingStrategy, TensorType, is_torch_available, logging +from .utils import is_torch_available, logging logger = logging.get_logger(__name__) @@ -33,6 +33,7 @@ from spectrograms import numpy_mel_spectrogram as _np_spec + if is_torch_available(): import torch from spectrograms import torch_mel_spectrogram as _torch_spec @@ -43,7 +44,7 @@ class NumpyAudioBackend(BaseAudioProcessor): @property def backend(self) -> str: - return "numpy" + return "numpy" def _process_audio(self, audio_el): """ @@ -270,7 +271,7 @@ def _extract_spectrogram( """Compute the (power) spectrogram via STFT using the torch backend.""" stft_cfg = spectrogram_config.stft_config - + # if spectrogram_config.preemphasis is not None: # audio_ranges = kwargs.get("audio_ranges", None) # if audio_ranges is not None: @@ -403,4 +404,4 @@ def _preprocess( else: output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) - return output \ No newline at end of file + return output diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index 5a731fddafb9..c7c58ba22743 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -19,6 +19,7 @@ from ...feature_extraction_utils import BatchFeature from ...utils import is_speech_available, is_torch_available + if is_speech_available(): import torchaudio.compliance.kaldi as ta_kaldi diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index 6d43d3e57fcb..7922773499ab 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -63,6 +63,32 @@ def __init__(self, **kwargs): self.nb_max_samples = self.max_length_s * self.sample_rate self.mel_filters_fusion = self._mel_filter_bank(self.spectrogram_config_fusion) + def _pad_single_clap(self, audio: np.ndarray, max_length: int, padding_mode: str) -> np.ndarray: + """ + CLAP-specific padding: handles "repeat" and "repeatpad" modes. + This is separate from the standard _pad_single used by the base class. + """ + current_length = audio.shape[-1] + if current_length >= max_length: + return audio + + if padding_mode == "repeat": + # Repeat the audio enough times to cover max_length + n_repeat = int(max_length / current_length) + audio = np.tile(audio, n_repeat + 1)[:max_length] + return audio + elif padding_mode == "repeatpad": + # Repeat then pad with zeros + n_repeat = int(max_length / current_length) + audio = np.tile(audio, n_repeat) + remaining = max_length - audio.shape[-1] + if remaining > 0: + audio = np.pad(audio, (0, remaining), mode="constant", constant_values=0) + return audio + else: + # For other modes, use standard padding via parent's _pad_single + return super()._pad_single(audio, max_length) + def _mel_filter_bank(self, spectrogram_config): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config @@ -81,7 +107,6 @@ def _extract_single_mel(self, waveform, spectrogram_config=None): if spectrogram_config is None: spectrogram_config = self.spectrogram_config stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config # Use the correct mel filters for this config if spectrogram_config is self.spectrogram_config_fusion: @@ -123,7 +148,7 @@ def _random_mel_fusion(self, mel, total_frames, chunk_frames): mel_shrink = mel_shrink[0][0].numpy() return np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) - def _get_input_mel(self, waveform, max_length, truncation, padding): + def _get_input_mel(self, waveform, max_length, truncation): hop_length = self.spectrogram_config.stft_config.hop_length if waveform.shape[0] > max_length: @@ -147,15 +172,6 @@ def _get_input_mel(self, waveform, max_length, truncation, padding): raise NotImplementedError(f"data_truncating {truncation} not implemented") else: longer = False - if waveform.shape[0] < max_length: - if padding == "repeat": - n_repeat = int(max_length / len(waveform)) - waveform = np.tile(waveform, n_repeat + 1)[:max_length] - if padding == "repeatpad": - n_repeat = int(max_length / len(waveform)) - waveform = np.tile(waveform, n_repeat) - waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) - if truncation == "fusion": input_mel = self._extract_single_mel(waveform, spectrogram_config=self.spectrogram_config_fusion) input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) @@ -164,13 +180,60 @@ def _get_input_mel(self, waveform, max_length, truncation, padding): return input_mel, longer - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - truncation_mode = self.truncation_mode - padding_mode = padding if padding else self.padding_mode + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, + do_batch_spectrogram=True, + **kwargs, + ): + # Use instance defaults when not explicitly provided (matching feature extractor behavior) + truncation_mode = self.truncation_mode if truncation is None else truncation + # For padding: use instance default only when not provided (None or False) + # When padding=True is passed, use it directly (feature extractor behavior) + if padding is None or padding is False: + padding_mode = self.padding_mode + else: + padding_mode = padding nb_max_samples = max_length if isinstance(max_length, int) and max_length > 0 else self.nb_max_samples + # Handle truncation: only apply if boolean truncation=True OR if using CLAP-specific string modes + # Note: CLAP's _get_input_mel handles truncation internally based on truncation_mode + # We only do pre-truncation here for standard boolean truncation=True case + if truncation is True: + if nb_max_samples is None: + raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") + trunc_length = nb_max_samples + if pad_to_multiple_of is not None and (trunc_length % pad_to_multiple_of != 0): + trunc_length = ((trunc_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + audio = [self._truncate_single(audio_el, max_length=trunc_length) for audio_el in audio] + + # Handle padding: CLAP-specific modes ("repeat", "repeatpad") vs standard modes + if padding_mode in ("repeat", "repeatpad"): + # Use CLAP's custom _pad_single_clap which handles repeat/repeatpad + audio = [self._pad_single_clap(audio_el, max_length=nb_max_samples, padding_mode=padding_mode) for audio_el in audio] + elif padding is not False and padding_mode is not False: + # Use standard padding flow for "longest", "max_length", True, etc. + from ...utils import PaddingStrategy + if padding_mode is True and nb_max_samples is not None: + # When padding=True and we have a max length, use MAX_LENGTH strategy + # (matching feature extractor behavior that pads to max_length) + padding_strategy = PaddingStrategy.MAX_LENGTH + elif isinstance(padding_mode, str) and padding_mode not in ("longest", "max_length", "do_not_pad"): + padding_strategy = PaddingStrategy.LONGEST # Default to longest for unknown string values + else: + padding_strategy = padding_mode + audio = self.pad(audio, padding_strategy, nb_max_samples, truncation=False, pad_to_multiple_of=pad_to_multiple_of) + + # Process each waveform through CLAP's mel extraction (handles truncation internally) padded_inputs = [ - self._get_input_mel(np.squeeze(waveform), nb_max_samples, truncation_mode, padding_mode) + self._get_input_mel(np.squeeze(waveform), nb_max_samples, truncation_mode) for waveform in audio ] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index db6bafbeb460..58a41a9f4c13 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -127,36 +127,27 @@ def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): return features def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): - """Custom STFT with HTK-flavor preemphasis.""" stft_cfg = spectrogram_config.stft_config preemphasis = spectrogram_config.preemphasis - features = [] - for waveform in audio: - if waveform.ndim == 1: - waveform = np.expand_dims(waveform, axis=0) - - frame_size_for_unfold = stft_cfg.win_length + 1 - frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length) - - if preemphasis is not None and preemphasis > 0.0: - if self.preemphasis_htk_flavor: - first_in_frame = frames_to_process[..., :1] * (1.0 - preemphasis) - rest_in_frame = ( - frames_to_process[..., 1:-1] - preemphasis * frames_to_process[..., :-2] - ) - frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) - else: - frames = frames_to_process[..., 1:] - preemphasis * frames_to_process[..., :-1] + frame_size_for_unfold = stft_cfg.win_length + 1 + frames_to_process = _unfold(audio, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length) + + # Preemphasis + if preemphasis is not None and preemphasis > 0.0: + if self.preemphasis_htk_flavor: + first_in_frame = frames_to_process[..., :1] * (1.0 - preemphasis) + rest_in_frame = frames_to_process[..., 1:-1] - preemphasis * frames_to_process[..., :-2] + frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) else: - frames = frames_to_process[..., :-1] + frames = frames_to_process[..., 1:] - preemphasis * frames_to_process[..., :-1] + else: + frames = frames_to_process[..., :-1] - frames = frames * self.window - stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) - magnitude_spec = np.abs(stft) - features.append(magnitude_spec.squeeze(0)) # (frames, n_freqs) + frames = frames * self.window # Broadcasting window - return features + stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) + magnitude_spec = np.abs(stft) def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): """Apply mel filterbank, log compression, and per-bin normalization.""" @@ -173,59 +164,5 @@ def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): result.append(log_mel_spec.astype(np.float32)) return result - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, - spectrogram_config=None, do_extract_spectrogram=None, **kwargs): - if max_length is None: - max_length = self.max_length - if truncation is None: - truncation = self.truncation - if pad_to_multiple_of is None: - pad_to_multiple_of = self.pad_to_multiple_of - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - - # Truncate then pad to longest in batch (matching FE "longest" padding strategy) - if truncation and max_length is not None: - audio = [a[..., :max_length] for a in audio] - - # Determine pad length (to longest in batch, rounded to multiple) - pad_length = max(a.shape[-1] for a in audio) - if pad_to_multiple_of is not None and (pad_length % pad_to_multiple_of != 0): - pad_length = ((pad_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - # Record original audio lengths and pad - original_lengths = [a.shape[-1] for a in audio] - audio = [self._pad_single(a, pad_length) for a in audio] - - # Extract spectrogram via orchestrator (_extract_spectrogram + _apply_mel_scale) - if do_extract_spectrogram is not False and spectrogram_config is not None: - features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) - else: - features = audio - - output_key = "audio_features" - stacked = np.stack(features, axis=0) - - # Compute per-sample spectrogram frame counts from the padded audio length - # This matches the FeatureExtractor which subsamples the attention mask from padded audio - frame_size_for_unfold = spectrogram_config.stft_config.win_length + 1 # 513 - hop_length = spectrogram_config.stft_config.hop_length # 160 - # Use padded length for frame count to match FeatureExtractor behavior - frame_counts = [(pad_length - frame_size_for_unfold) // hop_length + 1 for _ in original_lengths] - # Then limit by actual audio length to avoid counting frames beyond original audio - max_frames_per_audio = [(orig_len - 1) // hop_length + 1 for orig_len in original_lengths] - frame_counts = [min(fc, max_f) for fc, max_f in zip(frame_counts, max_frames_per_audio)] - - # Build mask: 1 for real frames, 0 for padded frames - max_frames = stacked.shape[1] - input_features_mask = np.array( - [[True] * fc + [False] * (max_frames - fc) for fc in frame_counts], dtype=bool - ) - - return BatchFeature( - data={output_key: stacked, "input_features_mask": input_features_mask}, - tensor_type=return_tensors, - ) - __all__ = ["Gemma3nAudioProcessor"] diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py index bdeebd2f55f8..88723799ed11 100644 --- a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py @@ -29,7 +29,6 @@ class MusicgenMelodyAudioProcessor(TorchAudioBackend): def __init__(self, **kwargs): super().__init__(**kwargs) import librosa - import numpy as np import torch self.chroma_filters = torch.from_numpy( diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index c93df2e8d0e7..82a6becab471 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -16,91 +16,109 @@ import torch from ...audio_processing_backends import TorchAudioBackend -from ...feature_extraction_utils import BatchFeature - -LOG_ZERO_GUARD_VALUE = 2**-24 -EPSILON = 1e-5 +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class ParakeetAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True - preemphasis = 0.97 - n_fft = 512 - hop_length = 160 - win_length = 400 - n_mels = 80 - - def __init__(self, **kwargs): - super().__init__(**kwargs) + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig( + n_fft=512, + hop_length=160, + win_length=400, + window_fn="hann_window", + power=2.0, + pad_mode="constant", + periodic=False, + ), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=0.0, + norm="slaney", + ), + preemphasis=0.97, + log_mode="log", + mel_floor=2**-24, + ) + + def _mel_filter_bank(self, spectrogram_config): + """Use librosa mel filters for exact numerical match with the feature extractor.""" + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config mel_filters = librosa.filters.mel( sr=self.sample_rate, - n_fft=self.n_fft, - n_mels=self.n_mels, - fmin=0.0, - fmax=self.sample_rate / 2, - norm="slaney", + n_fft=stft_cfg.n_fft, + n_mels=mel_cfg.n_mels, + fmin=mel_cfg.f_min, + fmax=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + norm=mel_cfg.norm, ) - self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Pad raw audio - lengths = [a.shape[-1] for a in audio] - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + # librosa returns (n_mels, freq); transpose to (freq, n_mels) for base class convention + return torch.from_numpy(mel_filters.T).to(torch.float32) + + def _pre_stft(self, audio, *, spectrogram_config, **kwargs): + preemphasis = spectrogram_config.preemphasis + if preemphasis is not None: + timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < self._audio_lengths.unsqueeze(1) + audio = torch.cat( + [audio[:, :1], audio[:, 1:] - preemphasis * audio[:, :-1]], dim=1 + ) + audio = audio.masked_fill(~timemask, 0.0) + return audio - # Stack into batch - waveform = torch.stack(audio) # (batch, length) - audio_lengths = torch.tensor(lengths) + def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): + # Detect audio lengths from zero-padded waveform for preemphasis masking and normalization + if audio.ndim == 2: + indices = torch.arange(audio.shape[-1], device=audio.device).expand_as(audio) + self._audio_lengths = indices.masked_fill(audio == 0, -1).max(dim=-1).values + 1 - # Preemphasis with masking for padded regions - if self.preemphasis is not None: - timemask = torch.arange(waveform.shape[1]).unsqueeze(0) < audio_lengths.unsqueeze(1) - waveform = torch.cat( - [waveform[:, :1], waveform[:, 1:] - self.preemphasis * waveform[:, :-1]], dim=1 - ) - waveform = waveform.masked_fill(~timemask, 0.0) + audio = self._pre_stft(audio, spectrogram_config=spectrogram_config, **kwargs) - # STFT - window = torch.hann_window(self.win_length, periodic=False) + # Compute STFT matching the FE's magnitude computation for exact numerical match + stft_cfg = spectrogram_config.stft_config + window = torch.hann_window(stft_cfg.win_length, periodic=stft_cfg.periodic, device=audio.device) stft = torch.stft( - waveform, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, + audio, + stft_cfg.n_fft, + hop_length=stft_cfg.hop_length, + win_length=stft_cfg.win_length, window=window, return_complex=True, - pad_mode="constant", + pad_mode=stft_cfg.pad_mode, ) - # Match FE: view_as_real -> pow(2).sum(-1).sqrt().pow(2) magnitudes = torch.view_as_real(stft) magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) magnitudes = magnitudes.pow(2) + return magnitudes - # Mel spectrogram + log - mel_spec = self.mel_filters @ magnitudes - mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + return torch.matmul(self.mel_filters.T, features) + + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + # Match FE: log(mel_spec + guard_value) instead of log(clamp(mel_spec, guard_value)) + features = torch.log(features + spectrogram_config.mel_floor) # (batch, mels, frames) -> (batch, frames, mels) - mel_spec = mel_spec.permute(0, 2, 1) + features = features.permute(0, 2, 1) # Per-utterance normalization + stft_cfg = spectrogram_config.stft_config + audio_lengths = self._audio_lengths features_lengths = torch.floor_divide( - audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length + audio_lengths + stft_cfg.n_fft // 2 * 2 - stft_cfg.n_fft, stft_cfg.hop_length ) - attention_mask = torch.arange(mel_spec.shape[1])[None, :] < features_lengths[:, None] + attention_mask = torch.arange(features.shape[1])[None, :] < features_lengths[:, None] mask = attention_mask.unsqueeze(-1) - mel_masked = mel_spec * mask + mel_masked = features * mask mean = mel_masked.sum(dim=1) / features_lengths.unsqueeze(-1) mean = mean.unsqueeze(1) variance = ((mel_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) std = torch.sqrt(variance).unsqueeze(1) - mel_spec = (mel_spec - mean) / (std + EPSILON) - mel_spec *= mask + features = (features - mean) / (std + 1e-5) + features *= mask - return BatchFeature( - data={"audio_features": mel_spec}, - tensor_type=return_tensors, - ) + return features __all__ = ["ParakeetAudioProcessor"] diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index f7b48a48823c..d778d5ebcc5a 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -91,7 +91,6 @@ def extract_spectrogram(self, audio, **kwargs): return [log_spec[i] for i in range(batch_size)] def _compute_audio_embed_size(self, audio_frames): - import torch integer = audio_frames // self.audio_compression_rate remainder = audio_frames % self.audio_compression_rate diff --git a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py index dca38d2a1d01..49490e8b9bce 100644 --- a/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py @@ -19,7 +19,7 @@ """Image processor class for Qwen2-VL.""" import math -from typing import Iterable +from collections.abc import Iterable import torch diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index abe048cd92f3..b1cf5ba1f4a0 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -19,6 +19,7 @@ from ...feature_extraction_utils import BatchFeature from ...utils import is_speech_available + if is_speech_available(): import torch import torchaudio.compliance.kaldi as ta_kaldi diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 34ddf266e524..866113b39b82 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -15,14 +15,13 @@ import torch from ...audio_processing_backends import TorchAudioBackend -from ...feature_extraction_utils import BatchFeature class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): sample_rate = 24000 force_mono = True add_channel_dim = True - + target_dB_FS = -25 eps = 1e-6 diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index ca333bd9e765..c7ea01d00de9 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -37,8 +37,8 @@ class WhisperAudioProcessor(TorchAudioBackend): log_mode="log10", ) - def extract_spectrogram(self, audio, **kwargs): - features = super().extract_spectrogram(audio, **kwargs) + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) features = features[..., :-1] # whisper skips last frame max_vals = features.amax(dim=(-2, -1), keepdim=True) diff --git a/tests/test_audio_processors_vs_feature_extractors.py b/tests/test_audio_processors_vs_feature_extractors.py new file mode 100644 index 000000000000..5a0e4595a8a6 --- /dev/null +++ b/tests/test_audio_processors_vs_feature_extractors.py @@ -0,0 +1,368 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests comparing the new AudioProcessor classes against the legacy FeatureExtractor classes. + +For each model, we: +1. Instantiate the FeatureExtractor via from_pretrained (from the Hub) +2. Instantiate the corresponding AudioProcessor directly +3. Run both on the same batched audio input +4. Assert torch.equal on the main output tensors +""" + +import numpy as np +import pytest +import torch + +from transformers.models.audio_spectrogram_transformer.audio_processing_audio_spectrogram_transformer import ( + AudioSpectrogramTransformerAudioProcessor, +) +from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import ( + ASTFeatureExtractor, +) +from transformers.models.clap.audio_processing_clap import ClapAudioProcessor +from transformers.models.clap.feature_extraction_clap import ClapFeatureExtractor +from transformers.models.clvp.audio_processing_clvp import ClvpAudioProcessor +from transformers.models.clvp.feature_extraction_clvp import ClvpFeatureExtractor +from transformers.models.dac.audio_processing_dac import DacAudioProcessor +from transformers.models.dac.feature_extraction_dac import DacFeatureExtractor +from transformers.models.dia.audio_processing_dia import DiaAudioProcessor +from transformers.models.dia.feature_extraction_dia import DiaFeatureExtractor +from transformers.models.encodec.audio_processing_encodec import EncodecAudioProcessor +from transformers.models.encodec.feature_extraction_encodec import EncodecFeatureExtractor +from transformers.models.gemma3n.audio_processing_gemma3n import Gemma3nAudioProcessor +from transformers.models.gemma3n.feature_extraction_gemma3n import Gemma3nAudioFeatureExtractor +from transformers.models.granite_speech.audio_processing_granite_speech import GraniteSpeechAudioProcessor +from transformers.models.granite_speech.feature_extraction_granite_speech import GraniteSpeechFeatureExtractor +from transformers.models.kyutai_speech_to_text.audio_processing_kyutai_speech_to_text import ( + KyutaiSpeechToTextAudioProcessor, +) +from transformers.models.kyutai_speech_to_text.feature_extraction_kyutai_speech_to_text import ( + KyutaiSpeechToTextFeatureExtractor, +) +from transformers.models.lasr.audio_processing_lasr import LasrAudioProcessor +from transformers.models.lasr.feature_extraction_lasr import LasrFeatureExtractor +from transformers.models.musicgen_melody.audio_processing_musicgen_melody import MusicgenMelodyAudioProcessor +from transformers.models.musicgen_melody.feature_extraction_musicgen_melody import MusicgenMelodyFeatureExtractor +from transformers.models.parakeet.audio_processing_parakeet import ParakeetAudioProcessor +from transformers.models.parakeet.feature_extraction_parakeet import ParakeetFeatureExtractor +from transformers.models.phi4_multimodal.audio_processing_phi4_multimodal import Phi4MultimodalAudioProcessor +from transformers.models.phi4_multimodal.feature_extraction_phi4_multimodal import Phi4MultimodalFeatureExtractor +from transformers.models.pop2piano.audio_processing_pop2piano import Pop2PianoAudioProcessor +from transformers.models.pop2piano.feature_extraction_pop2piano import Pop2PianoFeatureExtractor +from transformers.models.seamless_m4t.audio_processing_seamless_m4t import SeamlessM4tAudioProcessor +from transformers.models.seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor +from transformers.models.speech_to_text.audio_processing_speech_to_text import SpeechToTextAudioProcessor +from transformers.models.speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor +from transformers.models.speecht5.audio_processing_speecht5 import SpeechT5AudioProcessor +from transformers.models.speecht5.feature_extraction_speecht5 import SpeechT5FeatureExtractor +from transformers.models.univnet.audio_processing_univnet import UnivNetAudioProcessor +from transformers.models.univnet.feature_extraction_univnet import UnivNetFeatureExtractor +from transformers.models.vibevoice_acoustic_tokenizer.audio_processing_vibevoice_acoustic_tokenizer import ( + VibevoiceAcousticTokenizerAudioProcessor, +) +from transformers.models.vibevoice_acoustic_tokenizer.feature_extraction_vibevoice_acoustic_tokenizer import ( + VibeVoiceAcousticTokenizerFeatureExtractor, +) +from transformers.models.voxtral_realtime.audio_processing_voxtral_realtime import VoxtralRealtimeAudioProcessor +from transformers.models.voxtral_realtime.feature_extraction_voxtral_realtime import VoxtralRealtimeFeatureExtractor +from transformers.models.wav2vec2.audio_processing_wav2vec2 import Wav2Vec2AudioProcessor +from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor +from transformers.models.whisper.audio_processing_whisper import WhisperAudioProcessor +from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor + + +# Sentinel to exclude a key from default kwargs +_EXCLUDE = object() + +# Each entry is a dict with model config. Keys: +# name, hub_repo, fe_class, ap_class, fe_output_key, sample_rate +# fe_kwargs (optional): extra kwargs for the FE call (use _EXCLUDE to remove a default key) +# ap_kwargs (optional): extra kwargs for the AP call +MODEL_CONFIGS = [ + { + "name": "audio_spectrogram_transformer", + "hub_repo": "MIT/ast-finetuned-audioset-10-10-0.4593", + "fe_class": ASTFeatureExtractor, + "ap_class": AudioSpectrogramTransformerAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 16000, + }, + { + "name": "clap", + "hub_repo": "laion/clap-htsat-unfused", + "fe_class": ClapFeatureExtractor, + "ap_class": ClapAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 48000, + }, + { + "name": "clvp", + "hub_repo": "susnato/clvp_dev", + "fe_class": ClvpFeatureExtractor, + "ap_class": ClvpAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 22050, + "ap_init_kwargs": { + "mel_norms": [-7.0095, -6.0832, -4.644, -3.3562, -2.4548, -2.0097, -1.6036, -1.8641, -2.3728, -2.3455, -2.5947, -2.6695, -2.7129, -2.8555, -3.0251, -3.0889, -3.4261, -3.6759, -4.078, -4.4624, -4.7812, -5.0075, -5.1284, -5.2717, -5.4006, -5.4993, -5.531, -5.5878, -5.6726, -5.7016, -5.7943, -5.8831, -5.9537, -5.9989, -6.0305, -6.0539, -6.0748, -6.1163, -6.1481, -6.2476, -6.3195, -6.4457, -6.5377, -6.611, -6.6481, -6.6671, -6.6539, -6.6499, -6.6794, -6.7833, -6.9307, -7.0818, -7.1894, -7.2439, -7.3168, -7.3779, -7.4491, -7.5233, -7.6224, -7.7473, -7.8994, -8.0604, -8.2181, -8.3998, -8.5556, -8.7161, -8.8481, -8.9582, -9.0371, -9.0867, -9.1546, -9.2038, -9.2334, -9.2292, -9.2304, -9.268, -9.3156, -9.3716, -9.4165, -9.4822], + }, + }, + { + "name": "dac", + "hub_repo": "descript/dac_16khz", + "fe_class": DacFeatureExtractor, + "ap_class": DacAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 16000, + }, + { + "name": "dia", + "hub_repo": "nari-labs/Dia-1.6B-0626", + "fe_class": DiaFeatureExtractor, + "ap_class": DiaAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 44100, + }, + { + "name": "encodec", + "hub_repo": "facebook/encodec_24khz", + "fe_class": EncodecFeatureExtractor, + "ap_class": EncodecAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 24000, + }, + # { + # "name": "gemma3n", + # "hub_repo": "google/gemma-3n-e4b-it", + # "fe_class": Gemma3nAudioFeatureExtractor, + # "ap_class": Gemma3nAudioProcessor, + # "fe_output_key": "input_features", + # "sample_rate": 16000, + # # AP now implements custom FFT with HTK preemphasis and FFT overdrive + # }, + { + "name": "granite_speech", + "hub_repo": "ibm-granite/granite-speech-3.2-8b", + "fe_class": GraniteSpeechFeatureExtractor, + "ap_class": GraniteSpeechAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + "fe_kwargs": {"sampling_rate": _EXCLUDE, "return_tensors": _EXCLUDE, "padding": _EXCLUDE}, + }, + { + "name": "kyutai_speech_to_text", + "hub_repo": "kyutai/stt-2.6b-en-trfs", + "fe_class": KyutaiSpeechToTextFeatureExtractor, + "ap_class": KyutaiSpeechToTextAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 24000, + # AP now implements 1-second delay padding + }, + { + "name": "lasr", + "hub_repo": None, + "fe_class": LasrFeatureExtractor, + "ap_class": LasrAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + }, + { + "name": "musicgen_melody", + "hub_repo": "facebook/musicgen-melody", + "fe_class": MusicgenMelodyFeatureExtractor, + "ap_class": MusicgenMelodyAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 32000, + }, + { + "name": "parakeet", + "hub_repo": "nvidia/parakeet-ctc-1.1b", + "fe_class": ParakeetFeatureExtractor, + "ap_class": ParakeetAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + # AP now implements preemphasis, natural log, and slaney mel filters + }, + { + "name": "phi4_multimodal", + "hub_repo": "microsoft/Phi-4-multimodal-instruct", + "fe_class": Phi4MultimodalFeatureExtractor, + "ap_class": Phi4MultimodalAudioProcessor, + "fe_output_key": "audio_input_features", + "sample_rate": 16000, + }, + # { + # "name": "pop2piano", + # "hub_repo": "sweetcocoa/pop2piano", + # "fe_class": Pop2PianoFeatureExtractor, + # "ap_class": Pop2PianoAudioProcessor, + # "fe_output_key": "input_features", + # "sample_rate": 22050, + # "fe_kwargs": {"sampling_rate": [22050, 22050]}, + # # Skipped: Requires essentia library + # }, + { + "name": "seamless_m4t", + "hub_repo": "facebook/hf-seamless-m4t-medium", + "fe_class": SeamlessM4TFeatureExtractor, + "ap_class": SeamlessM4tAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + # AP now implements Kaldi-style features with stride concatenation + }, + { + "name": "speech_to_text", + "hub_repo": "facebook/s2t-small-librispeech-asr", + "fe_class": Speech2TextFeatureExtractor, + "ap_class": SpeechToTextAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + }, + { + "name": "speecht5", + "hub_repo": "microsoft/speecht5_asr", + "fe_class": SpeechT5FeatureExtractor, + "ap_class": SpeechT5AudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 16000, + }, + # { + # "name": "univnet", + # "hub_repo": "dg845/univnet-dev", + # "fe_class": UnivNetFeatureExtractor, + # "ap_class": UnivNetAudioProcessor, + # "fe_output_key": "input_features", + # "sample_rate": 24000, + # }, + { + "name": "vibevoice_acoustic_tokenizer", + "hub_repo": "microsoft/VibeVoice-AcousticTokenizer", + "fe_class": VibeVoiceAcousticTokenizerFeatureExtractor, + "ap_class": VibevoiceAcousticTokenizerAudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 24000, + "fe_kwargs": {"return_tensors": _EXCLUDE, "padding": _EXCLUDE}, + }, + { + "name": "voxtral_realtime", + "hub_repo": "mistralai/Voxtral-Mini-4B-Realtime-2602", + "fe_class": VoxtralRealtimeFeatureExtractor, + "ap_class": VoxtralRealtimeAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + }, + { + "name": "wav2vec2", + "hub_repo": "facebook/wav2vec2-large-960h-lv60-self", + "fe_class": Wav2Vec2FeatureExtractor, + "ap_class": Wav2Vec2AudioProcessor, + "fe_output_key": "input_values", + "sample_rate": 16000, + }, + { + "name": "whisper", + "hub_repo": "openai/whisper-small", + "fe_class": WhisperFeatureExtractor, + "ap_class": WhisperAudioProcessor, + "fe_output_key": "input_features", + "sample_rate": 16000, + "ap_kwargs": {"max_length": None, "truncation": False}, + }, +] + + +def _make_audio_batch(sample_rate: int, seed: int = 42) -> list[np.ndarray]: + """Create a deterministic batched audio input: two clips of different lengths.""" + rng = np.random.default_rng(seed) + return [ + rng.standard_normal(sample_rate).astype(np.float32), # 1 second + rng.standard_normal(sample_rate * 2).astype(np.float32), # 2 seconds + ] + + +@pytest.mark.parametrize( + "config", + MODEL_CONFIGS, + ids=[c["name"] if isinstance(c, dict) else c.values[0]["name"] for c in MODEL_CONFIGS], +) +def test_audio_processor_matches_feature_extractor(config): + hub_repo = config["hub_repo"] + fe_class = config["fe_class"] + ap_class = config["ap_class"] + fe_output_key = config["fe_output_key"] + sample_rate = config["sample_rate"] + + # Instantiate feature extractor from the Hub (or with defaults if hub_repo is None) + if hub_repo is not None: + fe = fe_class.from_pretrained(hub_repo) + else: + fe = fe_class() + + # Instantiate audio processor directly + ap_init_kwargs = config.get("ap_init_kwargs", {}) + ap = ap_class(**ap_init_kwargs) + + # Create batched audio input (deterministic) + audio_batch = _make_audio_batch(sample_rate) + + # Default kwargs + default_fe_kwargs = { + "sampling_rate": sample_rate, + "return_tensors": "pt", + "padding": True, + } + default_ap_kwargs = { + "sampling_rate": sample_rate, + "return_tensors": "pt", + "padding": True, + } + + # Apply per-model overrides (use _EXCLUDE sentinel to remove default keys) + fe_kwargs = {**default_fe_kwargs, **config.get("fe_kwargs", {})} + fe_kwargs = {k: v for k, v in fe_kwargs.items() if v is not _EXCLUDE} + ap_kwargs = {**default_ap_kwargs, **config.get("ap_kwargs", {})} + ap_kwargs = {k: v for k, v in ap_kwargs.items() if v is not _EXCLUDE} + + # Run feature extractor (copy inputs since some FEs mutate the list in-place) + fe_output = fe([x.copy() for x in audio_batch], **fe_kwargs) + + # Run audio processor + ap_output = ap([x.copy() for x in audio_batch], **ap_kwargs) + + fe_to_ap_key_map = { + "input_features": "audio_features", + "input_values": "audio_values", + "audio_input_features": "audio_features", + } + + for fe_key in fe_output.keys(): + if fe_key == "attention_mask" or fe_key == "padding_mask" or fe_key == "input_features_mask": + continue + ap_key = fe_to_ap_key_map.get(fe_key, fe_key) + assert ap_key in ap_output, f"Key {ap_key} (from FE key {fe_key}) not found in audio processor output" + fe_tensor = fe_output[fe_key] + ap_tensor = ap_output[ap_key] + + if not isinstance(fe_tensor, torch.Tensor): + fe_tensor = torch.tensor(fe_tensor) + if not isinstance(ap_tensor, torch.Tensor): + ap_tensor = torch.tensor(ap_tensor) + + assert fe_tensor.shape == ap_tensor.shape, ( + f"Shape mismatch for key '{fe_key}' (ap key '{ap_key}'): fe {fe_tensor.shape} vs ap {ap_tensor.shape}" + ) + assert torch.equal(fe_tensor, ap_tensor), ( + f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}" + ) From e84a0712cbc1de6f021dca7055646584086f83b3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 15:07:42 +0100 Subject: [PATCH 11/28] _preprocess merged for both backends --- src/transformers/audio_processing_backends.py | 71 ++++------- src/transformers/audio_processing_utils.py | 112 +++++++++++++++--- 2 files changed, 118 insertions(+), 65 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index ca211d1c37d6..9669ce33a051 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -181,32 +181,22 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) - def _preprocess( - self, - audio: list[np.ndarray], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - do_batch_spectrogram=True, - **kwargs, - ) -> BatchFeature: - import numpy as np - - # pad and truncate - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + def _to_batch(self, audio): + return np.stack(audio) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): if do_extract_spectrogram: - audio = np.stack(audio) if do_batch_spectrogram else audio - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) - output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) + spec_cfg = spectrogram_config or self.spectrogram_config + audio_lengths = np.array([end - start for start, end in audio_ranges]) + features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) + n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) + mask = np.arange(n_features)[None, :] < features_lengths[:, None] + return mask.astype(np.int32) else: - output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) - - return output + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return mask class TorchAudioBackend(BaseAudioProcessor): @@ -379,29 +369,18 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) - def _preprocess( - self, - audio: list["torch.Tensor"], - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - do_batch_spectrogram=True, - **kwargs, - ) -> BatchFeature: - import torch - - # pad and truncate - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + def _to_batch(self, audio): + return torch.stack(audio) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): if do_extract_spectrogram: - audio = torch.stack(audio) if do_batch_spectrogram else audio - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) - output = BatchFeature({"audio_features": feature}, tensor_type=return_tensors) + spec_cfg = spectrogram_config or self.spectrogram_config + audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) + features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) + n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) + return (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32) else: - output = BatchFeature({"audio_values": audio}, tensor_type=return_tensors) - - return output + mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return mask diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index e379be374edc..945af8126167 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -21,42 +21,48 @@ from .audio_processing_base import AudioProcessingMixin from .audio_utils import AudioInput, SpectrogramConfig, make_list_of_audio from .feature_extraction_utils import BatchFeature +from .tokenization_utils_base import PaddingStrategy, TruncationStrategy from .processing_utils import AudioKwargs from .utils import PaddingStrategy, TensorType, logging +from typing import TypedDict -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) -class AudioProcessingKwargs(AudioKwargs, total=False): - """Extended keyword arguments for the audio processing pipeline.""" - do_pad_values: bool | None +class AudioKwargs(TypedDict, total=False): + sampling_rate: int | None spectrogram_config: dict | SpectrogramConfig | None do_extract_spectrogram: bool | None - do_pad_features: bool | None do_resample: bool | None - generator: np.random.Generator | None + return_tensors: str | TensorType | None + padding: bool | str | PaddingStrategy | None + max_length: int | None + truncation: bool | str | TruncationStrategy | None + pad_to_multiple_of: int | None class BaseAudioProcessor(AudioProcessingMixin): model_input_names = ["audio"] - valid_kwargs = AudioProcessingKwargs + valid_kwargs = AudioKwargs unused_kwargs = None - feature_size = 1 + + # global defaults + sample_rate: int = None + force_mono: bool = None + + # padding defaults padding = True padding_side = "right" padding_value = 0.0 max_length = None truncation = None - return_attention_mask = True - - sample_rate: int = None - force_mono: bool = None + pad_to_multiple_of = None + return_attention_mask = True # TODO: we should either get a more appropriate name, either always return input mask spectrogram_config = None do_extract_spectrogram = None - pad_to_multiple_of = None def __init__( self, @@ -94,10 +100,10 @@ def __init__( if not hasattr(self, "mel_filters"): self.mel_filters = self._mel_filter_bank(self.spectrogram_config) - def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: + def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: return self.preprocess(audio, *args, **kwargs) - def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioProcessingKwargs]) -> BatchFeature: + def preprocess(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: """ Preprocess an audio or a batch of audio. """ @@ -121,14 +127,55 @@ def _preprocess_audio_like_inputs( audio: AudioInput, *args, sample_rate: int | None = None, - **kwargs: Unpack[AudioProcessingKwargs], + **kwargs: Unpack[AudioKwargs], ) -> BatchFeature: audio = self._prepare_audio_like_inputs(audio=audio, sample_rate=sample_rate) return self._preprocess(audio, *args, **kwargs) - def _preprocess(self, *args, **kwargs): + def _to_batch(self, audio): + """Stack a list of audio arrays/tensors into a batch. Implemented by backend subclasses.""" + raise NotImplementedError + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + """Build an attention mask from audio_ranges. Implemented by backend subclasses.""" raise NotImplementedError + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, + do_batch_spectrogram=True, + **kwargs, + ) -> BatchFeature: + # pad and truncate + audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = audio[0].shape[-1] + + if do_extract_spectrogram: + audio = self._to_batch(audio) if do_batch_spectrogram else audio + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + output = {"audio_features": feature} + + if self.return_attention_mask: + output["audio_features_mask"] = self._get_mask( + audio_ranges, padded_length, do_extract_spectrogram=True, spectrogram_config=spectrogram_config + ) + else: + output = {"audio_values": audio} + + if self.return_attention_mask: + output["audio_values_mask"] = self._get_mask( + audio_ranges, padded_length, do_extract_spectrogram=False, spectrogram_config=None + ) + + return BatchFeature(data=output, tensor_type=return_tensors) + def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list: """ Prepare audio-like inputs for processing by structuring and then converting each @@ -192,7 +239,7 @@ def pad( max_length: int | None = None, truncation: bool = False, pad_to_multiple_of: int | None = None, - ): + ) -> tuple[list, list[tuple[int, int]]]: padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) if truncation: @@ -211,10 +258,20 @@ def pad( if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + actual_lengths = [audio_el.shape[-1] for audio_el in audio] + if padding_strategy != PaddingStrategy.DO_NOT_PAD: audio = [self._pad_single(audio_el, max_length=max_length) for audio_el in audio] - return audio + audio_ranges = [] + for i, length in enumerate(actual_lengths): + padded_length = audio[i].shape[-1] + if self.padding_side == "left": + audio_ranges.append((padded_length - length, padded_length)) + else: + audio_ranges.append((0, length)) + + return audio, audio_ranges def _truncate_single(self, audio_el, max_length: int): """Truncate a single audio element to max_length along the time axis.""" @@ -297,6 +354,23 @@ def _normalize_magnitude(self, *args, **kwargs): def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): raise NotImplementedError + def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): + """ + Convert raw audio sample lengths to the number of feature frames after spectrogram extraction. + + By default returns `audio_lengths // hop_length`, which gives the number of valid (non-padding) + feature frames for centered STFT. When `include_center_frame=True` and the STFT uses centering, + adds 1 to account for the extra frame produced by centered STFT. + + Override this method in subclasses that use non-standard STFT configurations (e.g., unfold-based + or non-centered STFT). + """ + hop_length = spectrogram_config.stft_config.hop_length + lengths = audio_lengths // hop_length + if include_center_frame and spectrogram_config.stft_config.center: + lengths = lengths + 1 + return lengths + def _get_padding_strategies(self, padding=False, max_length=None): """Find the correct padding strategy.""" if padding is not False: From 3ecd21f7eec744bce6381ce021b2512fcca505eb Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:19:06 +0100 Subject: [PATCH 12/28] frequency_bin_mode and refacto --- src/transformers/audio_processing_backends.py | 13 ++++++++----- src/transformers/audio_processing_utils.py | 18 +++++++----------- src/transformers/audio_utils.py | 1 + 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 9669ce33a051..1107dde05e9f 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -179,6 +179,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + frequency_bin_mode=mel_cfg.frequency_bin_mode, ) def _to_batch(self, audio): @@ -190,13 +191,13 @@ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectro audio_lengths = np.array([end - start for start, end in audio_ranges]) features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) - mask = np.arange(n_features)[None, :] < features_lengths[:, None] - return mask.astype(np.int32) + mask = (np.arange(n_features)[None, :] < features_lengths[:, None]).astype(np.int32) + return {"audio_features_mask": mask} else: mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) for i, (start, end) in enumerate(audio_ranges): mask[i, start:end] = 1 - return mask + return {"audio_values_mask": mask} class TorchAudioBackend(BaseAudioProcessor): @@ -367,6 +368,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + frequency_bin_mode=mel_cfg.frequency_bin_mode, ) def _to_batch(self, audio): @@ -378,9 +380,10 @@ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectro audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) - return (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32) + mask = (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32) + return {"audio_features_mask": mask} else: mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) for i, (start, end) in enumerate(audio_ranges): mask[i, start:end] = 1 - return mask + return {"audio_values_mask": mask} diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 945af8126167..21dfdfe15223 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -137,7 +137,8 @@ def _to_batch(self, audio): raise NotImplementedError def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - """Build an attention mask from audio_ranges. Implemented by backend subclasses.""" + """Build attention mask dict from audio_ranges. Returns a dict of {key: mask} to merge into output. + Implemented by backend subclasses.""" raise NotImplementedError def _preprocess( @@ -161,18 +162,13 @@ def _preprocess( audio = self._to_batch(audio) if do_batch_spectrogram else audio feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) output = {"audio_features": feature} - - if self.return_attention_mask: - output["audio_features_mask"] = self._get_mask( - audio_ranges, padded_length, do_extract_spectrogram=True, spectrogram_config=spectrogram_config - ) else: - output = {"audio_values": audio} + output = {"audio_values": self._to_batch(audio)} - if self.return_attention_mask: - output["audio_values_mask"] = self._get_mask( - audio_ranges, padded_length, do_extract_spectrogram=False, spectrogram_config=None - ) + if self.return_attention_mask: + output.update(self._get_mask( + audio_ranges, padded_length, do_extract_spectrogram=do_extract_spectrogram, spectrogram_config=spectrogram_config + )) return BatchFeature(data=output, tensor_type=return_tensors) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index e71b87f33deb..d26ad649a0c3 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -102,6 +102,7 @@ class MelScaleConfig: mel_scale: str = "htk" norm: str | None = None triangularize_in_mel_space: bool = False + frequency_bin_mode: str = "rfft" def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} From b322be5e802c26f1c7568303d75aa6ab7fbc8e09 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:19:53 +0100 Subject: [PATCH 13/28] ensure BC + deprecate --- ...xtraction_audio_spectrogram_transformer.py | 223 +-------- .../models/clap/feature_extraction_clap.py | 350 +------------- .../models/clvp/feature_extraction_clvp.py | 223 +-------- .../models/dac/feature_extraction_dac.py | 156 +----- .../models/dia/feature_extraction_dia.py | 165 +------ .../encodec/feature_extraction_encodec.py | 191 +------- .../gemma3n/feature_extraction_gemma3n.py | 319 +------------ .../feature_extraction_granite_speech.py | 172 +------ ...eature_extraction_kyutai_speech_to_text.py | 223 +-------- .../models/lasr/feature_extraction_lasr.py | 261 +--------- .../feature_extraction_musicgen_melody.py | 322 +------------ .../parakeet/feature_extraction_parakeet.py | 271 +---------- .../pe_audio/feature_extraction_pe_audio.py | 146 +----- .../feature_extraction_phi4_multimodal.py | 269 +---------- .../pop2piano/feature_extraction_pop2piano.py | 438 +---------------- .../feature_extraction_seamless_m4t.py | 291 +----------- .../feature_extraction_speech_to_text.py | 297 +----------- .../speecht5/feature_extraction_speecht5.py | 360 +------------- .../univnet/feature_extraction_univnet.py | 444 +----------------- ...extraction_vibevoice_acoustic_tokenizer.py | 132 +----- .../feature_extraction_voxtral_realtime.py | 234 +-------- .../wav2vec2/feature_extraction_wav2vec2.py | 225 +-------- .../whisper/feature_extraction_whisper.py | 331 +------------ src/transformers/utils/deprecation.py | 35 ++ ..._audio_processors_vs_feature_extractors.py | 283 +++++++++-- 25 files changed, 369 insertions(+), 5992 deletions(-) diff --git a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py index ee69d1d0b991..80faf5663dec 100644 --- a/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py @@ -11,225 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for Audio Spectrogram Transformer. -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_audio_spectrogram_transformer import AudioSpectrogramTransformerAudioProcessor -import numpy as np - -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, is_speech_available, is_torch_available, logging - - -if is_speech_available(): - import torchaudio.compliance.kaldi as ta_kaldi - -if is_torch_available(): - import torch - - -logger = logging.get_logger(__name__) - - -class ASTFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Audio Spectrogram Transformer (AST) feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy - otherwise, pads/truncates them to a fixed length and normalizes them using a mean and standard deviation. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - num_mel_bins (`int`, *optional*, defaults to 128): - Number of Mel-frequency bins. - max_length (`int`, *optional*, defaults to 1024): - Maximum length to which to pad/truncate the extracted features. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether or not to normalize the log-Mel features using `mean` and `std`. - mean (`float`, *optional*, defaults to -4.2677393): - The mean value used to normalize the log-Mel features. Uses the AudioSet mean by default. - std (`float`, *optional*, defaults to 4.5689974): - The standard deviation value used to normalize the log-Mel features. Uses the AudioSet standard deviation - by default. - return_attention_mask (`bool`, *optional*, defaults to `False`): - Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`. - """ - - model_input_names = ["input_values", "attention_mask"] - - def __init__( - self, - feature_size=1, - sampling_rate=16000, - num_mel_bins=128, - max_length=1024, - padding_value=0.0, - do_normalize=True, - mean=-4.2677393, - std=4.5689974, - return_attention_mask=False, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.num_mel_bins = num_mel_bins - self.max_length = max_length - self.do_normalize = do_normalize - self.mean = mean - self.std = std - self.return_attention_mask = return_attention_mask - - if not is_speech_available(): - mel_filters = mel_filter_bank( - num_frequency_bins=257, - num_mel_filters=self.num_mel_bins, - min_frequency=20, - max_frequency=sampling_rate // 2, - sampling_rate=sampling_rate, - norm=None, - mel_scale="kaldi", - triangularize_in_mel_space=True, - ) - - self.mel_filters = mel_filters - self.window = window_function(400, "hann", periodic=False) - - def _extract_fbank_features( - self, - waveform: np.ndarray, - max_length: int, - ) -> np.ndarray: - """ - Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs - and hence the waveform should not be normalized before feature extraction. - """ - # waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers - if is_speech_available(): - waveform = torch.from_numpy(waveform).unsqueeze(0) - fbank = ta_kaldi.fbank( - waveform, - sample_frequency=self.sampling_rate, - window_type="hanning", - num_mel_bins=self.num_mel_bins, - ) - else: - waveform = np.squeeze(waveform) - fbank = spectrogram( - waveform, - self.window, - frame_length=400, - hop_length=160, - fft_length=512, - power=2.0, - center=False, - preemphasis=0.97, - mel_filters=self.mel_filters, - log_mel="log", - mel_floor=1.192092955078125e-07, - remove_dc_offset=True, - ).T - - fbank = torch.from_numpy(fbank) - - n_frames = fbank.shape[0] - difference = max_length - n_frames - - # pad or truncate, depending on difference - if difference > 0: - pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference)) - fbank = pad_module(fbank) - elif difference < 0: - fbank = fbank[0:max_length, :] - - fbank = fbank.numpy() - - return fbank - - def normalize(self, input_values: np.ndarray) -> np.ndarray: - return (input_values - (self.mean)) / (self.std * 2) - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - sampling_rate: int | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - """ - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [raw_speech] - - # extract fbank features and pad/truncate to max_length - features = [self._extract_fbank_features(waveform, max_length=self.max_length) for waveform in raw_speech] - - # convert into BatchFeature - padded_inputs = BatchFeature({"input_values": features}) - - # make sure list is in array format - input_values = padded_inputs.get("input_values") - if isinstance(input_values[0], list): - padded_inputs["input_values"] = [np.asarray(feature, dtype=np.float32) for feature in input_values] - - # normalization - if self.do_normalize: - padded_inputs["input_values"] = [self.normalize(feature) for feature in input_values] - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +ASTFeatureExtractor = deprecated_feature_extractor( + AudioSpectrogramTransformerAudioProcessor, "ASTFeatureExtractor" +) __all__ = ["ASTFeatureExtractor"] diff --git a/src/transformers/models/clap/feature_extraction_clap.py b/src/transformers/models/clap/feature_extraction_clap.py index 8f0a34d2cf4e..79c3c9353825 100644 --- a/src/transformers/models/clap/feature_extraction_clap.py +++ b/src/transformers/models/clap/feature_extraction_clap.py @@ -11,354 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for CLAP.""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_clap import ClapAudioProcessor -import copy -from typing import Any - -import numpy as np -import torch - -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging -from ...utils.import_utils import requires - - -logger = logging.get_logger(__name__) - - -@requires(backends=("torch",)) -class ClapFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a CLAP feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the *Short Time - Fourier Transform* (STFT) which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 64): - The feature dimension of the extracted Mel spectrograms. This corresponds to the number of mel filters - (`n_mels`). - sampling_rate (`int`, *optional*, defaults to 48000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). This only serves - to warn users if the audio fed to the feature extractor does not have the same sampling rate. - hop_length (`int`,*optional*, defaults to 480): - Length of the overlapping windows for the STFT used to obtain the Mel Spectrogram. The audio will be split - in smaller `frames` with a step of `hop_length` between each frame. - max_length_s (`int`, *optional*, defaults to 10): - The maximum input length of the model in seconds. This is used to pad the audio. - fft_window_size (`int`, *optional*, defaults to 1024): - Size of the window (in samples) on which the Fourier transform is applied. This controls the frequency - resolution of the spectrogram. 400 means that the fourier transform is computed on windows of 400 samples. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - return_attention_mask (`bool`, *optional*, defaults to `False`): - Whether or not the model should return the attention masks corresponding to the input. - frequency_min (`float`, *optional*, defaults to 0): - The lowest frequency of interest. The STFT will not be computed for values below this. - frequency_max (`float`, *optional*, defaults to 14000): - The highest frequency of interest. The STFT will not be computed for values above this. - top_db (`float`, *optional*): - The highest decibel value used to convert the mel spectrogram to the log scale. For more details see the - `audio_utils.power_to_db` function - truncation (`str`, *optional*, defaults to `"fusion"`): - Truncation pattern for long audio inputs. Two patterns are available: - - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a - downsampled version of the entire mel spectrogram. - If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a copy - of the original mel obtained from the padded audio. - - `rand_trunc` will select a random crop of the mel spectrogram. - padding (`str`, *optional*, defaults to `"repeatpad"`): - Padding pattern for shorter audio inputs. Three patterns were originally implemented: - - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. - - `repeat`: the audio is repeated and then cut to fit the `max_length` - - `pad`: the audio is padded. - """ - - model_input_names = ["input_features", "is_longer"] - - def __init__( - self, - feature_size=64, - sampling_rate=48_000, - hop_length=480, - max_length_s=10, - fft_window_size=1024, - padding_value=0.0, - return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask - frequency_min: float = 0, - frequency_max: float = 14_000, - top_db: int | None = None, - truncation: str = "fusion", - padding: str = "repeatpad", - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - self.top_db = top_db - self.truncation = truncation - self.padding = padding - self.fft_window_size = fft_window_size - self.nb_frequency_bins = (fft_window_size >> 1) + 1 - self.hop_length = hop_length - self.max_length_s = max_length_s - self.nb_max_samples = max_length_s * sampling_rate - self.sampling_rate = sampling_rate - self.frequency_min = frequency_min - self.frequency_max = frequency_max - self.mel_filters = mel_filter_bank( - num_frequency_bins=self.nb_frequency_bins, - num_mel_filters=feature_size, - min_frequency=frequency_min, - max_frequency=frequency_max, - sampling_rate=sampling_rate, - norm=None, - mel_scale="htk", - ) - self.mel_filters_slaney = mel_filter_bank( - num_frequency_bins=self.nb_frequency_bins, - num_mel_filters=feature_size, - min_frequency=frequency_min, - max_frequency=frequency_max, - sampling_rate=sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. - - Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, except for the - mel filter banks, which do not need to be saved or printed as they are too long. - """ - output = copy.deepcopy(self.__dict__) - output["feature_extractor_type"] = self.__class__.__name__ - if "mel_filters" in output: - del output["mel_filters"] - if "mel_filters_slaney" in output: - del output["mel_filters_slaney"] - return output - - def _np_extract_fbank_features(self, waveform: np.ndarray, mel_filters: np.ndarray | None = None) -> np.ndarray: - """ - Compute the log-mel spectrogram of the provided `waveform` using the Hann window. In CLAP, two different filter - banks are used depending on the truncation pattern: - - `self.mel_filters`: they correspond to the default parameters of `torchaudio` which can be obtained from - calling `torchaudio.transforms.MelSpectrogram().mel_scale.fb`. These filters are used when `truncation` - is set to `"fusion"`. - - `self.mel_filteres_slaney` : they correspond to the default parameters of `librosa` which used - `librosa.filters.mel` when computing the mel spectrogram. These filters were only used in the original - implementation when the truncation mode is not `"fusion"`. - """ - log_mel_spectrogram = spectrogram( - waveform, - window_function(self.fft_window_size, "hann"), - frame_length=self.fft_window_size, - hop_length=self.hop_length, - power=2.0, - mel_filters=mel_filters, - log_mel="dB", - ) - return log_mel_spectrogram.T - - def _random_mel_fusion(self, mel, total_frames, chunk_frames): - ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) - if len(ranges[1]) == 0: - # if the audio is too short, we just use the first chunk - ranges[1] = [0] - if len(ranges[2]) == 0: - # if the audio is too short, we just use the first chunk - ranges[2] = [0] - # randomly choose index for each part - idx_front = np.random.choice(ranges[0]) - idx_middle = np.random.choice(ranges[1]) - idx_back = np.random.choice(ranges[2]) - - mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] - mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] - mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] - - mel = torch.tensor(mel[None, None, :]) - mel_shrink = torch.nn.functional.interpolate( - mel, size=[chunk_frames, 64], mode="bilinear", align_corners=False - ) - mel_shrink = mel_shrink[0][0].numpy() - mel_fusion = np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) - return mel_fusion - - def _get_input_mel(self, waveform: np.ndarray, max_length, truncation, padding) -> np.ndarray: - """ - Extracts the mel spectrogram and prepares it for the mode based on the `truncation` and `padding` arguments. - Four different path are possible: - - `truncation="fusion"` and the length of the waveform is greater than the max length: the mel spectrogram - will be computed on the entire audio. 3 random crops and a dowsampled version of the full mel spectrogram - are then stacked together. They will later be used for `feature_fusion`. - - `truncation="rand_trunc"` and the length of the waveform is smaller than the max length: the audio is - padded based on `padding`. - - `truncation="fusion"` and the length of the waveform is smaller than the max length: the audio is padded - based on `padding`, and is repeated `4` times. - - `truncation="rand_trunc"` and the length of the waveform is greater than the max length: the mel - spectrogram will be computed on a random crop of the waveform. - - """ - if waveform.shape[0] > max_length: - if truncation == "rand_trunc": - longer = True - # random crop to max_length (for compatibility) -> this should be handled by self.pad - overflow = len(waveform) - max_length - idx = np.random.randint(0, overflow + 1) - waveform = waveform[idx : idx + max_length] - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] - elif truncation == "fusion": - mel = self._np_extract_fbank_features(waveform, self.mel_filters) - chunk_frames = max_length // self.hop_length + 1 # the +1 related to how the spectrogram is computed - total_frames = mel.shape[0] - if chunk_frames == total_frames: - # there is a corner case where the audio length is larger than max_length but smaller than max_length+hop_length. - # In this case, we just use the whole audio. - input_mel = np.stack([mel, mel, mel, mel], axis=0) - longer = False - else: - input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) - longer = True - else: - raise NotImplementedError(f"data_truncating {truncation} not implemented") - - else: - longer = False - # only use repeat as a new possible value for padding. you repeat the audio before applying the usual max_length padding - if waveform.shape[0] < max_length: - if padding == "repeat": - n_repeat = int(max_length / len(waveform)) - waveform = np.tile(waveform, n_repeat + 1)[:max_length] - if padding == "repeatpad": - n_repeat = int(max_length / len(waveform)) - waveform = np.tile(waveform, n_repeat) - waveform = np.pad(waveform, (0, max_length - waveform.shape[0]), mode="constant", constant_values=0) - - if truncation == "fusion": - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters) - input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) - else: - input_mel = self._np_extract_fbank_features(waveform, self.mel_filters_slaney)[None, :] - - return input_mel, longer - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: str | None = None, - padding: str | None = None, - max_length: int | None = None, - sampling_rate: int | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - truncation (`str`, *optional*): - Truncation pattern for long audio inputs. Two patterns are available: - - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and - a downsampled version of the entire mel spectrogram. - If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a - copy of the original mel obtained from the padded audio. - - `rand_trunc` will select a random crop of the mel spectrogram. - padding (`str`, *optional*): - Padding pattern for shorter audio inputs. Three patterns were originally implemented: - - `repeatpad`: the audio is repeated, and then padded to fit the `max_length`. - - `repeat`: the audio is repeated and then cut to fit the `max_length` - - `pad`: the audio is padded. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - `'pt'`: Return PyTorch `torch.np.array` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - """ - truncation = truncation if truncation is not None else self.truncation - padding = padding if padding else self.padding - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray(speech, dtype=np.float64) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float64) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float64) - - # always return batch - if not is_batched: - raw_speech = [np.asarray(raw_speech)] - - # convert to mel spectrogram, truncate and pad if needed. - padded_inputs = [ - self._get_input_mel(waveform, max_length if max_length else self.nb_max_samples, truncation, padding) - for waveform in raw_speech - ] - - input_mel = [] - is_longer = [] - for mel, longer in padded_inputs: - input_mel.append(mel) - is_longer.append(longer) - - if truncation == "fusion" and sum(is_longer) == 0: - # if no audio is longer than 10s, then randomly select one audio to be longer - rand_idx = np.random.randint(0, len(input_mel)) - is_longer[rand_idx] = True - - if isinstance(input_mel[0], list): - input_mel = [np.asarray(feature, dtype=np.float64) for feature in input_mel] - - # is_longer is a list of bool - is_longer = [[longer] for longer in is_longer] - - input_features = {"input_features": input_mel, "is_longer": is_longer} - input_features = BatchFeature(input_features) - - if return_tensors is not None: - input_features = input_features.convert_to_tensors(return_tensors) - - return input_features +ClapFeatureExtractor = deprecated_feature_extractor(ClapAudioProcessor, "ClapFeatureExtractor") __all__ = ["ClapFeatureExtractor"] diff --git a/src/transformers/models/clvp/feature_extraction_clvp.py b/src/transformers/models/clvp/feature_extraction_clvp.py index cc39e6aca677..e5966a9b2f02 100644 --- a/src/transformers/models/clvp/feature_extraction_clvp.py +++ b/src/transformers/models/clvp/feature_extraction_clvp.py @@ -11,227 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_clvp import ClvpAudioProcessor -""" -Feature extractor class for CLVP -""" - -import numpy as np - -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging - - -logger = logging.get_logger(__name__) - - -class ClvpFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a CLVP feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts log-mel-spectrogram features from raw speech using a custom numpy implementation of the `Short - Time Fourier Transform` which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 22050): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - default_audio_length (`int`, *optional*, defaults to 6): - The default length of raw audio in seconds. If `max_length` is not set during `__call__` then it will - automatically be set to default_audio_length * `self.sampling_rate`. - hop_length (`int`, *optional*, defaults to 256): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - chunk_length (`int`, *optional*, defaults to 30): - The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio - sequences. - n_fft (`int`, *optional*, defaults to 1024): - Size of the Fourier transform. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - mel_norms (`list` of length `feature_size`, *optional*): - If `mel_norms` is provided then it will be used to normalize the log-mel spectrograms along each - mel-filter. - return_attention_mask (`bool`, *optional*, defaults to `False`): - Whether to return the attention mask. If left to the default, it will return the attention mask. - - [What are attention masks?](../glossary#attention-mask) - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=80, - sampling_rate=22050, - default_audio_length=6, - hop_length=256, - chunk_length=30, - n_fft=1024, - padding_value=0.0, - mel_norms=None, - return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - self.n_fft = n_fft - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_samples = chunk_length * sampling_rate - self.nb_max_frames = self.n_samples // hop_length - self.sampling_rate = sampling_rate - self.default_audio_length = default_audio_length - self.mel_norms = mel_norms - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + (n_fft // 2), - num_mel_filters=feature_size, - min_frequency=0.0, - max_frequency=8000.0, - sampling_rate=sampling_rate, - norm="slaney", - mel_scale="htk", - ) - - def _np_extract_fbank_features(self, waveform: np.ndarray) -> np.ndarray: - """ - This method first computes the log-mel spectrogram of the provided audio then applies normalization along the - each mel-filterbank, if `mel_norms` is provided. - """ - log_spec = spectrogram( - waveform, - window_function(self.n_fft, "hann"), - frame_length=self.n_fft, - hop_length=self.hop_length, - power=2.0, - mel_filters=self.mel_filters, - log_mel=None, - ) - - log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None)) - - if self.mel_norms is not None: - log_spec = log_spec / np.array(self.mel_norms)[:, None] - - return log_spec - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - sampling_rate: int | None = None, - truncation: bool = True, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = True, - padding: str | None = "max_length", - max_length: int | None = None, - **kwargs, - ) -> BatchFeature: - """ - `ClvpFeatureExtractor` is used to extract various voice specific properties such as the pitch and tone of the - voice, speaking speed, and even speaking defects like a lisp or stuttering from a sample voice or `raw_speech`. - - First the voice is padded or truncated in a way such that it becomes a waveform of `self.default_audio_length` - seconds long and then the log-mel spectrogram is extracted from it. - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether to return the attention mask. If left to the default, it will return the attention mask. - - [What are attention masks?](../glossary#attention-mask) - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values / vectors. - max_length (`int`, *optional*): - The maximum input length of the inputs. - """ - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [np.asarray([raw_speech]).T] - - batched_speech = BatchFeature({"input_features": raw_speech}) - - max_length = self.default_audio_length * self.sampling_rate if max_length is None else max_length - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - - # make sure list is in array format - input_features = padded_inputs.get("input_features").transpose(2, 0, 1) - - input_features = [ - self._np_extract_fbank_features(waveform).astype(np.float32) for waveform in input_features[0] - ] - - if isinstance(input_features[0], list): - padded_inputs["input_features"] = [np.asarray(feature) for feature in input_features] - else: - padded_inputs["input_features"] = input_features - - return padded_inputs.convert_to_tensors(return_tensors) +ClvpFeatureExtractor = deprecated_feature_extractor(ClvpAudioProcessor, "ClvpFeatureExtractor") __all__ = ["ClvpFeatureExtractor"] diff --git a/src/transformers/models/dac/feature_extraction_dac.py b/src/transformers/models/dac/feature_extraction_dac.py index 7f910f57f09f..f255d22ebba5 100644 --- a/src/transformers/models/dac/feature_extraction_dac.py +++ b/src/transformers/models/dac/feature_extraction_dac.py @@ -11,160 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for DAC""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_dac import DacAudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class DacFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs an Dac feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used for padding. - hop_length (`int`, *optional*, defaults to 512): - Overlap length between successive windows. - """ - - model_input_names = ["input_values", "n_quantizers"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 16000, - padding_value: float = 0.0, - hop_length: int = 512, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.hop_length = hop_length - - def __call__( - self, - raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy | None = None, - truncation: bool | None = False, - max_length: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape - `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio - (`feature_size = 2`). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, *optional*, defaults to `False`): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if padding and truncation: - raise ValueError("Both padding and truncation were set. Make sure you only set one.") - elif padding is None: - # by default let's pad the inputs - padding = True - - is_batched = bool( - isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] - elif not is_batched and not isinstance(raw_audio, np.ndarray): - raw_audio = np.asarray(raw_audio, dtype=np.float32) - elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): - raw_audio = raw_audio.astype(np.float32) - - # always return batch - if not is_batched: - raw_audio = [np.asarray(raw_audio).T] - - # verify inputs are valid - for idx, example in enumerate(raw_audio): - if example.ndim > 2: - raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") - if self.feature_size == 1 and example.ndim != 1: - raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") - if self.feature_size == 2: - raise ValueError("Stereo audio isn't supported for now") - - input_values = BatchFeature({"input_values": raw_audio}) - - # normal padding on batch - padded_inputs = self.pad( - input_values, - max_length=max_length, - truncation=truncation, - padding=padding, - return_attention_mask=padding, - pad_to_multiple_of=self.hop_length, - ) - if padding: - padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") - if padding: - padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :] - - input_values = [] - for example in padded_inputs.pop("input_values"): - if self.feature_size == 1: - example = example[..., None] - input_values.append(example.T) - - padded_inputs["input_values"] = input_values - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +DacFeatureExtractor = deprecated_feature_extractor(DacAudioProcessor, "DacFeatureExtractor") __all__ = ["DacFeatureExtractor"] diff --git a/src/transformers/models/dia/feature_extraction_dia.py b/src/transformers/models/dia/feature_extraction_dia.py index eda1ead6e014..d358589b4282 100644 --- a/src/transformers/models/dia/feature_extraction_dia.py +++ b/src/transformers/models/dia/feature_extraction_dia.py @@ -11,169 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for Dia""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_dia import DiaAudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class DiaFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs an Dia feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used for padding. - hop_length (`int`, *optional*, defaults to 512): - Overlap length between successive windows. - """ - - model_input_names = ["input_values", "n_quantizers"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 16000, - padding_value: float = 0.0, - hop_length: int = 512, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.hop_length = hop_length - - def __call__( - self, - raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy | None = None, - truncation: bool | None = False, - max_length: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape - `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio - (`feature_size = 2`). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, *optional*, defaults to `False`): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - return_tensors (`str` or [`~utils.TensorType`], *optional*, default to 'pt'): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if padding and truncation: - raise ValueError("Both padding and truncation were set. Make sure you only set one.") - elif padding is None: - # by default let's pad the inputs - padding = True - - is_batched = bool( - isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] - elif not is_batched and not isinstance(raw_audio, np.ndarray): - raw_audio = np.asarray(raw_audio, dtype=np.float32) - elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): - raw_audio = raw_audio.astype(np.float32) - - # always return batch - if not is_batched: - raw_audio = [np.asarray(raw_audio).T] - - # convert stereo to mono if necessary, unique to Dia - for idx, example in enumerate(raw_audio): - if self.feature_size == 2 and example.ndim == 2: - raw_audio[idx] = np.mean(example, -1) - - # verify inputs are valid - for idx, example in enumerate(raw_audio): - if example.ndim > 2: - raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") - if self.feature_size == 1 and example.ndim != 1: - raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") - if self.feature_size == 2 and example.ndim != 1: # note the conversion before - raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") - - input_values = BatchFeature({"input_values": raw_audio}) - - # temporarily treat it as if we were mono as we also convert stereo to mono - original_feature_size = self.feature_size - self.feature_size = 1 - - # normal padding on batch - padded_inputs = self.pad( - input_values, - max_length=max_length, - truncation=truncation, - padding=padding, - return_attention_mask=True, - pad_to_multiple_of=self.hop_length, - ) - padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") - - input_values = [] - for example in padded_inputs.pop("input_values"): - if self.feature_size == 1: - example = example[..., None] - input_values.append(example.T) - - padded_inputs["input_values"] = input_values - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - # rewrite back to original feature size - self.feature_size = original_feature_size - - return padded_inputs +DiaFeatureExtractor = deprecated_feature_extractor(DiaAudioProcessor, "DiaFeatureExtractor") __all__ = ["DiaFeatureExtractor"] diff --git a/src/transformers/models/encodec/feature_extraction_encodec.py b/src/transformers/models/encodec/feature_extraction_encodec.py index 383936000243..2f1644ac912a 100644 --- a/src/transformers/models/encodec/feature_extraction_encodec.py +++ b/src/transformers/models/encodec/feature_extraction_encodec.py @@ -11,195 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for EnCodec.""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_encodec import EncodecAudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class EncodecFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs an EnCodec feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Instantiating a feature extractor with the defaults will yield a similar configuration to that of the - [facebook/encodec_24khz](https://huggingface.co/facebook/encodec_24khz) architecture. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. - sampling_rate (`int`, *optional*, defaults to 24000): - The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values. - chunk_length_s (`float`, *optional*): - If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. - overlap (`float`, *optional*): - Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following - formulae : `int((1.0 - self.overlap) * self.chunk_length)`. - """ - - model_input_names = ["input_values", "padding_mask"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 24000, - padding_value: float = 0.0, - chunk_length_s: float | None = None, - overlap: float | None = None, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.chunk_length_s = chunk_length_s - self.overlap = overlap - - # This is a property because you might want to change the chunk_length_s on the fly - @property - def chunk_length(self) -> int | None: - if self.chunk_length_s is None: - return None - else: - return int(self.chunk_length_s * self.sampling_rate) - - # This is a property because you might want to change the chunk_length_s on the fly - @property - def chunk_stride(self) -> int | None: - if self.chunk_length_s is None or self.overlap is None: - return None - else: - return max(1, int((1.0 - self.overlap) * self.chunk_length)) - - def __call__( - self, - raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy | None = None, - truncation: bool | None = False, - max_length: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape - `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio - (`feature_size = 2`). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, *optional*, defaults to `False`): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if padding and truncation: - raise ValueError("Both padding and truncation were set. Make sure you only set one.") - elif padding is None: - # by default let's pad the inputs - padding = True - - is_batched = bool( - isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] - elif not is_batched and not isinstance(raw_audio, np.ndarray): - raw_audio = np.asarray(raw_audio, dtype=np.float32) - elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): - raw_audio = raw_audio.astype(np.float32) - - # always return batch - if not is_batched: - raw_audio = [np.asarray(raw_audio).T] - - # verify inputs are valid - for idx, example in enumerate(raw_audio): - if example.ndim > 2: - raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") - if self.feature_size == 1 and example.ndim != 1: - raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") - if self.feature_size == 2 and example.shape[-1] != 2: - raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") - - padded_inputs = None - input_values = BatchFeature({"input_values": raw_audio}) - if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: - if truncation: - max_length = min(array.shape[0] for array in raw_audio) - nb_step = int(np.floor(max_length / self.chunk_stride)) - max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length - elif padding: - max_length = max(array.shape[0] for array in raw_audio) - nb_step = int(np.ceil(max_length / self.chunk_stride)) - max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length - padding = "max_length" - else: - padded_inputs = input_values - - # normal padding on batch - if padded_inputs is None: - padded_inputs = self.pad( - input_values, - max_length=max_length, - truncation=truncation, - padding=padding, - return_attention_mask=padding, - ) - if padding: - padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") - - input_values = [] - for example in padded_inputs.pop("input_values"): - if self.feature_size == 1: - example = example[..., None] - input_values.append(example.T) - - padded_inputs["input_values"] = input_values - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +EncodecFeatureExtractor = deprecated_feature_extractor(EncodecAudioProcessor, "EncodecFeatureExtractor") __all__ = ["EncodecFeatureExtractor"] diff --git a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py index e2b24fb1f19f..1b111b76b49d 100644 --- a/src/transformers/models/gemma3n/feature_extraction_gemma3n.py +++ b/src/transformers/models/gemma3n/feature_extraction_gemma3n.py @@ -11,323 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_gemma3n import Gemma3nAudioProcessor -import math -from collections.abc import Sequence - -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -def create_fb_matrix( - n_freqs: int, - f_min: float, - f_max: float, - n_mels: int, - sample_rate: int, - fft_length: int, - norm: str | None = None, -) -> np.ndarray: - r"""Create a frequency bin conversion matrix (NumPy version). - - Args: - n_freqs (int): Number of frequencies to highlight/apply - f_min (float): Minimum frequency (Hz) - f_max (float): Maximum frequency (Hz) - n_mels (int): Number of mel filterbanks - sample_rate (int): Sample rate of the audio waveform - fft_length (int): FFT length - norm (Optional[str]): If 'slaney', divide the triangular mel weights by - the width of the mel band (area normalization). (Default: ``None``) - - Returns: - np.ndarray: Triangular filter banks (fb matrix) of size (``n_freqs``, - ``n_mels``) - meaning number of frequencies to highlight/apply to x the number of - filterbanks. - Each column is a filterbank so that assuming there is a matrix A of - size (..., ``n_freqs``), the applied result would be - ``A @ create_fb_matrix_numpy(A.shape[-1], ...)``. - """ - - if norm is not None and norm != "slaney": - raise ValueError("norm must be one of None or 'slaney'") - - # freq bins - all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length) - - # calculate mel freq bins - # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) - m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) - m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) - m_pts = np.linspace(m_min, m_max, n_mels + 2) - # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) - f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) - # calculate difference between each mel point and each stft freq point in Hz - f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) - slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) # (n_freqs, n_mels + 2) - # create overlapping triangles - zero = np.zeros(1, dtype=np.float32) - down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) - up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) - fb = np.maximum(zero, np.minimum(down_slopes, up_slopes)) - - if norm is not None and norm == "slaney": - # Slaney-style mel is scaled to be approx constant energy per channel - enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) - fb *= np.expand_dims(enorm, 0) - - return fb - - -def _unfold(array: np.ndarray, dimension: int, size: int, step: int) -> np.ndarray: - """A basic NumPy equivalent of PyTorch's unfold for 2D arrays along the last dim.""" - if array.ndim != 2: - raise ValueError("This unfold implementation currently supports 2D arrays (batch, time).") - if dimension != -1 and dimension != array.ndim - 1: - raise ValueError("This unfold implementation only supports unfolding the last dimension.") - - batch_size, original_length = array.shape - num_frames = (original_length - size) // step + 1 - - if num_frames <= 0: - return np.zeros((batch_size, 0, size), dtype=array.dtype) - - output_shape = (batch_size, num_frames, size) - output_strides = (array.strides[0], array.strides[1] * step, array.strides[1]) - - return np.lib.stride_tricks.as_strided(array, shape=output_shape, strides=output_strides) - - -class Gemma3nAudioFeatureExtractor(SequenceFeatureExtractor): - """An audio feature extractor Universal Speech Models https://huggingface.co/papers/2303.01037. - - Args: - feature_size (`int`, *optional*, defaults to 128): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether to return the attention mask for the generated MEL spectrograms. - frame_length_ms (`float`, *optional*, defaults to 32.0): - The length of a frame in milliseconds. - hop_length_ms (`float`, *optional*, defaults to 10.0): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - min_frequency (`float`, *optional*, defaults to 125.0): - The minimum frequency (in Hz) for the Mel filterbank. - max_frequency (`float`, *optional*, defaults to 7600.0): - The maximum frequency (in Hz) for the Mel filterbank. - preemphasis (`float`, *optional*, defaults to 0.97): - The preemphasis coefficient. - preemphasis_htk_flavor (`bool`, *optional*, defaults to `True`): - Whether to use HTK-style preemphasis. - fft_overdrive (`bool`, *optional*, defaults to `True`): - Whether to use FFT overdrive. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering. In other words, adds a small Gaussian noise to each frame. - E.g. use 0.0001 to add dithering with a normal distribution centered - around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). - The value 0.0 means no dithering. - Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces - the high log_mel_fbank values for signals with hard-zero sections, - when VAD cutoff is present in the signal. - input_scale_factor (`float`, *optional*, defaults to 1.0): - Scaling factor applied to the input waveform. - mel_floor (`float`, *optional*, defaults to 1e-05): - Minimum value for Mel spectrograms to avoid log(0). - per_bin_mean (`Optional[Sequence[float]]`, *optional*): - Mean values for per-bin normalization. - per_bin_stddev (`Optional[Sequence[float]]`, *optional*): - Standard deviation values for per-bin normalization. - """ - - model_input_names = ["input_features", "input_features_mask"] - - def __init__( - self, - feature_size: int = 128, - sampling_rate: int = 16_000, - padding_value: float = 0.0, - return_attention_mask: bool = True, - frame_length_ms: float = 32.0, - hop_length_ms: float = 10.0, - min_frequency: float = 125.0, - max_frequency: float = 7600.0, - preemphasis: float = 0.97, - preemphasis_htk_flavor: bool = True, - fft_overdrive: bool = True, - dither: float = 0.0, - input_scale_factor: float = 1.0, - mel_floor: float = 1e-5, - per_bin_mean: Sequence[float] | None = None, - per_bin_stddev: Sequence[float] | None = None, - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - - self.min_frequency = min_frequency - self.max_frequency = max_frequency - self.preemphasis = preemphasis - self.preemphasis_htk_flavor = preemphasis_htk_flavor - self.fft_overdrive = fft_overdrive - self.dither = dither - self.input_scale_factor = input_scale_factor - self.frame_length = int(round(sampling_rate * frame_length_ms / 1000.0)) - self.hop_length = int(round(sampling_rate * hop_length_ms / 1000.0)) - self.mel_floor = np.array(mel_floor, dtype=np.float64) - - fft_length = 2 ** math.ceil(math.log2(self.frame_length)) - if self.fft_overdrive: - fft_length *= 2 - self.fft_length = fft_length - - hann_arange = np.arange(self.frame_length, dtype=np.float32) - window = 0.5 * (1 - np.cos(2 * np.pi * hann_arange / self.frame_length)) - self.window = window.astype(np.float32) - - self.mel_filters = create_fb_matrix( - n_freqs=self.fft_length // 2 + 1, - f_min=min_frequency, - f_max=max_frequency, - n_mels=feature_size, - sample_rate=self.sampling_rate, - norm=None, - fft_length=fft_length, - ) - - if per_bin_mean is not None: - self.per_bin_mean = np.array(per_bin_mean).reshape(1, 1, feature_size) - else: - self.per_bin_mean = None - - if per_bin_stddev is not None: - self.per_bin_stddev = np.array(per_bin_stddev).reshape(1, 1, feature_size) - else: - self.per_bin_stddev = None - - def _extract_spectrogram(self, waveform: np.ndarray, attention_mask: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - """""" - if waveform.ndim == 1: # If single waveform, add batch dimension - waveform = np.expand_dims(waveform, axis=0) - - if self.dither > 0.0: - waveform = waveform + self.dither * np.random.randn(*waveform.shape).astype(waveform.dtype) - - if self.input_scale_factor != 1.0: - waveform = waveform * self.input_scale_factor - - frame_size_for_unfold = self.frame_length + 1 - - # NumPy equivalent of unfold for [B, NumFrames, frame_size_for_unfold] - frames_to_process = _unfold(waveform, dimension=-1, size=frame_size_for_unfold, step=self.hop_length) - - if self.preemphasis > 0.0: - if self.preemphasis_htk_flavor: - first_in_frame = frames_to_process[..., :1] * (1.0 - self.preemphasis) - rest_in_frame = frames_to_process[..., 1:-1] - self.preemphasis * frames_to_process[..., :-2] - frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) - else: - frames = frames_to_process[..., 1:] - self.preemphasis * frames_to_process[..., :-1] - else: - frames = frames_to_process[..., :-1] - - frames = frames * self.window # Broadcasting window - stft = np.fft.rfft(frames, n=self.fft_length, axis=-1) - - magnitude_spec = np.abs(stft) - - mel_spec = np.matmul(magnitude_spec, self.mel_filters) - log_mel_spec = np.log(np.maximum(mel_spec, self.mel_floor)) - - if self.per_bin_mean is not None: - log_mel_spec = log_mel_spec - self.per_bin_mean # Broadcasting - - if self.per_bin_stddev is not None: - log_mel_spec = log_mel_spec / self.per_bin_stddev # Broadcasting - - mel_spectrogram = log_mel_spec.squeeze(0) - mask = attention_mask[:: self.hop_length].astype(bool) - # TODO: The filtered mask is always exactly 3 elements longer than the mel_spectrogram. Why??? - return mel_spectrogram, mask[: mel_spectrogram.shape[0]] - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy = "longest", - max_length: int | None = 480_000, - truncation: bool = True, - pad_to_multiple_of: int | None = 128, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = True, - **kwargs, - ) -> BatchFeature: - """Creates a batch of MEL spectrograms from the provided raw speech. - - This implementation uses a different algorithm for windowing and preemphasis compared to the built-in - `transformers.audio_utils.spectrogram()` function that _will_ result in different outputs. Consider this - carefully when selecting an audio feature extractor, especially with pre-trained models. - - Args: - raw_speech: - The audio for which MEL spectrograms are created. - padding (`Union[bool, str, PaddingStrategy]`, *optional*, defaults to `"longest"`): - The padding strategy to use for batches of audio with different lengths. - max_length (`int`, *optional*, defaults to 480000): - If provided, defines the maximum length of the audio to allow. Audio longer than this will be - truncated if `truncation=True`. - truncation (`bool`, *optional*, defaults to `True`): - Whether or not to truncate audio above `max_length`. - pad_to_multiple_of (`int`, *optional*, defaults to 128): - When padding, pad to a multiple of this value. The default value is defined for optimal TPU support. - return_tensors (`Union[str, TensorType]`, *optional*, defaults to `None`): - The type of tensors to return (e.g., NumPy, or Torch). - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether to return the attention mask for the generated MEL spectrograms. - """ - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - is_batched_sequence = isinstance(raw_speech, Sequence) and isinstance(raw_speech[0], (np.ndarray, Sequence)) - is_batched = is_batched_numpy or is_batched_sequence - - # Always return a batch - if not is_batched: - raw_speech = [raw_speech] - raw_speech = [np.asarray([rs]).T for rs in raw_speech] - - batched_speech = self.pad( - BatchFeature({"input_features": raw_speech}), - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - - prepared_speech = [] - prepared_speech_mask = [] - for speech, mask in zip(batched_speech.input_features, batched_speech.attention_mask): - speech, mask = self._extract_spectrogram(speech.T, mask) - prepared_speech.append(speech.astype(np.float32)) - prepared_speech_mask.append(mask) - - return BatchFeature( - {"input_features": prepared_speech, "input_features_mask": prepared_speech_mask}, - tensor_type=return_tensors, - ) +Gemma3nAudioFeatureExtractor = deprecated_feature_extractor(Gemma3nAudioProcessor, "Gemma3nAudioFeatureExtractor") __all__ = ["Gemma3nAudioFeatureExtractor"] diff --git a/src/transformers/models/granite_speech/feature_extraction_granite_speech.py b/src/transformers/models/granite_speech/feature_extraction_granite_speech.py index cd32d0433bae..15bab8e6466f 100644 --- a/src/transformers/models/granite_speech/feature_extraction_granite_speech.py +++ b/src/transformers/models/granite_speech/feature_extraction_granite_speech.py @@ -11,174 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for Granite Speech.""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_granite_speech import GraniteSpeechAudioProcessor -import math -from collections.abc import Sequence - -import numpy as np - -from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin -from ...tokenization_utils_base import AudioInput -from ...utils import is_torch_available, is_torchaudio_available, logging -from ...utils.import_utils import requires_backends - - -logger = logging.get_logger(__name__) - -if is_torch_available(): - import torch - -if is_torchaudio_available(): - import torchaudio - - -class GraniteSpeechFeatureExtractor(FeatureExtractionMixin): - model_input_names = ["input_features"] - - def __init__( - self, - sampling_rate: int = 16000, - n_fft: int = 512, - win_length: int = 400, - hop_length: int = 160, - n_mels: int = 80, - projector_window_size: int = 15, - projector_downsample_rate: int = 5, - **kwargs, - ): - super().__init__(**kwargs) - self.sampling_rate = sampling_rate - self.melspec_kwargs = { - "sample_rate": sampling_rate, - "n_fft": n_fft, - "win_length": win_length, - "hop_length": hop_length, - "n_mels": n_mels, - } - requires_backends(self, ["torchaudio"]) - self.mel_filters = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) - self.projector_window_size = projector_window_size - self.projector_downsample_rate = projector_downsample_rate - - def __call__( - self, - audios: AudioInput, - device: str | None = "cpu", - ) -> BatchFeature: - requires_backends(self, ["torchaudio"]) - - speech_inputs = {} - batched_audio, audio_lengths = self._get_audios_and_audio_lengths(audios) - speech_inputs["input_features"] = self._extract_mel_spectrograms( - batched_audio, - device=device, - ) - audio_embed_sizes = self._get_num_audio_features(audio_lengths) - speech_inputs["audio_embed_sizes"] = audio_embed_sizes - # TODO (@alex-jw-brooks): Currently input_features_mask is not - # a great name, because input_features and input_features_mask - # have different shapes (before/after the projector). - # - # We should align this with other multimodal models, e.g,. llava - # and qwen2audio and refactor this to ensure input_feature_mask - # has the same dimensionality as input_features, or compute it in - # the model based on the audio embedding sizes (since we do not - # have an attention mask for the audio features to infer padding from). - speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor( - audio_embed_sizes - ).view(-1, 1) - return BatchFeature(data=speech_inputs) - - def _extract_mel_spectrograms(self, audio: "torch.Tensor", device="cpu"): - """ - Compute the Mel features to be passed to the conformer encoder. - """ - requires_backends(self, ["torchaudio"]) - if device is not None: - melspec = self.mel_filters.to(device) - audio = audio.to(device) - else: - melspec = self.mel_filters - - bsz = audio.shape[0] - with torch.no_grad(): - # Compute mel features - mel = melspec(audio.float()) - logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() - mx = logmel.amax(dim=(-2, -1), keepdim=True) - logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) - # remove last frame if odd - if logmel.shape[1] % 2 == 1: - logmel = logmel[:, :-1] - - # stacking and skipping by 2 - audio = logmel.reshape(bsz, -1, 2 * logmel.shape[-1]) - - return audio - - def _get_num_audio_features(self, audio_lengths: Sequence[int]) -> Sequence[int]: - """ - Gets the (variable length) number of features (i.e., projector output) for the sequences - being considered. - - Args: - audio_lengths (`Sequence[int]`): - Sequence of one or more raw audio lengths. - """ - hop_length = self.melspec_kwargs["hop_length"] - effective_window_size = self.projector_window_size // self.projector_downsample_rate - - projector_lengths = [] - for raw_length in audio_lengths: - # mel sequence length computation - mel_length = raw_length // hop_length + 1 - # encoder frame takes two mel features - encoder_length = mel_length // 2 - nblocks = math.ceil(encoder_length / self.projector_window_size) - # projector output length - projector_length = nblocks * effective_window_size - projector_lengths.append(projector_length) - - return projector_lengths - - def _get_audios_and_audio_lengths(self, audios: AudioInput) -> Sequence["torch.Tensor", Sequence[int]]: - """ - Coerces audio inputs to torch tensors and extracts audio lengths prior to stacking. - - Args: - audios (`AudioInput`): - Audio sequence, numpy array, or torch tensor. - """ - requires_backends(self, ["torch"]) - - # Coerce to PyTorch tensors if we have numpy arrays, since - # currently we have a dependency on torch/torchaudio anyway - if isinstance(audios, np.ndarray): - audios = torch.from_numpy(audios) - elif isinstance(audios, Sequence) and isinstance(audios[0], np.ndarray): - audios = [torch.from_numpy(arr) for arr in audios] - - if isinstance(audios, torch.Tensor): - if audios.ndim == 1: - audios = audios.unsqueeze(0) - if not torch.is_floating_point(audios): - raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") - - if audios.shape[0] > 1: - logger.warning("Audio samples are already collated; assuming they all have the same length") - lengths = [audios.shape[-1]] * audios.shape[0] - return audios, lengths - - elif isinstance(audios, Sequence) and isinstance(audios[0], torch.Tensor): - if not torch.is_floating_point(audios[0]): - raise ValueError("Invalid audio provided. Audio should be a floating point between 0 and 1") - lengths = [audio.shape[-1] for audio in audios] - audios = [audio.squeeze(0) for audio in audios] - audios = torch.nn.utils.rnn.pad_sequence(audios, batch_first=True, padding_value=0.0) - return audios, lengths - - raise TypeError("Invalid audio provided. Audio should be a one or more torch tensors or numpy arrays") +GraniteSpeechFeatureExtractor = deprecated_feature_extractor( + GraniteSpeechAudioProcessor, "GraniteSpeechFeatureExtractor" +) __all__ = ["GraniteSpeechFeatureExtractor"] diff --git a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py index b472473a19e5..5abc645f3f8a 100644 --- a/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py @@ -1,10 +1,7 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_kyutai_speech_to_text.py file directly. One of our CI enforces this. +# This file is now a thin backward-compatibility wrapper. The original was auto-generated from modular_kyutai_speech_to_text.py. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# Copyright 2025 Kyutai and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,218 +14,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_kyutai_speech_to_text import KyutaiSpeechToTextAudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class KyutaiSpeechToTextFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs an KyutaiSpeechToText feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. - sampling_rate (`int`, *optional*, defaults to 24000): - The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values. - chunk_length_s (`float`, *optional*): - If defined the audio is pre-processed into chunks of lengths `chunk_length_s` and then encoded. - overlap (`float`, *optional*): - Defines the overlap between each chunk. It is used to compute the `chunk_stride` using the following - formulae : `int((1.0 - self.overlap) * self.chunk_length)`. - audio_delay_seconds (`float`, *optional*, defaults to 0.0): - The delay in seconds to add after the audio (right padding). - audio_silence_prefix_seconds (`float`, *optional*, defaults to 0.0): - The silence prefix in seconds to add before the audio (left padding). - """ - - model_input_names = ["input_values", "padding_mask"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 24000, - padding_value: float = 0.0, - chunk_length_s: float | None = None, - overlap: float | None = None, - audio_delay_seconds: float | None = 0.0, - audio_silence_prefix_seconds: float | None = 0.0, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.chunk_length_s = chunk_length_s - self.overlap = overlap - self.audio_delay_seconds = audio_delay_seconds - self.audio_silence_prefix_seconds = audio_silence_prefix_seconds - - # This is a property because you might want to change the chunk_length_s on the fly - @property - def chunk_length(self) -> int | None: - if self.chunk_length_s is None: - return None - else: - return int(self.chunk_length_s * self.sampling_rate) - - # This is a property because you might want to change the chunk_length_s on the fly - @property - def chunk_stride(self) -> int | None: - if self.chunk_length_s is None or self.overlap is None: - return None - else: - return max(1, int((1.0 - self.overlap) * self.chunk_length)) - - def __call__( - self, - raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy | None = None, - truncation: bool | None = False, - max_length: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. The numpy array must be of shape - `(num_samples,)` for mono audio (`feature_size = 1`), or `(2, num_samples)` for stereo audio - (`feature_size = 2`). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - truncation (`bool`, *optional*, defaults to `False`): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if padding and truncation: - raise ValueError("Both padding and truncation were set. Make sure you only set one.") - elif padding is None: - # by default let's pad the inputs - padding = True - - is_batched = bool( - isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] - elif not is_batched and not isinstance(raw_audio, np.ndarray): - raw_audio = np.asarray(raw_audio, dtype=np.float32) - elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): - raw_audio = raw_audio.astype(np.float32) - - # always return batch - if not is_batched: - raw_audio = [np.asarray(raw_audio).T] - - # verify inputs are valid - for idx, example in enumerate(raw_audio): - if example.ndim > 2: - raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") - if self.feature_size == 1 and example.ndim != 1: - raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") - if self.feature_size == 2 and example.shape[-1] != 2: - raise ValueError(f"Expected stereo audio but example has {example.shape[-1]} channels") - - padded_inputs = None - input_values = BatchFeature({"input_values": raw_audio}) - if self.chunk_stride is not None and self.chunk_length is not None and max_length is None: - if truncation: - max_length = min(array.shape[0] for array in raw_audio) - nb_step = int(np.floor(max_length / self.chunk_stride)) - max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length - elif padding: - max_length = max(array.shape[0] for array in raw_audio) - nb_step = int(np.ceil(max_length / self.chunk_stride)) - max_length = (nb_step - 1) * self.chunk_stride + self.chunk_length - padding = "max_length" - else: - padded_inputs = input_values - - # normal padding on batch - if padded_inputs is None: - padded_inputs = self.pad( - input_values, - max_length=max_length, - truncation=truncation, - padding=padding, - return_attention_mask=padding, - ) - - if padding: - padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") - - # now let's pad left and right - pad_left = int(self.audio_silence_prefix_seconds * self.sampling_rate) - pad_right = int((self.audio_delay_seconds + 1.0) * self.sampling_rate) - padded_inputs["input_values"] = np.pad( - padded_inputs["input_values"], - ((0, 0), (pad_left, pad_right)), - mode="constant", - constant_values=0.0, - ) - if padding: - padded_inputs["padding_mask"] = np.pad( - padded_inputs["padding_mask"], - ((0, 0), (pad_left, pad_right)), - mode="constant", - constant_values=0, - ) - - input_values = [] - for example in padded_inputs.pop("input_values"): - if self.feature_size == 1: - example = example[..., None] - input_values.append(example.T) - - padded_inputs["input_values"] = input_values - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs - +KyutaiSpeechToTextFeatureExtractor = deprecated_feature_extractor( + KyutaiSpeechToTextAudioProcessor, "KyutaiSpeechToTextFeatureExtractor" +) __all__ = ["KyutaiSpeechToTextFeatureExtractor"] diff --git a/src/transformers/models/lasr/feature_extraction_lasr.py b/src/transformers/models/lasr/feature_extraction_lasr.py index 7cf1822ee40d..90b1954ec5f2 100644 --- a/src/transformers/models/lasr/feature_extraction_lasr.py +++ b/src/transformers/models/lasr/feature_extraction_lasr.py @@ -11,265 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_lasr import LasrAudioProcessor -import numpy as np -import torch - -from ...audio_utils import hertz_to_mel -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging -from ...utils.import_utils import requires - - -logger = logging.get_logger(__name__) - - -# TODO: @eustlb, we should be able to remove this and use mel_filter_bank from audio_utils -def linear_to_mel_weight_matrix( - num_mel_bins: int, - num_spectrogram_bins: int, - sample_rate: float, - lower_edge_hertz: float, - upper_edge_hertz: float, - dtype, -) -> np.ndarray: - """NumPy-port of the JAX mel weight matrix logic.""" - # We use float64 for precision, matching the JAX implementation. - internal_dtype = np.float64 - - # HTK excludes the spectrogram DC bin. - bands_to_zero = 1 - nyquist_hertz = sample_rate / 2.0 - linear_frequencies = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins, dtype=internal_dtype)[bands_to_zero:] - spectrogram_bins_mel = hertz_to_mel(linear_frequencies, mel_scale="kaldi")[:, np.newaxis] - - edges = np.linspace( - hertz_to_mel(lower_edge_hertz, mel_scale="kaldi"), - hertz_to_mel(upper_edge_hertz, mel_scale="kaldi"), - num_mel_bins + 2, - dtype=internal_dtype, - ) - - lower_edge_mel, center_mel, upper_edge_mel = ( - edges[:-2][np.newaxis, :], - edges[1:-1][np.newaxis, :], - edges[2:][np.newaxis, :], - ) - - lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel) - upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel) - mel_weights_matrix = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes)) - return np.pad(mel_weights_matrix, [[bands_to_zero, 0], [0, 0]]).astype(dtype) - - -@requires(backends=("torch",)) -class LasrFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a LASR feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time - Fourier Transform` which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 128): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 160): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - n_fft (`int`, *optional*, defaults to 512): - Size of the Fourier transform. - win_length (`int`, *optional*, defaults to 400): - The window length for the STFT computation. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=128, - sampling_rate=16000, - hop_length=160, - n_fft=512, - win_length=400, - padding_value=0.0, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - self.hop_length = hop_length - self.n_fft = n_fft - self.win_length = win_length - self.mel_filters = torch.from_numpy( - linear_to_mel_weight_matrix( - num_mel_bins=feature_size, - num_spectrogram_bins=n_fft // 2 + 1, - sample_rate=sampling_rate, - lower_edge_hertz=125.0, - upper_edge_hertz=7500.0, - dtype=np.float64, - ) - ) - - def _torch_extract_fbank_features(self, waveform, device="cpu"): - # spectrogram - window = torch.hann_window(self.win_length, periodic=False, device=device, dtype=torch.float64) - waveform = waveform.to(torch.float64) - - # TODO: @eustlb, to be standardized - # here we cannot use directly torch.stft because every fft frame is padded with zeros - # due to unfold then rfft, while torch.stft unfolds with the number of fft points - frames = waveform.unfold(-1, self.win_length, self.hop_length) - stft = torch.fft.rfft(window * frames, n=self.n_fft) - power_spec = torch.abs(stft) ** 2 - - # log mel spectrogram - mel_filters = self.mel_filters.to(device) - mel_spec = torch.clamp(power_spec @ mel_filters, min=1e-5) - mel_spec = torch.log(mel_spec) - - return mel_spec - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = None, - padding: str | None = "longest", - max_length: int | None = None, - sampling_rate: int | None = None, - do_normalize: bool | None = None, - device: str | None = "cpu", - return_token_timestamps: bool | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for - the STFT computation if available, otherwise a slower NumPy based one. - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values / vectors. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance of the model. - device (`str`, *optional*, defaults to `'cpu'`): - Specifies the device for computation of the log-mel spectrogram of audio signals in the - `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") - return_token_timestamps (`bool`, *optional*, defaults to `None`): - Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred. - - Whether or not to return the number of frames of the input raw_speech. - These num_frames can be used by the model to compute word level timestamps. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - # Convert to torch tensor - if isinstance(raw_speech, np.ndarray): - raw_speech = torch.tensor(raw_speech) - elif isinstance(raw_speech, (list, tuple)): - if isinstance(raw_speech[0], (list, np.ndarray)): - raw_speech = [torch.tensor(speech) for speech in raw_speech] - else: # list[float] - raw_speech = torch.tensor(raw_speech) - - is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 - if is_batched_torch and len(raw_speech.shape) > 2: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - raw_speech = raw_speech.mean(-1) - - is_batched_sequence = isinstance(raw_speech, (list, tuple)) - if is_batched_sequence: - for speech in raw_speech: - if len(speech.shape) > 1: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - speech = speech.mean(-1) - - if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] - else: - raw_speech = [raw_speech[:, None].to(torch.float32)] - - batched_speech = BatchFeature({"input_features": raw_speech}) - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_tensors="pt", - ) - input_features = padded_inputs.input_features.squeeze(-1) - input_features = self._torch_extract_fbank_features(input_features, device) - data = { - "input_features": input_features.to(torch.float32), - } - - if return_attention_mask: - attention_mask = padded_inputs.attention_mask[:, self.win_length - 1 :: self.hop_length] - data["attention_mask"] = attention_mask.to(torch.bool) - - return BatchFeature(data=data, tensor_type=return_tensors) +LasrFeatureExtractor = deprecated_feature_extractor(LasrAudioProcessor, "LasrFeatureExtractor") __all__ = ["LasrFeatureExtractor"] diff --git a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py index 1811fa11e630..c41ea0666292 100644 --- a/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py @@ -11,324 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for Musicgen Melody -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_musicgen_melody import MusicgenMelodyAudioProcessor -import copy -from typing import Any - -import numpy as np - -from ...audio_utils import chroma_filter_bank -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, is_torch_available, is_torchaudio_available, logging -from ...utils.import_utils import requires - - -if is_torch_available(): - import torch - -if is_torchaudio_available(): - import torchaudio - -logger = logging.get_logger(__name__) - - -@requires(backends=("torchaudio",)) -class MusicgenMelodyFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a MusicgenMelody feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts chroma features from audio processed by [Demucs](https://github.com/adefossez/demucs/tree/main) or - directly from raw audio waveform. - - Args: - feature_size (`int`, *optional*, defaults to 12): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 32000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 4096): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - chunk_length (`int`, *optional*, defaults to 30): - The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio - sequences. - n_fft (`int`, *optional*, defaults to 16384): - Size of the Fourier transform. - num_chroma (`int`, *optional*, defaults to 12): - Number of chroma bins to use. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. - return_attention_mask (`bool`, *optional*, defaults to `False`): - Whether to return the attention mask. Can be overwritten when calling the feature extractor. - - [What are attention masks?](../glossary#attention-mask) - - - - For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - stem_indices (`list[int]`, *optional*, defaults to `[3, 2]`): - Stem channels to extract if demucs outputs are passed. - """ - - model_input_names = ["input_features"] - - def __init__( - self, - feature_size=12, - sampling_rate=32000, - hop_length=4096, - chunk_length=30, - n_fft=16384, - num_chroma=12, - padding_value=0.0, - return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask - stem_indices=[3, 2], - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - self.n_fft = n_fft - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_samples = chunk_length * sampling_rate - self.sampling_rate = sampling_rate - self.chroma_filters = torch.from_numpy( - chroma_filter_bank(sampling_rate=sampling_rate, num_frequency_bins=n_fft, tuning=0, num_chroma=num_chroma) - ).float() - self.spectrogram = torchaudio.transforms.Spectrogram( - n_fft=n_fft, win_length=n_fft, hop_length=hop_length, power=2, center=True, pad=0, normalized=True - ) - self.stem_indices = stem_indices - - def _torch_extract_fbank_features(self, waveform: torch.Tensor) -> torch.Tensor: - """ - Compute the chroma spectrogram of the provided audio using the torchaudio spectrogram implementation and the librosa chroma features. - """ - - # if wav length is not long enough, pad it - wav_length = waveform.shape[-1] - if wav_length < self.n_fft: - pad = self.n_fft - wav_length - rest = 0 if pad % 2 == 0 else 1 - waveform = torch.nn.functional.pad(waveform, (pad // 2, pad // 2 + rest), "constant", 0) - - # squeeze alongside channel dimension - spec = self.spectrogram(waveform).squeeze(1) - - # sum along the frequency dimension - raw_chroma = torch.einsum("cf, ...ft->...ct", self.chroma_filters, spec) - - # normalise with max value - norm_chroma = torch.nn.functional.normalize(raw_chroma, p=float("inf"), dim=-2, eps=1e-6) - - # transpose time and chroma dimension -> (batch, time, chroma) - norm_chroma = norm_chroma.transpose(1, 2) - - # replace max value alongside chroma dimension with 1 and replace the rest with 0 - idx = norm_chroma.argmax(-1, keepdim=True) - norm_chroma[:] = 0 - norm_chroma.scatter_(dim=-1, index=idx, value=1) - - return norm_chroma - - def _extract_stem_indices(self, audio, sampling_rate=None): - """ - Extracts stems from the output of the [Demucs](https://github.com/adefossez/demucs/tree/main) audio separation model, - then converts to mono-channel and resample to the feature extractor sampling rate. - - Args: - audio (`torch.Tensor` of shape `(batch_size, num_stems, channel_size, audio_length)`): - The output of the Demucs model to be processed. - sampling_rate (`int`, *optional*): - Demucs sampling rate. If not specified, defaults to `44000`. - """ - sampling_rate = 44000 if sampling_rate is None else sampling_rate - - # extract "vocals" and "others" sources from audio encoder (demucs) output - # [batch_size, num_stems, channel_size, audio_length] - wav = audio[:, torch.tensor(self.stem_indices)] - - # merge extracted stems to single waveform - wav = wav.sum(1) - - # convert to mono-channel waveform - wav = wav.mean(dim=1, keepdim=True) - - # resample to model sampling rate - # not equivalent to julius.resample - if sampling_rate != self.sampling_rate: - wav = torchaudio.functional.resample( - wav, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24 - ) - - # [batch_size, 1, audio_length] -> [batch_size, audio_length] - wav = wav.squeeze(1) - - return wav - - def __call__( - self, - audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: bool = True, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = None, - padding: str | None = True, - max_length: int | None = None, - sampling_rate: int | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - audio (`torch.Tensor`, `np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a torch tensor, a numpy array, a list of float - values, a list of numpy arrays, a list of torch tensors, or a list of list of float values. - If `audio` is the output of Demucs, it has to be a torch tensor of shape `(batch_size, num_stems, channel_size, audio_length)`. - Otherwise, it must be mono or stereo channel audio. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - For Musicgen Melody models, audio `attention_mask` is not necessary. - - - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - Note that if `audio` is the output of Demucs, `sampling_rate` must be the sampling rate at which Demucs operates. - """ - - if sampling_rate is None: - logger.warning_once( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if isinstance(audio, torch.Tensor) and len(audio.shape) == 4: - logger.warning_once( - "`audio` is a 4-dimensional torch tensor and has thus been recognized as the output of `Demucs`. " - "If this is not the case, make sure to read Musicgen Melody docstrings and " - "to correct `audio` to get the right behaviour." - "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody" - ) - audio = self._extract_stem_indices(audio, sampling_rate=sampling_rate) - elif sampling_rate is not None and sampling_rate != self.sampling_rate: - audio = torchaudio.functional.resample( - audio, sampling_rate, self.sampling_rate, rolloff=0.945, lowpass_filter_width=24 - ) - - is_batched = isinstance(audio, (np.ndarray, torch.Tensor)) and len(audio.shape) > 1 - is_batched = is_batched or ( - isinstance(audio, (list, tuple)) and (isinstance(audio[0], (torch.Tensor, np.ndarray, tuple, list))) - ) - - if is_batched and not isinstance(audio[0], torch.Tensor): - audio = [torch.tensor(speech, dtype=torch.float32).unsqueeze(-1) for speech in audio] - elif is_batched: - audio = [speech.unsqueeze(-1) for speech in audio] - elif not is_batched and not isinstance(audio, torch.Tensor): - audio = torch.tensor(audio, dtype=torch.float32).unsqueeze(-1) - - if isinstance(audio[0], torch.Tensor) and audio[0].dtype is torch.float64: - audio = [speech.to(torch.float32) for speech in audio] - - # always return batch - if not is_batched: - audio = [audio] - - if len(audio[0].shape) == 3: - logger.warning_once( - "`audio` has been detected as a batch of stereo signals. Will be convert to mono signals. " - "If this is an undesired behaviour, make sure to read Musicgen Melody docstrings and " - "to correct `audio` to get the right behaviour." - "Link to the docstrings: https://huggingface.co/docs/transformers/main/en/model_doc/musicgen_melody" - ) - # convert to mono-channel waveform - audio = [stereo.mean(dim=0) for stereo in audio] - - batched_speech = BatchFeature({"input_features": audio}) - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length if max_length else self.n_samples, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_tensors="pt", - ) - - input_features = self._torch_extract_fbank_features(padded_inputs["input_features"].squeeze(-1)) - - padded_inputs["input_features"] = input_features - - if return_attention_mask: - # rescale from raw audio length to spectrogram length - padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length] - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs - - def to_dict(self) -> dict[str, Any]: - """ - Serializes this instance to a Python dictionary. Returns: - `dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. - """ - output = copy.deepcopy(self.__dict__) - output["feature_extractor_type"] = self.__class__.__name__ - if "mel_filters" in output: - del output["mel_filters"] - if "window" in output: - del output["window"] - if "chroma_filters" in output: - del output["chroma_filters"] - if "spectrogram" in output: - del output["spectrogram"] - return output +MusicgenMelodyFeatureExtractor = deprecated_feature_extractor( + MusicgenMelodyAudioProcessor, "MusicgenMelodyFeatureExtractor" +) __all__ = ["MusicgenMelodyFeatureExtractor"] diff --git a/src/transformers/models/parakeet/feature_extraction_parakeet.py b/src/transformers/models/parakeet/feature_extraction_parakeet.py index c745d02c9629..92f02cd0a9f4 100644 --- a/src/transformers/models/parakeet/feature_extraction_parakeet.py +++ b/src/transformers/models/parakeet/feature_extraction_parakeet.py @@ -11,275 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_parakeet import ParakeetAudioProcessor -import numpy as np -import torch - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, is_librosa_available, logging -from ...utils.import_utils import requires - - -if is_librosa_available(): - import librosa - - -EPSILON = 1e-5 -LOG_ZERO_GUARD_VALUE = 2**-24 - - -logger = logging.get_logger(__name__) - - -@requires(backends=("torch", "librosa")) -class ParakeetFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Parakeet feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time - Fourier Transform` which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 160): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - n_fft (`int`, *optional*, defaults to 512): - Size of the Fourier transform. - win_length (`int`, *optional*, defaults to 400): - The window length for the STFT computation. - preemphasis (`float`, *optional*, defaults to 0.97): - A preemphasis filter coefficient. 0.0 means no preemphasis filter. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=80, - sampling_rate=16000, - hop_length=160, - n_fft=512, - win_length=400, - preemphasis=0.97, - padding_value=0.0, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - self.hop_length = hop_length - self.n_fft = n_fft - self.win_length = win_length - self.preemphasis = preemphasis - - # TODO: @eustlb, for now we use librosa to compute the mel filters - # indeed mel_filter_bank uses np.float64 (while librosa uses np.float32), giving numerical differences - # self.mel_filters = mel_filter_bank( - # num_frequency_bins=n_fft // 2 + 1, - # num_mel_filters=feature_size, - # min_frequency=0.0, - # max_frequency=sampling_rate / 2, - # sampling_rate=sampling_rate, - # norm="slaney", - # mel_scale="slaney", - # ) - mel_filters = librosa.filters.mel( - sr=sampling_rate, n_fft=n_fft, n_mels=feature_size, fmin=0.0, fmax=sampling_rate / 2, norm="slaney" - ) - self.mel_filters = torch.from_numpy(mel_filters).to(torch.float32) - - def _torch_extract_fbank_features(self, waveform, device="cpu"): - # spectrogram - window = torch.hann_window(self.win_length, periodic=False, device=device) - stft = torch.stft( - waveform, - self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=window, - return_complex=True, - pad_mode="constant", - ) - # Let's math original implementation - # magnitudes = torch.abs(stft) ** 2 - magnitudes = torch.view_as_real(stft) - magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) - magnitudes = magnitudes.pow(2) - - # log mel spectrogram - mel_filters = self.mel_filters.to(device) - mel_spec = mel_filters @ magnitudes - mel_spec = torch.log(mel_spec + LOG_ZERO_GUARD_VALUE) - - # (batch_size, num_mel_filters, num_frames) -> (batch_size, num_frames, num_mel_filters) - mel_spec = mel_spec.permute(0, 2, 1) - - return mel_spec - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = None, - padding: str | None = "longest", - max_length: int | None = None, - sampling_rate: int | None = None, - do_normalize: bool | None = None, - device: str | None = "cpu", - return_token_timestamps: bool | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for - the STFT computation if available, otherwise a slower NumPy based one. - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values / vectors. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance of the model. - device (`str`, *optional*, defaults to `'cpu'`): - Specifies the device for computation of the log-mel spectrogram of audio signals in the - `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") - return_token_timestamps (`bool`, *optional*, defaults to `None`): - Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred. - - Whether or not to return the number of frames of the input raw_speech. - These num_frames can be used by the model to compute word level timestamps. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - # Convert to torch tensor - if isinstance(raw_speech, np.ndarray): - raw_speech = torch.tensor(raw_speech) - elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): - raw_speech = [torch.tensor(speech) for speech in raw_speech] - - is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 - if is_batched_torch and len(raw_speech.shape) > 2: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - raw_speech = raw_speech.mean(-1) - - is_batched_sequence = isinstance(raw_speech, (list, tuple)) - if is_batched_sequence: - for speech in raw_speech: - if len(speech.shape) > 1: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - speech = speech.mean(-1) - - if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] - else: - raw_speech = [raw_speech[:, None].to(torch.float32)] - - audio_lengths = [len(speech) for speech in raw_speech] - batched_speech = BatchFeature({"input_features": raw_speech, "audio_lengths": audio_lengths}) - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ) - input_features = padded_inputs.input_features.squeeze(-1) - - # preemphasis - if self.preemphasis is not None: - timemask = torch.arange(input_features.shape[1], device=input_features.device).unsqueeze( - 0 - ) < padded_inputs.audio_lengths.unsqueeze(1) - input_features = torch.cat( - [input_features[:, :1], input_features[:, 1:] - self.preemphasis * input_features[:, :-1]], dim=1 - ) - input_features = input_features.masked_fill(~timemask, 0.0) - - input_features = self._torch_extract_fbank_features(input_features, device) - features_lengths = torch.floor_divide( - padded_inputs.audio_lengths + self.n_fft // 2 * 2 - self.n_fft, self.hop_length - ) - attention_mask = torch.arange(input_features.shape[1], device=device)[None, :] < features_lengths[:, None] - - # normalize mel features, ignoring padding - mask = attention_mask.unsqueeze(-1) - input_features_masked = input_features * mask - mean = input_features_masked.sum(dim=1) / features_lengths.unsqueeze(-1) - mean = mean.unsqueeze(1) - variance = ((input_features_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) - std = torch.sqrt(variance).unsqueeze(1) - input_features = (input_features - mean) / (std + EPSILON) - input_features *= mask - - return BatchFeature( - data={ - "input_features": input_features, - "attention_mask": attention_mask, - }, - tensor_type=return_tensors, - ) +ParakeetFeatureExtractor = deprecated_feature_extractor(ParakeetAudioProcessor, "ParakeetFeatureExtractor") __all__ = ["ParakeetFeatureExtractor"] diff --git a/src/transformers/models/pe_audio/feature_extraction_pe_audio.py b/src/transformers/models/pe_audio/feature_extraction_pe_audio.py index a7738d3089ac..da1f7d34a86f 100644 --- a/src/transformers/models/pe_audio/feature_extraction_pe_audio.py +++ b/src/transformers/models/pe_audio/feature_extraction_pe_audio.py @@ -11,150 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_pe_audio import PeAudioAudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...processing_utils import load_audio -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class PeAudioFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a PeAudioFeatureExtractor feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. Use 1 for mono, 2 for stereo. - sampling_rate (`int`, *optional*, defaults to 48000): - The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used for padding. - hop_length (`int`, *optional*, defaults to 1920): - Overlap length between successive windows. - """ - - model_input_names = ["input_values"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 48_000, - padding_value: float = 0.0, - hop_length: int = 1920, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.hop_length = hop_length - - def _reflect_pad(self, wav): - if len(wav) % self.hop_length == 0: - return wav - p1d = (0, self.hop_length - (len(wav) % self.hop_length)) - return np.pad(wav, p1d, "reflect") - - def __call__( - self, - raw_audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | str | list[str], - padding: bool | str | PaddingStrategy | None = None, - truncation: bool | None = False, - max_length: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - ) -> BatchFeature: - from_file = False - if isinstance(raw_audio, str): - raw_audio = [raw_audio] - - if isinstance(raw_audio, (list, tuple)) and isinstance(raw_audio[0], str): - loaded = [] - for audio_file in raw_audio: - loaded.append(load_audio(audio_file, self.sampling_rate)) - raw_audio = loaded - from_file = True - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - elif not from_file: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if padding and truncation: - raise ValueError("Both padding and truncation were set. Make sure you only set one.") - elif padding is None: - # by default let's pad the inputs - padding = True - - is_batched = bool( - isinstance(raw_audio, (list, tuple)) and (isinstance(raw_audio[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_audio = [np.asarray(audio, dtype=np.float32).T for audio in raw_audio] - elif not is_batched and not isinstance(raw_audio, np.ndarray): - raw_audio = np.asarray(raw_audio, dtype=np.float32) - elif isinstance(raw_audio, np.ndarray) and raw_audio.dtype is np.dtype(np.float64): - raw_audio = raw_audio.astype(np.float32) - - # always return batch - if not is_batched: - raw_audio = [np.asarray(raw_audio).T] - - if isinstance(raw_audio, list): - raw_audio = [self._reflect_pad(x) for x in raw_audio] - else: - raw_audio = self._reflect_pad(raw_audio) - - # verify inputs are valid - for example in raw_audio: - if example.ndim > 2: - raise ValueError(f"Expected input shape (channels, length) but got shape {example.shape}") - if self.feature_size == 1 and example.ndim != 1: - raise ValueError(f"Expected mono audio but example has {example.shape[-1]} channels") - if self.feature_size == 2: - raise ValueError("Stereo audio isn't supported for now") - - input_values = BatchFeature({"input_values": raw_audio}) - - # normal padding on batch - padded_inputs = self.pad( - input_values, - max_length=max_length, - truncation=truncation, - padding=padding, - return_attention_mask=padding, - pad_to_multiple_of=self.hop_length, - ) - if padding: - padded_inputs["padding_mask"] = padded_inputs.pop("attention_mask") - if padding: - padded_inputs.input_values = padded_inputs.input_values[:, np.newaxis, :] - - input_values = [] - for example in padded_inputs.pop("input_values"): - if self.feature_size == 1: - example = example[..., None] - input_values.append(example.T) - - padded_inputs["input_values"] = input_values - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +PeAudioFeatureExtractor = deprecated_feature_extractor(PeAudioAudioProcessor, "PeAudioFeatureExtractor") __all__ = ["PeAudioFeatureExtractor"] diff --git a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py index 9ce98251e50e..78d4727cbccd 100644 --- a/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py @@ -11,271 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_phi4_multimodal import Phi4MultimodalAudioProcessor -""" -Processor class for Phi4Multimodal -""" - -import numpy as np - -from ...audio_utils import AudioInput, mel_filter_bank -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...image_processing_utils import BatchFeature -from ...utils import TensorType, is_torch_available, logging - - -if is_torch_available(): - import torch - - -logger = logging.get_logger(__name__) - - -class Phi4MultimodalFeatureExtractor(SequenceFeatureExtractor): - model_input_names = ["audio_input_features", "audio_embed_sizes", "audio_attention_mask"] - - def __init__( - self, - feature_size: int = 80, - sampling_rate: int = 16000, - hop_length: int = 160, - n_fft: int = 512, - win_length: int = 400, - preemphasis: float = 0.97, - padding_value: float = 0.0, - audio_compression_rate: int = 8, - audio_downsample_rate: int = 1, - audio_feat_stride: int = 1, - mel_min_frequency: float = 0, - mel_max_frequency: float = 7690, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - self.hop_length = hop_length - self.n_fft = n_fft - self.win_length = win_length - self.preemphasis = preemphasis - self.padding_value = padding_value - self.audio_compression_rate = audio_compression_rate - self.audio_downsample_rate = audio_downsample_rate - self.audio_feat_stride = audio_feat_stride - - self.mel_filters = mel_filter_bank( - num_frequency_bins=self.n_fft // 2 + 1, - num_mel_filters=self.feature_size, - min_frequency=mel_min_frequency, - max_frequency=mel_max_frequency, - sampling_rate=self.sampling_rate, - triangularize_in_mel_space=True, - mel_scale="kaldi", - ) - - def __call__( - self, - raw_speech: AudioInput, - sampling_rate: int | None = None, - pad_to_multiple_of: int | None = None, - padding: str | None = "longest", - max_length: int | None = None, - truncation: bool = False, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = True, - device: str | None = "cpu", - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several audio sequence(s). Implementation uses PyTorch for - the STFT computation if available, otherwise a slower NumPy based one. - - Args: - raw_speech (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array or PyTorch tensor. - For batched inputs, sequences can be a list of numpy arrays or PyTorch tensors, or a single numpy array or - PyTorch tensor with first dimension being the batch size. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - padding (`str`, *optional*, defaults to "longest"): - Padding strategy. Can be "longest" to pad to the longest sequence in the batch, or a specific length. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length. - truncation (`bool`, *optional*, defaults to False): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of numpy arrays. Acceptable values are: - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether to return the extracted audio input features' attention mask. - device (`str`, *optional*, defaults to "cpu"): - Specifies the device for computation of the audio features. (e.g., "cpu", "cuda") - - Returns: - [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - **audio_input_features** -- Audio features extracted from the raw audio input, shape (batch_size, max_feature_length, feature_size). - - **audio_lengths** -- Length of each audio sample in the batch, shape (batch_size,). - - **audio_attention_mask** -- Attention mask for the audio input, shape (batch_size, max_feature_length). - If `return_tensors` is not specified, the fields will be PyTorch tensors if PyTorch is available, otherwise NumPy arrays. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - # Convert to torch tensor - if isinstance(raw_speech, np.ndarray): - raw_speech = torch.tensor(raw_speech) - elif isinstance(raw_speech, (list, tuple)) and isinstance(raw_speech[0], np.ndarray): - raw_speech = [torch.tensor(speech) for speech in raw_speech] - - is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 - if is_batched_torch and len(raw_speech.shape) > 2: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - raw_speech = raw_speech.mean(-1) - - is_batched_sequence = isinstance(raw_speech, (list, tuple)) - if is_batched_sequence: - for speech in raw_speech: - if len(speech.shape) > 1: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - speech = speech.mean(-1) - - if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] - else: - raw_speech = [raw_speech[:, None].to(torch.float32)] - - audio_lengths = [len(speech) for speech in raw_speech] - - # convert into correct format for padding - batched_speech = BatchFeature(data={"audio_input_features": raw_speech, "audio_lengths": audio_lengths}) - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_tensors="pt", - ) - input_features = padded_inputs.audio_input_features.squeeze(-1) - audio_lengths = padded_inputs.audio_lengths - - input_features = self._torch_extract_fbank_features(input_features, audio_lengths, device) - - feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 - feature_lengths = feature_lengths * self.audio_feat_stride - audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) - - feature_attention_mask = ( - torch.arange(0, feature_lengths.max()) if is_torch_available() else np.arange(0, feature_lengths.max()) - ) - feature_attention_mask = ( - feature_attention_mask[None, :] < feature_lengths[:, None] if len(feature_lengths) > 1 else None - ) - - data = { - "audio_input_features": input_features, - "audio_embed_sizes": audio_embed_sizes, - } - if feature_attention_mask is not None and return_attention_mask: - data["audio_attention_mask"] = feature_attention_mask - - return BatchFeature(data=data, tensor_type=return_tensors) - - # TODO; @eustlb, move this to audio_utils in a general spectogram_batch function that handles torch and numpy - def _torch_extract_fbank_features( - self, waveform: "torch.FloatTensor", audio_lengths: "torch.Tensor", device: str = "cpu" - ) -> "torch.FloatTensor": - """ - Compute the log mel-scaled spectrogram of batched waveforms using PyTorch's FFT implementation. - - Args: - waveform (torch.FloatTensor` of shape `(batch_size, max_audio_length)`): - The batched waveforms. - audio_lengths (`torch.Tensor` of shape `(batch_size,)`): - The lengths of the waveforms along the max_audio_length dimension. - device (`str`, *optional*, defaults to "cpu"): - The device to run the computation on. (e.g., "cpu", "cuda") - - Returns: - `torch.FloatTensor` of shape `(batch_size, max_feature_length, feature_size)`: - The log mel-scaled spectrogram of the batched waveforms. - """ - fft_window = torch.hamming_window(self.win_length, periodic=False, device=device, dtype=torch.float64) - - # batched implementation - batch_size = waveform.shape[0] - frames = waveform.unfold(-1, self.win_length, self.hop_length) - - # --- - # the unbatched (and unpaded) original implementation skips last few audio values that can't be included in a frame - # we need to ensure that the corresponding frames for the padded input also mask these values - if batch_size > 1: - frames = frames.clone() - # concerned batch indices - to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()] - if to_mask_batch_idxs.numel() > 0: - batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 - batch_idxs_up = (audio_lengths[to_mask_batch_idxs] // self.hop_length) - 1 - offset_idx = batch_idxs_down.min() - max_idx = batch_idxs_up.max() - - mask = torch.arange(max_idx - offset_idx, device=device).expand(to_mask_batch_idxs.shape[0], -1) - mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( - mask < (batch_idxs_up - offset_idx).unsqueeze(1) - ) - mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length) - masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) - frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames - # --- - - # apply pre-emphasis first order filter on fft windows - frames_prev = torch.roll(frames, 1, dims=-1) - frames_prev[:, :, 0] = frames_prev[:, :, 1] - frames = (frames - self.preemphasis * frames_prev) * 32768 - - # apply fft - S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) - S = S.view(frames.shape[0], -1, S.shape[-1]) - S = S.to(torch.complex64) - - spec = torch.abs(S) - spec_power = spec**2 - - # apply triangular mel filter bank - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) - log_spec = torch.log(log_spec) - - return log_spec - - def _compute_audio_embed_size(self, audio_frames): - integer = audio_frames // self.audio_compression_rate - remainder = audio_frames % self.audio_compression_rate - result = integer + (remainder > 0).to(integer.dtype) - - integer = result // self.audio_downsample_rate - remainder = result % self.audio_downsample_rate - result = integer + (remainder > 0).to(integer.dtype) # qformer compression - - return result +Phi4MultimodalFeatureExtractor = deprecated_feature_extractor( + Phi4MultimodalAudioProcessor, "Phi4MultimodalFeatureExtractor" +) __all__ = ["Phi4MultimodalFeatureExtractor"] diff --git a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py index 4e770fcb1b71..3ab91ec37d43 100644 --- a/src/transformers/models/pop2piano/feature_extraction_pop2piano.py +++ b/src/transformers/models/pop2piano/feature_extraction_pop2piano.py @@ -11,442 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for Pop2Piano""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_pop2piano import Pop2PianoAudioProcessor -import warnings - -import numpy -import numpy as np - -from ...audio_utils import mel_filter_bank, spectrogram -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import ( - TensorType, - is_essentia_available, - is_librosa_available, - is_scipy_available, - logging, - requires_backends, -) -from ...utils.import_utils import requires - - -if is_essentia_available(): - import essentia.standard - -if is_librosa_available(): - import librosa - -if is_scipy_available(): - import scipy - - -logger = logging.get_logger(__name__) - - -@requires(backends=("essentia", "librosa", "scipy", "torch")) -class Pop2PianoFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Pop2Piano feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts rhythm and preprocesses the audio before it is passed to the model. First the audio is passed - to `RhythmExtractor2013` algorithm which extracts the beat_times, beat positions and estimates their confidence as - well as tempo in bpm, then beat_times is interpolated and to get beatsteps. Later we calculate - extrapolated_beatsteps from it to be used in tokenizer. On the other hand audio is resampled to self.sampling_rate - and preprocessed and then log mel spectogram is computed from that to be used in our transformer model. - - Args: - sampling_rate (`int`, *optional*, defaults to 22050): - Target Sampling rate of audio signal. It's the sampling rate that we forward to the model. - padding_value (`int`, *optional*, defaults to 0): - Padding value used to pad the audio. Should correspond to silences. - window_size (`int`, *optional*, defaults to 4096): - Length of the window in samples to which the Fourier transform is applied. - hop_length (`int`, *optional*, defaults to 1024): - Step size between each window of the waveform, in samples. - min_frequency (`float`, *optional*, defaults to 10.0): - Lowest frequency that will be used in the log-mel spectrogram. - feature_size (`int`, *optional*, defaults to 512): - The feature dimension of the extracted features. - num_bars (`int`, *optional*, defaults to 2): - Determines interval between each sequence. - """ - - model_input_names = ["input_features", "beatsteps", "extrapolated_beatstep"] - - def __init__( - self, - sampling_rate: int = 22050, - padding_value: int = 0, - window_size: int = 4096, - hop_length: int = 1024, - min_frequency: float = 10.0, - feature_size: int = 512, - num_bars: int = 2, - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - **kwargs, - ) - self.sampling_rate = sampling_rate - self.padding_value = padding_value - self.window_size = window_size - self.hop_length = hop_length - self.min_frequency = min_frequency - self.feature_size = feature_size - self.num_bars = num_bars - self.mel_filters = mel_filter_bank( - num_frequency_bins=(self.window_size // 2) + 1, - num_mel_filters=self.feature_size, - min_frequency=self.min_frequency, - max_frequency=float(self.sampling_rate // 2), - sampling_rate=self.sampling_rate, - norm=None, - mel_scale="htk", - ) - - def mel_spectrogram(self, sequence: np.ndarray): - """ - Generates MelSpectrogram. - - Args: - sequence (`numpy.ndarray`): - The sequence of which the mel-spectrogram will be computed. - """ - mel_specs = [] - for seq in sequence: - window = np.hanning(self.window_size + 1)[:-1] - mel_specs.append( - spectrogram( - waveform=seq, - window=window, - frame_length=self.window_size, - hop_length=self.hop_length, - power=2.0, - mel_filters=self.mel_filters, - ) - ) - mel_specs = np.array(mel_specs) - - return mel_specs - - def extract_rhythm(self, audio: np.ndarray): - """ - This algorithm(`RhythmExtractor2013`) extracts the beat positions and estimates their confidence as well as - tempo in bpm for an audio signal. For more information please visit - https://essentia.upf.edu/reference/std_RhythmExtractor2013.html . - - Args: - audio(`numpy.ndarray`): - raw audio waveform which is passed to the Rhythm Extractor. - """ - requires_backends(self, ["essentia"]) - essentia_tracker = essentia.standard.RhythmExtractor2013(method="multifeature") - bpm, beat_times, confidence, estimates, essentia_beat_intervals = essentia_tracker(audio) - - return bpm, beat_times, confidence, estimates, essentia_beat_intervals - - def interpolate_beat_times( - self, beat_times: numpy.ndarray, steps_per_beat: numpy.ndarray, n_extend: numpy.ndarray - ): - """ - This method takes beat_times and then interpolates that using `scipy.interpolate.interp1d` and the output is - then used to convert raw audio to log-mel-spectrogram. - - Args: - beat_times (`numpy.ndarray`): - beat_times is passed into `scipy.interpolate.interp1d` for processing. - steps_per_beat (`int`): - used as an parameter to control the interpolation. - n_extend (`int`): - used as an parameter to control the interpolation. - """ - - requires_backends(self, ["scipy"]) - beat_times_function = scipy.interpolate.interp1d( - np.arange(beat_times.size), - beat_times, - bounds_error=False, - fill_value="extrapolate", - ) - - ext_beats = beat_times_function( - np.linspace(0, beat_times.size + n_extend - 1, beat_times.size * steps_per_beat + n_extend) - ) - - return ext_beats - - def preprocess_mel(self, audio: np.ndarray, beatstep: np.ndarray): - """ - Preprocessing for log-mel-spectrogram - - Args: - audio (`numpy.ndarray` of shape `(audio_length, )` ): - Raw audio waveform to be processed. - beatstep (`numpy.ndarray`): - Interpolated values of the raw audio. If beatstep[0] is greater than 0.0, then it will be shifted by - the value at beatstep[0]. - """ - - if audio is not None and len(audio.shape) != 1: - raise ValueError( - f"Expected `audio` to be a single channel audio input of shape `(n, )` but found shape {audio.shape}." - ) - if beatstep[0] > 0.0: - beatstep = beatstep - beatstep[0] - - num_steps = self.num_bars * 4 - num_target_steps = len(beatstep) - extrapolated_beatstep = self.interpolate_beat_times( - beat_times=beatstep, steps_per_beat=1, n_extend=(self.num_bars + 1) * 4 + 1 - ) - - sample_indices = [] - max_feature_length = 0 - for i in range(0, num_target_steps, num_steps): - start_idx = i - end_idx = min(i + num_steps, num_target_steps) - start_sample = int(extrapolated_beatstep[start_idx] * self.sampling_rate) - end_sample = int(extrapolated_beatstep[end_idx] * self.sampling_rate) - sample_indices.append((start_sample, end_sample)) - max_feature_length = max(max_feature_length, end_sample - start_sample) - padded_batch = [] - for start_sample, end_sample in sample_indices: - feature = audio[start_sample:end_sample] - padded_feature = np.pad( - feature, - ((0, max_feature_length - feature.shape[0]),), - "constant", - constant_values=0, - ) - padded_batch.append(padded_feature) - - padded_batch = np.asarray(padded_batch) - return padded_batch, extrapolated_beatstep - - def _pad(self, features: np.ndarray, add_zero_line=True): - features_shapes = [each_feature.shape for each_feature in features] - attention_masks, padded_features = [], [] - for i, each_feature in enumerate(features): - # To pad "input_features". - if len(each_feature.shape) == 3: - features_pad_value = max([*zip(*features_shapes)][1]) - features_shapes[i][1] - attention_mask = np.ones(features_shapes[i][:2], dtype=np.int64) - feature_padding = ((0, 0), (0, features_pad_value), (0, 0)) - attention_mask_padding = (feature_padding[0], feature_padding[1]) - - # To pad "beatsteps" and "extrapolated_beatstep". - else: - each_feature = each_feature.reshape(1, -1) - features_pad_value = max([*zip(*features_shapes)][0]) - features_shapes[i][0] - attention_mask = np.ones(features_shapes[i], dtype=np.int64).reshape(1, -1) - feature_padding = attention_mask_padding = ((0, 0), (0, features_pad_value)) - - each_padded_feature = np.pad(each_feature, feature_padding, "constant", constant_values=self.padding_value) - attention_mask = np.pad( - attention_mask, attention_mask_padding, "constant", constant_values=self.padding_value - ) - - if add_zero_line: - # if it is batched then we separate each examples using zero array - zero_array_len = max([*zip(*features_shapes)][1]) - - # we concatenate the zero array line here - each_padded_feature = np.concatenate( - [each_padded_feature, np.zeros([1, zero_array_len, self.feature_size])], axis=0 - ) - attention_mask = np.concatenate( - [attention_mask, np.zeros([1, zero_array_len], dtype=attention_mask.dtype)], axis=0 - ) - - padded_features.append(each_padded_feature) - attention_masks.append(attention_mask) - - padded_features = np.concatenate(padded_features, axis=0).astype(np.float32) - attention_masks = np.concatenate(attention_masks, axis=0).astype(np.int64) - - return padded_features, attention_masks - - def pad( - self, - inputs: BatchFeature, - is_batched: bool, - return_attention_mask: bool, - return_tensors: str | TensorType | None = None, - ): - """ - Pads the inputs to same length and returns attention_mask. - - Args: - inputs (`BatchFeature`): - Processed audio features. - is_batched (`bool`): - Whether inputs are batched or not. - return_attention_mask (`bool`): - Whether to return attention mask or not. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - If nothing is specified, it will return list of `np.ndarray` arrays. - Return: - `BatchFeature` with attention_mask, attention_mask_beatsteps and attention_mask_extrapolated_beatstep added - to it: - - **attention_mask** numpy.ndarray of shape `(batch_size, max_input_features_seq_length)` -- - Example : - 1, 1, 1, 0, 0 (audio 1, also here it is padded to max length of 5 that's why there are 2 zeros at - the end indicating they are padded) - - 0, 0, 0, 0, 0 (zero pad to separate audio 1 and 2) - - 1, 1, 1, 1, 1 (audio 2) - - 0, 0, 0, 0, 0 (zero pad to separate audio 2 and 3) - - 1, 1, 1, 1, 1 (audio 3) - - **attention_mask_beatsteps** numpy.ndarray of shape `(batch_size, max_beatsteps_seq_length)` - - **attention_mask_extrapolated_beatstep** numpy.ndarray of shape `(batch_size, - max_extrapolated_beatstep_seq_length)` - """ - - processed_features_dict = {} - for feature_name, feature_value in inputs.items(): - if feature_name == "input_features": - padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=True) - processed_features_dict[feature_name] = padded_feature_values - if return_attention_mask: - processed_features_dict["attention_mask"] = attention_mask - else: - padded_feature_values, attention_mask = self._pad(feature_value, add_zero_line=False) - processed_features_dict[feature_name] = padded_feature_values - if return_attention_mask: - processed_features_dict[f"attention_mask_{feature_name}"] = attention_mask - - # If we are processing only one example, we should remove the zero array line since we don't need it to - # separate examples from each other. - if not is_batched and not return_attention_mask: - processed_features_dict["input_features"] = processed_features_dict["input_features"][:-1, ...] - - outputs = BatchFeature(processed_features_dict, tensor_type=return_tensors) - - return outputs - - def __call__( - self, - audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - sampling_rate: int | list[int], - steps_per_beat: int = 2, - resample: bool | None = True, - return_attention_mask: bool | None = False, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model. - - Args: - audio (`np.ndarray`, `List`): - The audio or batch of audio to be processed. Each audio can be a numpy array, a list of float values, a - list of numpy arrays or a list of list of float values. - sampling_rate (`int`): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - steps_per_beat (`int`, *optional*, defaults to 2): - This is used in interpolating `beat_times`. - resample (`bool`, *optional*, defaults to `True`): - Determines whether to resample the audio to `sampling_rate` or not before processing. Must be True - during inference. - return_attention_mask (`bool` *optional*, defaults to `False`): - Denotes if attention_mask for input_features, beatsteps and extrapolated_beatstep will be given as - output or not. Automatically set to True for batched inputs. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - If nothing is specified, it will return list of `np.ndarray` arrays. - """ - - requires_backends(self, ["librosa"]) - is_batched = isinstance(audio, (list, tuple)) and isinstance(audio[0], (np.ndarray, tuple, list)) - if is_batched: - # This enables the user to process files of different sampling_rate at same time - if not isinstance(sampling_rate, list): - raise ValueError( - "Please give sampling_rate of each audio separately when you are passing multiple raw_audios at the same time. " - f"Received {sampling_rate}, expected [audio_1_sr, ..., audio_n_sr]." - ) - return_attention_mask = True if return_attention_mask is None else return_attention_mask - else: - audio = [audio] - sampling_rate = [sampling_rate] - return_attention_mask = False if return_attention_mask is None else return_attention_mask - - batch_input_features, batch_beatsteps, batch_ext_beatstep = [], [], [] - for single_raw_audio, single_sampling_rate in zip(audio, sampling_rate): - bpm, beat_times, confidence, estimates, essentia_beat_intervals = self.extract_rhythm( - audio=single_raw_audio - ) - beatsteps = self.interpolate_beat_times(beat_times=beat_times, steps_per_beat=steps_per_beat, n_extend=1) - - if self.sampling_rate != single_sampling_rate and self.sampling_rate is not None: - if resample: - # Change sampling_rate to self.sampling_rate - single_raw_audio = librosa.core.resample( - single_raw_audio, - orig_sr=single_sampling_rate, - target_sr=self.sampling_rate, - res_type="kaiser_best", - ) - else: - warnings.warn( - f"The sampling_rate of the provided audio is different from the target sampling_rate " - f"of the Feature Extractor, {self.sampling_rate} vs {single_sampling_rate}. " - f"In these cases it is recommended to use `resample=True` in the `__call__` method to " - f"get the optimal behaviour." - ) - - single_sampling_rate = self.sampling_rate - start_sample = int(beatsteps[0] * single_sampling_rate) - end_sample = int(beatsteps[-1] * single_sampling_rate) - - input_features, extrapolated_beatstep = self.preprocess_mel( - single_raw_audio[start_sample:end_sample], beatsteps - beatsteps[0] - ) - - mel_specs = self.mel_spectrogram(input_features.astype(np.float32)) - - # apply np.log to get log mel-spectrograms - log_mel_specs = np.log(np.clip(mel_specs, a_min=1e-6, a_max=None)) - - input_features = np.transpose(log_mel_specs, (0, -1, -2)) - - batch_input_features.append(input_features) - batch_beatsteps.append(beatsteps) - batch_ext_beatstep.append(extrapolated_beatstep) - - output = BatchFeature( - { - "input_features": batch_input_features, - "beatsteps": batch_beatsteps, - "extrapolated_beatstep": batch_ext_beatstep, - } - ) - - output = self.pad( - output, - is_batched=is_batched, - return_attention_mask=return_attention_mask, - return_tensors=return_tensors, - ) - - return output +Pop2PianoFeatureExtractor = deprecated_feature_extractor(Pop2PianoAudioProcessor, "Pop2PianoFeatureExtractor") __all__ = ["Pop2PianoFeatureExtractor"] diff --git a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py index 1b18dcc33404..174bc72baa16 100644 --- a/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py @@ -11,295 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for SeamlessM4T -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_seamless_m4t import SeamlessM4tAudioProcessor -import numpy as np - -from ...utils import is_torch_available - - -if is_torch_available(): - import torch - -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class SeamlessM4TFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a SeamlessM4T feature extractor. - - This feature extractor inherits from [`SequenceFeatureExtractor`] which contains most of the main methods. Users - should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - num_mel_bins (`int`, *optional*, defaults to 80): - Number of Mel-frequency bins. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding vectors. - stride (`int`, *optional*, defaults to 2): - Stride used to reshape audios from shape (batch_size,num_frames,num_mel_bins) to - (batch_size,num_frames//stride,num_mel_bins*stride). - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=80, - sampling_rate=16000, - num_mel_bins=80, - padding_value=0.0, - stride=2, - **kwargs, - ): - self.num_mel_bins = num_mel_bins - self.return_attention_mask = True - self.stride = stride - - mel_filters = mel_filter_bank( - num_frequency_bins=257, - num_mel_filters=self.num_mel_bins, - min_frequency=20, - max_frequency=sampling_rate // 2, - sampling_rate=sampling_rate, - norm=None, - mel_scale="kaldi", - triangularize_in_mel_space=True, - ) - - self.mel_filters = mel_filters - self.window = window_function(400, "povey", periodic=False) - - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - @staticmethod - # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm - def zero_mean_unit_var_norm( - input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 - ) -> list[np.ndarray]: - """ - Every array in the list is normalized to have zero mean and unit variance - """ - if attention_mask is not None: - attention_mask = np.array(attention_mask, np.int32) - normed_input_values = [] - - for vector, length in zip(input_values, attention_mask.sum(-1)): - normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) - if length < normed_slice.shape[0]: - normed_slice[length:] = padding_value - - normed_input_values.append(normed_slice) - else: - normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] - - return normed_input_values - - def _extract_fbank_features( - self, - waveform: np.ndarray, - ) -> np.ndarray: - """ - Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs - and hence the waveform should not be normalized before feature extraction. - """ - # by default, it extracts the left channel if stereo - if len(waveform.shape) == 2: - waveform = waveform[0] - - waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers - features = spectrogram( - waveform, - self.window, - frame_length=400, - hop_length=160, - fft_length=512, - power=2.0, - center=False, - preemphasis=0.97, - mel_filters=self.mel_filters, - log_mel="log", - mel_floor=1.192092955078125e-07, - remove_dc_offset=True, - ).T - return features - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy = True, - pad_to_multiple_of: int | None = 2, - max_length: int | None = None, - truncation: bool = False, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - return_attention_mask: bool | None = None, - do_normalize_per_mel_bins: bool | None = True, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `torch.Tensor`, `list[float]`, `list[np.ndarray]`, `list[torch.Tensor]`, - `list[list[float]]`, `list[list[list[float]]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, - a torch tensor, a list of float values, a list of numpy arrays, a list of torch tensors, - a list of list of float values or a list of a list of list of float values. - If `raw_speech` is a one-dimensional `np.ndarray`, `torch.Tensor` or a `list[float]`, `raw_speech` is - considered a single-channel, single-sample sound. In all other cases, the first dimension of - `raw_speech`, whether from an `np.ndarray`, a `torch.Tensor` or a `list[...]`, - corresponds to the number of samples in the batch, and the number of channels - (i.e. mono or stereo character) is derived from the other dimensions - (1D -> single-channel waveform batches; 2D-> stereo-channel waveform batches). - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - pad_to_multiple_of (`int`, *optional*, defaults to 2): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For SeamlessM4T models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - do_normalize_per_mel_bins (`bool`, *optional*, defaults to `True`): - Whether or not to zero-mean unit-variance normalize the input per mel-channel. - kwargs (*optional*): - Remaining dictionary of keyword arguments that will be passed to the tokenizer or the feature - extractor. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - return_attention_mask = ( - return_attention_mask if return_attention_mask is not None else self.return_attention_mask - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 3: - raise ValueError(f"Only mono-channel or stereo-channel audio is supported for input to {self}") - - acceptable_types = ( - (torch.Tensor, np.ndarray, tuple, list) if is_torch_available() else (np.ndarray, tuple, list) - ) - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], acceptable_types)) - ) - - if is_batched: - raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [raw_speech] - - # extract fbank features - features = [self._extract_fbank_features(waveform) for waveform in raw_speech] - - if do_normalize_per_mel_bins: - # torch defaults to ddof=1, and numpy defaults to ddof=0 - features = [ - (x - np.expand_dims(x.mean(0), 0)) / np.sqrt(np.expand_dims(x.var(0, ddof=1), 0) + 1e-7) - for x in features - ] - - # convert into correct format for padding - encoded_inputs = BatchFeature({"input_features": features}) - - padded_inputs = self.pad( - encoded_inputs, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=True, - return_tensors="np", - ) - - # SeamlessM4T needs to process extracted features - input_features = padded_inputs.get("input_features") - attention_mask = padded_inputs.pop("attention_mask") - - batch_size, num_frames, num_channels = input_features.shape - - remainder = num_frames % self.stride - if remainder != 0: - input_features = input_features[:, : num_frames - remainder, :] - attention_mask = attention_mask[:, : num_frames - remainder] - - input_features = np.reshape( - input_features, (batch_size, num_frames // self.stride, num_channels * self.stride) - ) - - indices = np.arange(0, num_frames - remainder) - attention_mask = attention_mask[:, indices % self.stride == 1] - - padded_inputs["input_features"] = input_features - if return_attention_mask: - padded_inputs["attention_mask"] = attention_mask - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +SeamlessM4TFeatureExtractor = deprecated_feature_extractor(SeamlessM4tAudioProcessor, "SeamlessM4TFeatureExtractor") __all__ = ["SeamlessM4TFeatureExtractor"] diff --git a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py index 9685e9be0134..584afc35f229 100644 --- a/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py +++ b/src/transformers/models/speech_to_text/feature_extraction_speech_to_text.py @@ -11,301 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for Speech2Text -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_speech_to_text import SpeechToTextAudioProcessor -import numpy as np - -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, is_speech_available, logging - - -if is_speech_available(): - import torch - import torchaudio.compliance.kaldi as ta_kaldi - -logger = logging.get_logger(__name__) - - -class Speech2TextFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Speech2Text feature extractor. - - This feature extractor inherits from [`Speech2TextFeatureExtractor`] which contains most of the main methods. Users - should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using TorchAudio if installed or using numpy - otherwise, and applies utterance-level cepstral mean and variance normalization to the extracted features. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - num_mel_bins (`int`, *optional*, defaults to 80): - Number of Mel-frequency bins. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding vectors. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering. In other words, adds a small Gaussian noise to each frame. - E.g. use 4.0 to add dithering with a normal distribution centered - around 0.0 with standard deviation 4.0 (assuming [-32k,+32k] range of kaldi waveform). - The value 0.0 means no dithering. - Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank - values for signals with hard-zero sections, when VAD cutoff is present in the signal. - do_ceptral_normalize (`bool`, *optional*, defaults to `True`): - Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features. - normalize_means (`bool`, *optional*, defaults to `True`): - Whether or not to zero-mean normalize the extracted features. - normalize_vars (`bool`, *optional*, defaults to `True`): - Whether or not to unit-variance normalize the extracted features. - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=80, - sampling_rate=16000, - num_mel_bins=80, - padding_value=0.0, - dither=0.0, - do_ceptral_normalize=True, - normalize_means=True, - normalize_vars=True, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.num_mel_bins = num_mel_bins - self.dither = dither - self.do_ceptral_normalize = do_ceptral_normalize - self.normalize_means = normalize_means - self.normalize_vars = normalize_vars - self.return_attention_mask = True - - if not is_speech_available(): - mel_filters = mel_filter_bank( - num_frequency_bins=257, - num_mel_filters=self.num_mel_bins, - min_frequency=20, - max_frequency=sampling_rate // 2, - sampling_rate=sampling_rate, - norm=None, - mel_scale="kaldi", - triangularize_in_mel_space=True, - ) - - self.mel_filters = mel_filters - self.window = window_function(400, "povey", periodic=False) - - def _extract_fbank_features( - self, - waveform: np.ndarray, - ) -> np.ndarray: - """ - Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs - and hence the waveform should not be normalized before feature extraction. - """ - waveform = waveform * (2**15) # Kaldi compliance: 16-bit signed integers - if is_speech_available(): - waveform = torch.from_numpy(waveform).unsqueeze(0) - features = ta_kaldi.fbank( - waveform, - dither=self.dither, - num_mel_bins=self.num_mel_bins, - sample_frequency=self.sampling_rate, - ) - features = features.numpy() - else: - waveform = np.squeeze(waveform) - features = spectrogram( - waveform, - self.window, - frame_length=400, - hop_length=160, - fft_length=512, - power=2.0, - center=False, - dither=self.dither, - preemphasis=0.97, - mel_filters=self.mel_filters, - log_mel="log", - mel_floor=1.192092955078125e-07, - remove_dc_offset=True, - ).T - return features - - @staticmethod - def utterance_cmvn( - x: np.ndarray, - input_length: int, - normalize_means: bool | None = True, - normalize_vars: bool | None = True, - padding_value: float = 0.0, - ) -> np.ndarray: - # make sure we normalize float32 arrays - if normalize_means: - mean = x[:input_length].mean(axis=0) - x = np.subtract(x, mean) - if normalize_vars: - std = x[:input_length].std(axis=0) - x = np.divide(x, std) - - if input_length < x.shape[0]: - x[input_length:] = padding_value - - # make sure array is in float32 - x = x.astype(np.float32) - - return x - - def normalize( - self, input_features: list[np.ndarray], attention_mask: np.ndarray | None = None - ) -> list[np.ndarray]: - lengths = attention_mask.sum(-1) if attention_mask is not None else [x.shape[0] for x in input_features] - return [ - self.utterance_cmvn(x, n, self.normalize_means, self.normalize_vars, self.padding_value) - for x, n in zip(input_features, lengths) - ] - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy = False, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - return_attention_mask: bool | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For Speech2TextTransformer models, `attention_mask` should always be passed for batched inference, to - avoid subtle bugs. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values / vectors. - """ - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [raw_speech] - - # extract fbank features - features = [self._extract_fbank_features(waveform) for waveform in raw_speech] - - # convert into correct format for padding - encoded_inputs = BatchFeature({"input_features": features}) - - padded_inputs = self.pad( - encoded_inputs, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - **kwargs, - ) - - # make sure list is in array format - input_features = padded_inputs.get("input_features") - if isinstance(input_features[0], list): - padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] - - attention_mask = padded_inputs.get("attention_mask") - if attention_mask is not None: - padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] - - # Utterance-level cepstral mean and variance normalization - if self.do_ceptral_normalize: - attention_mask = ( - np.array(attention_mask, dtype=np.int32) - if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD - else None - ) - padded_inputs["input_features"] = self.normalize( - padded_inputs["input_features"], attention_mask=attention_mask - ) - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +Speech2TextFeatureExtractor = deprecated_feature_extractor(SpeechToTextAudioProcessor, "Speech2TextFeatureExtractor") __all__ = ["Speech2TextFeatureExtractor"] diff --git a/src/transformers/models/speecht5/feature_extraction_speecht5.py b/src/transformers/models/speecht5/feature_extraction_speecht5.py index 5b9ca2e1f954..1aece171a6f3 100644 --- a/src/transformers/models/speecht5/feature_extraction_speecht5.py +++ b/src/transformers/models/speecht5/feature_extraction_speecht5.py @@ -11,364 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for SpeechT5.""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_speecht5 import SpeechT5AudioProcessor -from typing import Any - -import numpy as np - -from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class SpeechT5FeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a SpeechT5 feature extractor. - - This class can pre-process a raw speech signal by (optionally) normalizing to zero-mean unit-variance, for use by - the SpeechT5 speech encoder prenet. - - This class can also extract log-mel filter bank features from raw speech, for use by the SpeechT5 speech decoder - prenet. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance for some models. - num_mel_bins (`int`, *optional*, defaults to 80): - The number of mel-frequency bins in the extracted spectrogram features. - hop_length (`int`, *optional*, defaults to 16): - Number of ms between windows. Otherwise referred to as "shift" in many papers. - win_length (`int`, *optional*, defaults to 64): - Number of ms per window. - win_function (`str`, *optional*, defaults to `"hann_window"`): - Name for the window function used for windowing, must be accessible via `torch.{win_function}` - fmin (`float`, *optional*, defaults to 80): - Minimum mel frequency in Hz. - fmax (`float`, *optional*, defaults to 7600): - Maximum mel frequency in Hz. - mel_floor (`float`, *optional*, defaults to 1e-10): - Minimum value of mel frequency banks.. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. - """ - - model_input_names = ["input_values", "attention_mask"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 16000, - padding_value: float = 0.0, - do_normalize: bool = False, - num_mel_bins: int = 80, - hop_length: int = 16, - win_length: int = 64, - win_function: str = "hann_window", - fmin: float = 80, - fmax: float = 7600, - mel_floor: float = 1e-10, - return_attention_mask: bool = True, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.do_normalize = do_normalize - self.return_attention_mask = return_attention_mask - - self.num_mel_bins = num_mel_bins - self.hop_length = hop_length - self.win_length = win_length - self.win_function = win_function - self.fmin = fmin - self.fmax = fmax - self.mel_floor = mel_floor - - self.sample_size = win_length * sampling_rate // 1000 - self.sample_stride = hop_length * sampling_rate // 1000 - self.n_fft = optimal_fft_length(self.sample_size) - self.n_freqs = (self.n_fft // 2) + 1 - - self.window = window_function(window_length=self.sample_size, name=self.win_function, periodic=True) - - self.mel_filters = mel_filter_bank( - num_frequency_bins=self.n_freqs, - num_mel_filters=self.num_mel_bins, - min_frequency=self.fmin, - max_frequency=self.fmax, - sampling_rate=self.sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - - @staticmethod - # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm - def zero_mean_unit_var_norm( - input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 - ) -> list[np.ndarray]: - """ - Every array in the list is normalized to have zero mean and unit variance - """ - if attention_mask is not None: - attention_mask = np.array(attention_mask, np.int32) - normed_input_values = [] - - for vector, length in zip(input_values, attention_mask.sum(-1)): - normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) - if length < normed_slice.shape[0]: - normed_slice[length:] = padding_value - - normed_input_values.append(normed_slice) - else: - normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] - - return normed_input_values - - def _extract_mel_features( - self, - one_waveform: np.ndarray, - ) -> np.ndarray: - """ - Extracts log-mel filterbank features for one waveform array (unbatched). - """ - log_mel_spec = spectrogram( - one_waveform, - window=self.window, - frame_length=self.sample_size, - hop_length=self.sample_stride, - fft_length=self.n_fft, - mel_filters=self.mel_filters, - mel_floor=self.mel_floor, - log_mel="log10", - ) - return log_mel_spec.T - - def __call__( - self, - audio: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | None = None, - audio_target: np.ndarray | list[float] | list[np.ndarray] | list[list[float]] | None = None, - padding: bool | str | PaddingStrategy = False, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_attention_mask: bool | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Pass in a value for `audio` to extract waveform features. Pass in a value for `audio_target` to extract log-mel - spectrogram features. - - Args: - audio (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*): - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. This outputs waveform features. Must - be mono channel audio, not stereo, i.e. single float per timestep. - audio_target (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`, *optional*): - The sequence or batch of sequences to be processed as targets. Each sequence can be a numpy array, a - list of float values, a list of numpy arrays or a list of list of float values. This outputs log-mel - spectrogram features. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` or `audio_target` input was sampled. It is strongly recommended - to pass `sampling_rate` at the forward call to prevent silent errors. - """ - if audio is None and audio_target is None: - raise ValueError("You must provide either `audio` or `audio_target` values.") - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - if audio is not None: - inputs = self._process_audio( - audio, - False, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_attention_mask, - return_tensors, - **kwargs, - ) - else: - inputs = None - - if audio_target is not None: - inputs_target = self._process_audio( - audio_target, - True, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_attention_mask, - return_tensors, - **kwargs, - ) - - if inputs is None: - return inputs_target - else: - inputs["labels"] = inputs_target["input_values"] - decoder_attention_mask = inputs_target.get("attention_mask") - if decoder_attention_mask is not None: - inputs["decoder_attention_mask"] = decoder_attention_mask - - return inputs - - def _process_audio( - self, - speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - is_target: bool = False, - padding: bool | str | PaddingStrategy = False, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_attention_mask: bool | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> BatchFeature: - is_batched_numpy = isinstance(speech, np.ndarray) and len(speech.shape) > 1 - if is_batched_numpy and len(speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(speech, (list, tuple)) and (isinstance(speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - speech = [np.asarray(speech, dtype=np.float32) for speech in speech] - elif not is_batched and not isinstance(speech, np.ndarray): - speech = np.asarray(speech, dtype=np.float32) - elif isinstance(speech, np.ndarray) and speech.dtype is np.dtype(np.float64): - speech = speech.astype(np.float32) - - # always return batch - if not is_batched: - speech = [speech] - - # needed to make pad() work on spectrogram inputs - feature_size_hack = self.feature_size - - # convert into correct format for padding - if is_target: - features = [self._extract_mel_features(waveform) for waveform in speech] - encoded_inputs = BatchFeature({"input_values": features}) - self.feature_size = self.num_mel_bins - else: - encoded_inputs = BatchFeature({"input_values": speech}) - - padded_inputs = self.pad( - encoded_inputs, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - **kwargs, - ) - - self.feature_size = feature_size_hack - - # convert input values to correct format - input_values = padded_inputs["input_values"] - if not isinstance(input_values[0], np.ndarray): - padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] - elif ( - not isinstance(input_values, np.ndarray) - and isinstance(input_values[0], np.ndarray) - and input_values[0].dtype is np.dtype(np.float64) - ): - padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] - elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): - padded_inputs["input_values"] = input_values.astype(np.float32) - - # convert attention_mask to correct format - attention_mask = padded_inputs.get("attention_mask") - if attention_mask is not None: - padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] - - # zero-mean and unit-variance normalization - if not is_target and self.do_normalize: - attention_mask = ( - attention_mask - if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD - else None - ) - padded_inputs["input_values"] = self.zero_mean_unit_var_norm( - padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value - ) - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs - - def to_dict(self) -> dict[str, Any]: - output = super().to_dict() - - # Don't serialize these as they are derived from the other properties. - names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"] - for name in names: - if name in output: - del output[name] - - return output +SpeechT5FeatureExtractor = deprecated_feature_extractor(SpeechT5AudioProcessor, "SpeechT5FeatureExtractor") __all__ = ["SpeechT5FeatureExtractor"] diff --git a/src/transformers/models/univnet/feature_extraction_univnet.py b/src/transformers/models/univnet/feature_extraction_univnet.py index 84e9420a0f75..73ae758ee708 100644 --- a/src/transformers/models/univnet/feature_extraction_univnet.py +++ b/src/transformers/models/univnet/feature_extraction_univnet.py @@ -11,448 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Feature extractor class for UnivNetModel.""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_univnet import UnivNetAudioProcessor -from typing import Any - -import numpy as np - -from ...audio_utils import mel_filter_bank, optimal_fft_length, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class UnivNetFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a UnivNet feature extractor. - - This class extracts log-mel-filter bank features from raw speech using the short time Fourier Transform (STFT). The - STFT implementation follows that of TacoTron 2 and Hifi-GAN. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 24000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value to pad with when applying the padding strategy defined by the `padding` argument to - [`UnivNetFeatureExtractor.__call__`]. Should correspond to audio silence. The `pad_end` argument to - `__call__` will also use this padding value. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve the - performance for some models. - num_mel_bins (`int`, *optional*, defaults to 100): - The number of mel-frequency bins in the extracted spectrogram features. This should match - `UnivNetModel.config.num_mel_bins`. - hop_length (`int`, *optional*, defaults to 256): - The direct number of samples between sliding windows. Otherwise referred to as "shift" in many papers. Note - that this is different from other audio feature extractors such as [`SpeechT5FeatureExtractor`] which take - the `hop_length` in ms. - win_length (`int`, *optional*, defaults to 1024): - The direct number of samples for each sliding window. Note that this is different from other audio feature - extractors such as [`SpeechT5FeatureExtractor`] which take the `win_length` in ms. - win_function (`str`, *optional*, defaults to `"hann_window"`): - Name for the window function used for windowing, must be accessible via `torch.{win_function}` - filter_length (`int`, *optional*, defaults to 1024): - The number of FFT components to use. If `None`, this is determined using - `transformers.audio_utils.optimal_fft_length`. - max_length_s (`int`, *optional*, defaults to 10): - The maximum input length of the model in seconds. This is used to pad the audio. - fmin (`float`, *optional*, defaults to 0.0): - Minimum mel frequency in Hz. - fmax (`float`, *optional*): - Maximum mel frequency in Hz. If not set, defaults to `sampling_rate / 2`. - mel_floor (`float`, *optional*, defaults to 1e-09): - Minimum value of mel frequency banks. Note that the way [`UnivNetFeatureExtractor`] uses `mel_floor` is - different than in [`transformers.audio_utils.spectrogram`]. - center (`bool`, *optional*, defaults to `False`): - Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame - `t` will start at time `t * hop_length`. - compression_factor (`float`, *optional*, defaults to 1.0): - The multiplicative compression factor for dynamic range compression during spectral normalization. - compression_clip_val (`float`, *optional*, defaults to 1e-05): - The clip value applied to the waveform before applying dynamic range compression during spectral - normalization. - normalize_min (`float`, *optional*, defaults to -11.512925148010254): - The min value used for Tacotron 2-style linear normalization. The default is the original value from the - Tacotron 2 implementation. - normalize_max (`float`, *optional*, defaults to 2.3143386840820312): - The max value used for Tacotron 2-style linear normalization. The default is the original value from the - Tacotron 2 implementation. - model_in_channels (`int`, *optional*, defaults to 64): - The number of input channels to the [`UnivNetModel`] model. This should match - `UnivNetModel.config.model_in_channels`. - pad_end_length (`int`, *optional*, defaults to 10): - If padding the end of each waveform, the number of spectrogram frames worth of samples to append. The - number of appended samples will be `pad_end_length * hop_length`. - return_attention_mask (`bool`, *optional*, defaults to `True`): - Whether or not [`~UnivNetFeatureExtractor.__call__`] should return `attention_mask`. - """ - - model_input_names = ["input_features", "noise_sequence", "padding_mask"] - - def __init__( - self, - feature_size: int = 1, - sampling_rate: int = 24000, - padding_value: float = 0.0, - do_normalize: bool = False, - num_mel_bins: int = 100, - hop_length: int = 256, - win_length: int = 1024, - win_function: str = "hann_window", - filter_length: int | None = 1024, - max_length_s: int = 10, - fmin: float = 0.0, - fmax: float | None = None, - mel_floor: float = 1e-9, - center: bool = False, - compression_factor: float = 1.0, - compression_clip_val: float = 1e-5, - normalize_min: float = -11.512925148010254, - normalize_max: float = 2.3143386840820312, - model_in_channels: int = 64, - pad_end_length: int = 10, - return_attention_mask=True, - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - - self.do_normalize = do_normalize - - self.num_mel_bins = num_mel_bins - self.hop_length = hop_length - self.win_length = win_length - self.win_function = win_function - self.filter_length = filter_length - self.fmin = fmin - if fmax is None: - # Follows the librosa.filters.mel implementation - fmax = float(sampling_rate) / 2 - self.fmax = fmax - self.mel_floor = mel_floor - - self.max_length_s = max_length_s - self.num_max_samples = max_length_s * sampling_rate - - if self.filter_length is None: - self.n_fft = optimal_fft_length(self.win_length) - else: - self.n_fft = self.filter_length - self.n_freqs = (self.n_fft // 2) + 1 - - self.window = window_function(window_length=self.win_length, name=self.win_function, periodic=True) - - self.mel_filters = mel_filter_bank( - num_frequency_bins=self.n_freqs, - num_mel_filters=self.num_mel_bins, - min_frequency=self.fmin, - max_frequency=self.fmax, - sampling_rate=self.sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - - self.center = center - self.compression_factor = compression_factor - self.compression_clip_val = compression_clip_val - self.normalize_min = normalize_min - self.normalize_max = normalize_max - self.model_in_channels = model_in_channels - self.pad_end_length = pad_end_length - - def normalize(self, spectrogram): - return 2 * ((spectrogram - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 - - def denormalize(self, spectrogram): - return self.normalize_min + (self.normalize_max - self.normalize_min) * ((spectrogram + 1) / 2) - - def mel_spectrogram(self, waveform: np.ndarray) -> np.ndarray: - """ - Calculates log MEL spectrograms from a batch of waveforms. Note that the input waveform(s) will be padded by - `int(self.n_fft - self.hop_length) / 2` on both sides using the `reflect` padding mode. - - Args: - waveform (`np.ndarray` of shape `(length,)`): - The input waveform. This must be a single real-valued, mono waveform. - - Returns: - `numpy.ndarray`: Array containing a log-mel spectrogram of shape `(num_frames, num_mel_bins)`. - """ - # Do custom padding based on the official MelGAN and Hifi-GAN implementations - # See https://github.com/maum-ai/univnet/blob/9bb2b54838bb6d7ce767131cc7b8b61198bc7558/utils/stft.py#L84-L86 - waveform = np.pad( - waveform, - (int((self.n_fft - self.hop_length) / 2), int((self.n_fft - self.hop_length) / 2)), - mode="reflect", - ) - - # Get the complex spectrogram. - # Note: waveform must be unbatched currently due to the implementation of spectrogram(...). - complex_spectrogram = spectrogram( - waveform, - window=self.window, - frame_length=self.n_fft, - hop_length=self.hop_length, - fft_length=self.n_fft, - power=None, - center=self.center, - mel_filters=None, - mel_floor=None, - ) - - # Apply the MEL filter bank and MEL floor manually since UnivNet uses a slightly different implementation - amplitude_spectrogram = np.sqrt( - np.real(complex_spectrogram) ** 2 + np.imag(complex_spectrogram) ** 2 + self.mel_floor - ) - mel_spectrogram = np.matmul(self.mel_filters.T, amplitude_spectrogram) - - # Perform spectral normalization to get the log mel spectrogram. - log_mel_spectrogram = np.log( - np.clip(mel_spectrogram, a_min=self.compression_clip_val, a_max=None) * self.compression_factor - ) - - # Return spectrogram with num_mel_bins last - return log_mel_spectrogram.T - - def generate_noise( - self, - noise_length: int, - generator: np.random.Generator | None = None, - ) -> np.ndarray: - """ - Generates a random noise sequence of standard Gaussian noise for use in the `noise_sequence` argument of - [`UnivNetModel.forward`]. - - Args: - spectrogram_length (`int`): - The length (dim 0) of the generated noise. - model_in_channels (`int`, *optional*, defaults to `None`): - The number of features (dim 1) of the generated noise. This should correspond to the - `model_in_channels` of the [`UnivNetGan`] model. If not set, this will default to - `self.config.model_in_channels`. - generator (`numpy.random.Generator`, *optional*, defaults to `None`) - An optional `numpy.random.Generator` random number generator to control noise generation. If not set, a - new generator with fresh entropy will be created. - - Returns: - `numpy.ndarray`: Array containing random standard Gaussian noise of shape `(noise_length, - model_in_channels)`. - """ - if generator is None: - generator = np.random.default_rng() - - noise_shape = (noise_length, self.model_in_channels) - noise = generator.standard_normal(noise_shape, dtype=np.float32) - - return noise - - def batch_decode(self, waveforms, waveform_lengths=None) -> list[np.ndarray]: - r""" - Removes padding from generated audio after running [`UnivNetModel.forward`]. This returns a ragged list of 1D - audio waveform arrays and not a single tensor/array because in general the waveforms will have different - lengths after removing padding. - - Args: - waveforms (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): - The batched output waveforms from the [`UnivNetModel`]. - waveform_lengths (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): - The batched lengths of each waveform before padding. - - Returns: - `list[np.ndarray]`: A ragged list of 1D waveform arrays with padding removed. - """ - # Collapse the batched waveform tensor to a list of 1D audio waveforms - waveforms = [waveform.detach().to(device="cpu", copy=True).numpy() for waveform in waveforms] - - if waveform_lengths is not None: - waveforms = [waveform[: waveform_lengths[i]] for i, waveform in enumerate(waveforms)] - - return waveforms - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - sampling_rate: int | None = None, - padding: bool | str | PaddingStrategy = True, - max_length: int | None = None, - truncation: bool = True, - pad_to_multiple_of: int | None = None, - return_noise: bool = True, - generator: np.random.Generator | None = None, - pad_end: bool = False, - pad_length: int | None = None, - do_normalize: str | None = None, - return_attention_mask: bool | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the input `raw_speech` waveforms (according to the model's padding side and - padding index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - - If `pad_end = True`, that padding will occur before the `padding` strategy is applied. - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`, *optional*, defaults to `True`): - Activates truncation to cut input sequences longer than `max_length` to `max_length`. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_noise (`bool`, *optional*, defaults to `True`): - Whether to generate and return a noise waveform for use in [`UnivNetModel.forward`]. - generator (`numpy.random.Generator`, *optional*, defaults to `None`): - An optional `numpy.random.Generator` random number generator to use when generating noise. - pad_end (`bool`, *optional*, defaults to `False`): - Whether to pad the end of each waveform with silence. This can help reduce artifacts at the end of the - generated audio sample; see https://github.com/seungwonpark/melgan/issues/8 for more details. This - padding will be done before the padding strategy specified in `padding` is performed. - pad_length (`int`, *optional*, defaults to `None`): - If padding the end of each waveform, the length of the padding in spectrogram frames. If not set, this - will default to `self.config.pad_end_length`. - do_normalize (`bool`, *optional*): - Whether to perform Tacotron 2 normalization on the input. Normalizing can help to significantly improve - the performance for some models. If not set, this will default to `self.config.do_normalize`. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.np.array` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - """ - do_normalize = do_normalize if do_normalize is not None else self.do_normalize - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray(speech, dtype=np.float32) for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [np.asarray(raw_speech, dtype=np.float32)] - - # Pad end to reduce artifacts - if pad_end: - pad_length = pad_length if pad_length is not None else self.pad_end_length - raw_speech = [ - np.pad(waveform, (0, pad_length * self.hop_length), constant_values=self.padding_value) - for waveform in raw_speech - ] - - batched_speech = BatchFeature({"input_features": raw_speech}) - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length if max_length is not None else self.num_max_samples, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - - # make sure list is in array format - # input_features = padded_inputs.get("input_features").transpose(2, 0, 1) - input_features = padded_inputs.get("input_features") - - mel_spectrograms = [self.mel_spectrogram(waveform) for waveform in input_features] - - if isinstance(input_features[0], list): - batched_speech["input_features"] = [np.asarray(mel, dtype=np.float32) for mel in mel_spectrograms] - else: - batched_speech["input_features"] = [mel.astype(np.float32) for mel in mel_spectrograms] - - # convert attention_mask to correct format - attention_mask = padded_inputs.get("attention_mask") - if attention_mask is not None: - batched_speech["padding_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] - - if return_noise: - noise = [ - self.generate_noise(spectrogram.shape[0], generator) - for spectrogram in batched_speech["input_features"] - ] - batched_speech["noise_sequence"] = noise - - if do_normalize: - batched_speech["input_features"] = [ - self.normalize(spectrogram) for spectrogram in batched_speech["input_features"] - ] - - if return_tensors is not None: - batched_speech = batched_speech.convert_to_tensors(return_tensors) - - return batched_speech - - def to_dict(self) -> dict[str, Any]: - output = super().to_dict() - - # Don't serialize these as they are derived from the other properties. - names = ["window", "mel_filters", "n_fft", "n_freqs", "num_max_samples"] - for name in names: - if name in output: - del output[name] - - return output +UnivNetFeatureExtractor = deprecated_feature_extractor(UnivNetAudioProcessor, "UnivNetFeatureExtractor") __all__ = ["UnivNetFeatureExtractor"] diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py index c42db2bc0d74..0f38cd9df814 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/feature_extraction_vibevoice_acoustic_tokenizer.py @@ -11,134 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_vibevoice_acoustic_tokenizer import VibevoiceAcousticTokenizerAudioProcessor -from ...audio_utils import AudioInput, make_list_of_audio -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, logging -from ...utils.import_utils import is_torch_available, requires - - -if is_torch_available(): - import torch - -logger = logging.get_logger(__name__) - - -@requires(backends=("torch",)) -class VibeVoiceAcousticTokenizerFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a VibeVoiceAcousticTokenizer feature extractor. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The number of channels. - sampling_rate (`int`, *optional*, defaults to 24000): - The sampling rate at which the audio waveform should be digitalized, expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used for padding. - normalize_audio (`bool`, *optional*, defaults to `True`): - Whether to normalize audio to a target dB FS. - target_dB_FS (`float`, *optional*, defaults to -25): - Target dB FS for normalization. - eps (`float`, *optional*, defaults to 1e-06): - A small value to avoid division by zero when normalizing. - - """ - - model_input_names = ["input_values", "padding_mask"] - - def __init__( - self, - feature_size=1, - sampling_rate=24000, - padding_value=0.0, - normalize_audio=True, - target_dB_FS=-25, - eps=1e-6, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - self.normalize_audio = normalize_audio - self.target_dB_FS = target_dB_FS - self.eps = eps - - def __call__( - self, - audio: AudioInput, - sampling_rate: int | None = None, - padding: bool | str | PaddingStrategy | None = True, - pad_to_multiple_of: int | None = None, - return_attention_mask: bool | None = True, - ) -> BatchFeature: - """ - Args: - audio (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`: - The sequence or batch of sequences to be processed. Each sequence can be a numpy array, a torch tensor, - a list of numpy arrays or a list of torch tensors. - sampling_rate (`int`, *optional*): - The sampling rate at which the `audio` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided audio input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - # Ensure batch of mono tensors - audio = make_list_of_audio(audio) - for idx, example in enumerate(audio): - example = torch.tensor(example, dtype=torch.float32) - if example.ndim != 1: - raise ValueError(f"Audio should be mono, got shape: {example.shape}") - audio[idx] = example - - if self.normalize_audio: - for idx, example in enumerate(audio): - rms = torch.sqrt(torch.mean(example**2)) - example *= 10 ** (self.target_dB_FS / 20) / (rms + self.eps) - max_val = torch.max(torch.abs(example)) - if max_val > 1.0: - example = example / (max_val + self.eps) - audio[idx] = example - - output_values = BatchFeature({"input_values": audio}) - if padding or pad_to_multiple_of: - output_values = self.pad( - output_values, - padding=padding, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - if return_attention_mask: - output_values["padding_mask"] = output_values.pop("attention_mask") - - # add channel dimension - # output_values["input_values"] = output_values["input_values"][:, None, :] - - return output_values +VibeVoiceAcousticTokenizerFeatureExtractor = deprecated_feature_extractor( + VibevoiceAcousticTokenizerAudioProcessor, "VibeVoiceAcousticTokenizerFeatureExtractor" +) __all__ = ["VibeVoiceAcousticTokenizerFeatureExtractor"] diff --git a/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py index 58355f3c0d7c..09e49995be51 100644 --- a/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/feature_extraction_voxtral_realtime.py @@ -11,236 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_voxtral_realtime import VoxtralRealtimeAudioProcessor -import numpy as np -import torch - -from ...audio_utils import mel_filter_bank -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging -from ...utils.import_utils import requires - - -logger = logging.get_logger(__name__) - - -@requires(backends=("torch",)) -class VoxtralRealtimeFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a VOXTRAL_REALTIME feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time - Fourier Transform` which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 128): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 160): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - n_fft (`int`, *optional*, defaults to 512): - Size of the Fourier transform. - win_length (`int`, *optional*, defaults to 400): - The window length for the STFT computation. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - """ - - model_input_names = ["input_features", "attention_mask"] - - def __init__( - self, - feature_size=128, - sampling_rate=16000, - hop_length=160, - n_fft=400, - win_length=400, - padding_value=0.0, - global_log_mel_max=1.5, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - - self.hop_length = hop_length - self.n_fft = n_fft - self.win_length = win_length - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + n_fft // 2, - num_mel_filters=feature_size, - min_frequency=0.0, - max_frequency=8000.0, - sampling_rate=sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - self.global_log_mel_max = global_log_mel_max - - def _torch_extract_fbank_features(self, waveform, device: str = "cpu", center: bool = True): - window = torch.hann_window(self.n_fft, device=device) - stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True, center=center) - magnitudes = stft[..., :-1].abs() ** 2 - - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - mel_spec = mel_filters.T @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - if self.global_log_mel_max is not None: - log_spec_max = torch.tensor( - self.global_log_mel_max, - device=log_spec.device, - dtype=log_spec.dtype, - ) - else: - log_spec_max = log_spec.max() - - log_spec = torch.maximum(log_spec, log_spec_max - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - if device != "cpu": - log_spec = log_spec.detach().cpu() - return log_spec - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = None, - padding: str | None = "longest", - max_length: int | None = None, - sampling_rate: int | None = None, - do_normalize: bool | None = None, - device: str | None = "cpu", - return_token_timestamps: bool | None = None, - center: bool = True, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for - the STFT computation if available, otherwise a slower NumPy based one. - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For Parakeet models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'tf'`: Return TensorFlow `tf.constant` objects. - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values / vectors. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance of the model. - device (`str`, *optional*, defaults to `'cpu'`): - Specifies the device for computation of the log-mel spectrogram of audio signals in the - `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") - return_token_timestamps (`bool`, *optional*, defaults to `None`): - Deprecated. Use `return_attention_mask` instead from which the number of frames can be inferred. - - Whether or not to return the number of frames of the input raw_speech. - These num_frames can be used by the model to compute word level timestamps. - center (`bool`, *optional*, defaults to `True`): - Whether to use centering for the STFT computation. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - # Convert to torch tensor - if isinstance(raw_speech, np.ndarray): - raw_speech = torch.tensor(raw_speech) - elif isinstance(raw_speech, (list, tuple)): - if isinstance(raw_speech[0], (list, np.ndarray)): - raw_speech = [torch.tensor(speech) for speech in raw_speech] - else: # list[float] - raw_speech = torch.tensor(raw_speech) - - is_batched_torch = isinstance(raw_speech, torch.Tensor) and len(raw_speech.shape) > 1 - if is_batched_torch and len(raw_speech.shape) > 2: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - raw_speech = raw_speech.mean(-1) - - is_batched_sequence = isinstance(raw_speech, (list, tuple)) - if is_batched_sequence: - for speech in raw_speech: - if len(speech.shape) > 1: - logger.warning( - f"Only mono-channel audio is supported for input to {self.__class__.__name__}. " - "We will take the mean of the channels to convert to mono." - ) - speech = speech.mean(-1) - - if is_batched_torch or is_batched_sequence: - raw_speech = [speech[:, None].to(torch.float32) for speech in raw_speech] - else: - raw_speech = [raw_speech[:, None].to(torch.float32)] - - batched_speech = BatchFeature({"input_features": raw_speech}) - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - return_tensors="pt", - ) - input_features = padded_inputs.input_features.squeeze(-1) - input_features = self._torch_extract_fbank_features(input_features, device, center) - data = { - "input_features": input_features.to(torch.float32), - } - - if return_attention_mask: - attention_mask = padded_inputs.attention_mask[:, self.win_length - 1 :: self.hop_length] - data["attention_mask"] = attention_mask.to(torch.bool) - - return BatchFeature(data=data, tensor_type=return_tensors) +VoxtralRealtimeFeatureExtractor = deprecated_feature_extractor( + VoxtralRealtimeAudioProcessor, "VoxtralRealtimeFeatureExtractor" +) __all__ = ["VoxtralRealtimeFeatureExtractor"] diff --git a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py index dea2f3af5b48..bc4c8fdee07e 100644 --- a/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py +++ b/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py @@ -11,229 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for Wav2Vec2 -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_wav2vec2 import Wav2Vec2AudioProcessor -import numpy as np - -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import PaddingStrategy, TensorType, logging - - -logger = logging.get_logger(__name__) - - -class Wav2Vec2FeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Wav2Vec2 feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - Args: - feature_size (`int`, *optional*, defaults to 1): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - padding_value (`float`, *optional*, defaults to 0.0): - The value that is used to fill the padding values. - do_normalize (`bool`, *optional*, defaults to `True`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance for some models, *e.g.*, - [wav2vec2-lv60](https://huggingface.co/models?search=lv60). - return_attention_mask (`bool`, *optional*, defaults to `False`): - Whether or not [`~Wav2Vec2FeatureExtractor.__call__`] should return `attention_mask`. - - - - Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as - [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using - `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask` - should be passed. - - For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as - [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be - passed for batched inference. - - """ - - model_input_names = ["input_values", "attention_mask"] - - def __init__( - self, - feature_size=1, - sampling_rate=16000, - padding_value=0.0, - return_attention_mask=False, - do_normalize=True, - **kwargs, - ): - super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs) - self.return_attention_mask = return_attention_mask - self.do_normalize = do_normalize - - @staticmethod - def zero_mean_unit_var_norm( - input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 - ) -> list[np.ndarray]: - """ - Every array in the list is normalized to have zero mean and unit variance - """ - if attention_mask is not None: - attention_mask = np.array(attention_mask, np.int32) - normed_input_values = [] - - for vector, length in zip(input_values, attention_mask.sum(-1)): - normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) - if length < normed_slice.shape[0]: - normed_slice[length:] = padding_value - - normed_input_values.append(normed_slice) - else: - normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] - - return normed_input_values - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - padding: bool | str | PaddingStrategy = False, - max_length: int | None = None, - truncation: bool = False, - pad_to_multiple_of: int | None = None, - return_attention_mask: bool | None = None, - return_tensors: str | TensorType | None = None, - sampling_rate: int | None = None, - **kwargs, - ) -> BatchFeature: - """ - Main method to featurize and prepare for the model one or several sequence(s). - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): - Select a strategy to pad the returned sequences (according to the model's padding side and padding - index) among: - - - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single - sequence if provided). - - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum - acceptable input length for the model if that argument is not provided. - - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different - lengths). - max_length (`int`, *optional*): - Maximum length of the returned list and optionally padding length (see above). - truncation (`bool`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as - [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using - `attention_mask`. For such models, `input_values` should simply be padded with 0 and no - `attention_mask` should be passed. - - For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as - [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should - be passed for batched inference. - - - - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors. - padding_value (`float`, *optional*, defaults to 0.0): - """ - - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of" - f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with" - f" {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - # always return batch - if not is_batched: - raw_speech = [raw_speech] - - # convert into correct format for padding - encoded_inputs = BatchFeature({"input_values": raw_speech}) - - padded_inputs = self.pad( - encoded_inputs, - padding=padding, - max_length=max_length, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask, - ) - - # convert input values to correct format - input_values = padded_inputs["input_values"] - if not isinstance(input_values[0], np.ndarray): - padded_inputs["input_values"] = [np.asarray(array, dtype=np.float32) for array in input_values] - elif ( - not isinstance(input_values, np.ndarray) - and isinstance(input_values[0], np.ndarray) - and input_values[0].dtype is np.dtype(np.float64) - ): - padded_inputs["input_values"] = [array.astype(np.float32) for array in input_values] - elif isinstance(input_values, np.ndarray) and input_values.dtype is np.dtype(np.float64): - padded_inputs["input_values"] = input_values.astype(np.float32) - - # convert attention_mask to correct format - attention_mask = padded_inputs.get("attention_mask") - if attention_mask is not None: - padded_inputs["attention_mask"] = [np.asarray(array, dtype=np.int32) for array in attention_mask] - - # zero-mean and unit-variance normalization - if self.do_normalize: - attention_mask = ( - attention_mask - if self._get_padding_strategies(padding, max_length=max_length) is not PaddingStrategy.DO_NOT_PAD - else None - ) - padded_inputs["input_values"] = self.zero_mean_unit_var_norm( - padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value - ) - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +Wav2Vec2FeatureExtractor = deprecated_feature_extractor(Wav2Vec2AudioProcessor, "Wav2Vec2FeatureExtractor") __all__ = ["Wav2Vec2FeatureExtractor"] diff --git a/src/transformers/models/whisper/feature_extraction_whisper.py b/src/transformers/models/whisper/feature_extraction_whisper.py index 4151a3824dfd..4e4f49df3c2d 100644 --- a/src/transformers/models/whisper/feature_extraction_whisper.py +++ b/src/transformers/models/whisper/feature_extraction_whisper.py @@ -11,335 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Feature extractor class for Whisper -""" +from ...utils.deprecation import deprecated_feature_extractor +from .audio_processing_whisper import WhisperAudioProcessor -import numpy as np - -from ... import is_torch_available -from ...audio_utils import mel_filter_bank, spectrogram, window_function -from ...feature_extraction_sequence_utils import SequenceFeatureExtractor -from ...feature_extraction_utils import BatchFeature -from ...utils import TensorType, logging - - -if is_torch_available(): - import torch - -logger = logging.get_logger(__name__) - - -class WhisperFeatureExtractor(SequenceFeatureExtractor): - r""" - Constructs a Whisper feature extractor. - - This feature extractor inherits from [`~feature_extraction_sequence_utils.SequenceFeatureExtractor`] which contains - most of the main methods. Users should refer to this superclass for more information regarding those methods. - - This class extracts mel-filter bank features from raw speech using a custom numpy implementation of the `Short Time - Fourier Transform` which should match pytorch's `torch.stft` equivalent. - - Args: - feature_size (`int`, *optional*, defaults to 80): - The feature dimension of the extracted features. - sampling_rate (`int`, *optional*, defaults to 16000): - The sampling rate at which the audio files should be digitalized expressed in hertz (Hz). - hop_length (`int`, *optional*, defaults to 160): - Length of the overlapping windows for the STFT used to obtain the Mel Frequency coefficients. - chunk_length (`int`, *optional*, defaults to 30): - The maximum number of chunks of `sampling_rate` samples used to trim and pad longer or shorter audio - sequences. - n_fft (`int`, *optional*, defaults to 400): - Size of the Fourier transform. - padding_value (`float`, *optional*, defaults to 0.0): - Padding value used to pad the audio. Should correspond to silences. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering. In other words, adds a small Gaussian noise to each frame. - E.g. use 0.0001 to add dithering with a normal distribution centered - around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range of raw_speech). - The value 0.0 means no dithering. - Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces - the high log_mel_fbank values for signals with hard-zero sections, - when VAD cutoff is present in the signal. - """ - - model_input_names = ["input_features"] - - def __init__( - self, - feature_size=80, - sampling_rate=16000, - hop_length=160, - chunk_length=30, - n_fft=400, - padding_value=0.0, - dither=0.0, - return_attention_mask=False, # pad inputs to max length with silence token (zero) and no attention mask - **kwargs, - ): - super().__init__( - feature_size=feature_size, - sampling_rate=sampling_rate, - padding_value=padding_value, - return_attention_mask=return_attention_mask, - **kwargs, - ) - self.n_fft = n_fft - self.hop_length = hop_length - self.chunk_length = chunk_length - self.n_samples = chunk_length * sampling_rate - self.nb_max_frames = self.n_samples // hop_length - self.sampling_rate = sampling_rate - self.dither = dither - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + n_fft // 2, - num_mel_filters=feature_size, - min_frequency=0.0, - max_frequency=8000.0, - sampling_rate=sampling_rate, - norm="slaney", - mel_scale="slaney", - ) - - def _np_extract_fbank_features(self, waveform_batch: np.ndarray, device: str) -> np.ndarray: - """ - Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch - implementation with 1e-5 tolerance. - """ - if device != "cpu": - raise ValueError( - f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator " - "devices requires torch, which is not installed. Either set `device='cpu'`, or " - "install torch according to the official instructions: https://pytorch.org/get-started/locally/" - ) - log_spec_batch = [] - for waveform in waveform_batch: - log_spec = spectrogram( - waveform, - window_function(self.n_fft, "hann"), - frame_length=self.n_fft, - hop_length=self.hop_length, - power=2.0, - dither=self.dither, - mel_filters=self.mel_filters, - log_mel="log10", - ) - log_spec = log_spec[:, :-1] - log_spec = np.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - log_spec_batch.append(log_spec) - log_spec_batch = np.array(log_spec_batch) - return log_spec_batch - - def _torch_extract_fbank_features(self, waveform: np.ndarray, device: str = "cpu") -> np.ndarray: - """ - Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching, - yielding results similar to cpu computing with 1e-5 tolerance. - """ - waveform = torch.from_numpy(waveform).to(device, torch.float32) - window = torch.hann_window(self.n_fft, device=device) - - # Note: it would be better to dither the chunked waveform, - # so overlapping signal does not get the same dithering. - # But, chunking is happening inside pytorch, so it is here. - if self.dither != 0.0: - waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device) - - stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True) - magnitudes = stft[..., :-1].abs() ** 2 - - mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32) - mel_spec = mel_filters.T @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - if waveform.dim() == 2: - max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0] - log_spec = torch.maximum(log_spec, max_val - 8.0) - else: - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - if device != "cpu": - log_spec = log_spec.detach().cpu() - return log_spec.numpy() - - @staticmethod - # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm - def zero_mean_unit_var_norm( - input_values: list[np.ndarray], attention_mask: list[np.ndarray], padding_value: float = 0.0 - ) -> list[np.ndarray]: - """ - Every array in the list is normalized to have zero mean and unit variance - """ - if attention_mask is not None: - attention_mask = np.array(attention_mask, np.int32) - normed_input_values = [] - - for vector, length in zip(input_values, attention_mask.sum(-1)): - normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7) - if length < normed_slice.shape[0]: - normed_slice[length:] = padding_value - - normed_input_values.append(normed_slice) - else: - normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values] - - return normed_input_values - - def __call__( - self, - raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], - truncation: bool = True, - pad_to_multiple_of: int | None = None, - return_tensors: str | TensorType | None = None, - return_attention_mask: bool | None = None, - padding: str | None = "max_length", - max_length: int | None = None, - sampling_rate: int | None = None, - do_normalize: bool | None = None, - device: str | None = "cpu", - **kwargs, - ) -> BatchFeature: - """Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch - for the STFT computation if available, otherwise a slower NumPy based one. - - Args: - raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`): - The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float - values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not - stereo, i.e. single float per timestep. - truncation (`bool`, *optional*, default to `True`): - Activates truncation to cut input sequences longer than *max_length* to *max_length*. - pad_to_multiple_of (`int`, *optional*, defaults to None): - If set will pad the sequence to a multiple of the provided value. - - This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability - `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. - return_tensors (`str` or [`~utils.TensorType`], *optional*): - If set, will return tensors instead of list of python integers. Acceptable values are: - - - `'pt'`: Return PyTorch `torch.Tensor` objects. - - `'np'`: Return Numpy `np.ndarray` objects. - return_attention_mask (`bool`, *optional*): - Whether to return the attention mask. If left to the default, will return the attention mask according - to the specific feature_extractor's default. - - [What are attention masks?](../glossary#attention-mask) - - - - For Whisper models, `attention_mask` should always be passed for batched inference, to avoid subtle - bugs. - - - padding (`str` or [`~utils.PaddingStrategy`], *optional*, defaults to `'max_length'`): - Activates and controls padding. Accepts the following values: - - - `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence is - provided). - - `'max_length'` (default): Pad to a maximum length specified with the argument `max_length` or to the - maximum acceptable input length for the model if that argument is not provided. - - `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths). - max_length (`int`, *optional*): - Controls the maximum length to use by one of the truncation/padding parameters. - - If left unset or set to `None`, this will use the predefined model maximum length if a maximum length - is required by one of the truncation/padding parameters. If the model has no specific maximum input - length (like XLNet) truncation/padding to a maximum length will be deactivated. - sampling_rate (`int`, *optional*): - The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass - `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition - pipeline. - do_normalize (`bool`, *optional*, defaults to `False`): - Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly - improve the performance of the model. - device (`str`, *optional*, defaults to `'cpu'`): - Specifies the device for computation of the log-mel spectrogram of audio signals in the - `_torch_extract_fbank_features` method. (e.g., "cpu", "cuda") - **kwargs: Not supported by WhisperFeatureExtractor.__call__() and ignored. - """ - if sampling_rate is not None: - if sampling_rate != self.sampling_rate: - raise ValueError( - f"The model corresponding to this feature extractor: {self.__class__.__name__} was trained using a" - f" sampling rate of {self.sampling_rate}. Please make sure that the provided `raw_speech` input" - f" was sampled with {self.sampling_rate} and not {sampling_rate}." - ) - else: - logger.warning( - f"It is strongly recommended to pass the `sampling_rate` argument to `{self.__class__.__name__}()`. " - "Failing to do so can result in silent errors that might be hard to debug." - ) - - is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1 - if is_batched_numpy and len(raw_speech.shape) > 2: - raise ValueError(f"Only mono-channel audio is supported for input to {self}") - is_batched = is_batched_numpy or ( - isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list))) - ) - - if is_batched: - raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech] - elif not is_batched and not isinstance(raw_speech, np.ndarray): - raw_speech = np.asarray(raw_speech, dtype=np.float32) - elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64): - raw_speech = raw_speech.astype(np.float32) - - # always return batch - if not is_batched: - raw_speech = [np.asarray([raw_speech]).T] - - batched_speech = BatchFeature({"input_features": raw_speech}) - - # convert into correct format for padding - - padded_inputs = self.pad( - batched_speech, - padding=padding, - max_length=max_length if max_length else self.n_samples, - truncation=truncation, - pad_to_multiple_of=pad_to_multiple_of, - return_attention_mask=return_attention_mask or do_normalize, - ) - - # zero-mean and unit-variance normalization - if do_normalize: - padded_inputs["input_features"] = self.zero_mean_unit_var_norm( - padded_inputs["input_features"], - attention_mask=padded_inputs["attention_mask"], - padding_value=self.padding_value, - ) - padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0) - - # make sure list is in array format - input_features = padded_inputs.get("input_features").transpose(2, 0, 1) - - extract_fbank_features = ( - self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features - ) - input_features = extract_fbank_features(input_features[0], device) - - if isinstance(input_features[0], list): - padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features] - - else: - padded_inputs["input_features"] = input_features - - if return_attention_mask: - # rescale from sample (48000) to feature (3000) - rescaled_attention_mask = padded_inputs["attention_mask"][:, :: self.hop_length] - - # The STFT computation produces L//hop_length + 1 frames, but we skip the last frame (see `_torch_extract_fbank_features`). - # This means we need to trim the rescaled attention mask to match the actual number of frames (L//hop_length) when the input length - # is not perfectly divisible by the hop length. - if padded_inputs["attention_mask"].shape[1] % self.hop_length != 0: - rescaled_attention_mask = rescaled_attention_mask[:, :-1] - padded_inputs["attention_mask"] = rescaled_attention_mask - - if return_tensors is not None: - padded_inputs = padded_inputs.convert_to_tensors(return_tensors) - - return padded_inputs +WhisperFeatureExtractor = deprecated_feature_extractor(WhisperAudioProcessor, "WhisperFeatureExtractor") __all__ = ["WhisperFeatureExtractor"] diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index db0e67325d78..98af20e5df77 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -33,6 +33,41 @@ class Action(ExplicitEnum): RAISE = "raise" +def deprecated_feature_extractor(audio_processor_class, old_class_name, version="4.55"): + """Create a deprecated FeatureExtractor alias for an AudioProcessor. + + Uses dynamic class creation to reduce boilerplate across ~20 models. + """ + + def __init__(self, *args, **kwargs): + warnings.warn( + f"`{old_class_name}` is deprecated and will be removed in v{version}. " + f"Use `{audio_processor_class.__name__}` instead.", + FutureWarning, + stacklevel=2, + ) + super(type(self), self).__init__(*args, **kwargs) + + def __init_subclass__(cls, **kwargs): + warnings.warn( + f"`{old_class_name}` is deprecated and will be removed in v{version}. " + f"Use `{audio_processor_class.__name__}` instead.", + FutureWarning, + ) + super(type(cls), cls).__init_subclass__(**kwargs) + + return type( + old_class_name, + (audio_processor_class,), + { + "__init__": __init__, + "__init_subclass__": __init_subclass__, + "__module__": audio_processor_class.__module__, + "__doc__": f"Deprecated. Use {audio_processor_class.__name__} instead.", + }, + ) + + def deprecate_kwarg( old_name: str, version: str, diff --git a/tests/test_audio_processors_vs_feature_extractors.py b/tests/test_audio_processors_vs_feature_extractors.py index 5a0e4595a8a6..0d5d4fdc472f 100644 --- a/tests/test_audio_processors_vs_feature_extractors.py +++ b/tests/test_audio_processors_vs_feature_extractors.py @@ -22,66 +22,137 @@ 4. Assert torch.equal on the main output tensors """ +import importlib +import os +import sys + import numpy as np import pytest import torch +# --------------------------------------------------------------------------- +# Feature extractor classes are loaded from ~/transformers/src (upstream). +# We temporarily swap sys.path and clear cached transformers modules so that +# ``import transformers.models.X.feature_extraction_X`` resolves to the +# upstream checkout rather than the locally-installed (audio-processors) copy. +# --------------------------------------------------------------------------- +_UPSTREAM_SRC = os.path.expanduser("~/transformers/src") + +_fe_class_specs = [ + ("transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer", "ASTFeatureExtractor"), + ("transformers.models.clap.feature_extraction_clap", "ClapFeatureExtractor"), + ("transformers.models.clvp.feature_extraction_clvp", "ClvpFeatureExtractor"), + ("transformers.models.dac.feature_extraction_dac", "DacFeatureExtractor"), + ("transformers.models.dia.feature_extraction_dia", "DiaFeatureExtractor"), + ("transformers.models.encodec.feature_extraction_encodec", "EncodecFeatureExtractor"), + ("transformers.models.granite_speech.feature_extraction_granite_speech", "GraniteSpeechFeatureExtractor"), + ("transformers.models.kyutai_speech_to_text.feature_extraction_kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"), + ("transformers.models.lasr.feature_extraction_lasr", "LasrFeatureExtractor"), + ("transformers.models.musicgen_melody.feature_extraction_musicgen_melody", "MusicgenMelodyFeatureExtractor"), + ("transformers.models.parakeet.feature_extraction_parakeet", "ParakeetFeatureExtractor"), + ("transformers.models.phi4_multimodal.feature_extraction_phi4_multimodal", "Phi4MultimodalFeatureExtractor"), + ("transformers.models.pop2piano.feature_extraction_pop2piano", "Pop2PianoFeatureExtractor"), + ("transformers.models.seamless_m4t.feature_extraction_seamless_m4t", "SeamlessM4TFeatureExtractor"), + ("transformers.models.speech_to_text.feature_extraction_speech_to_text", "Speech2TextFeatureExtractor"), + ("transformers.models.speecht5.feature_extraction_speecht5", "SpeechT5FeatureExtractor"), + ("transformers.models.univnet.feature_extraction_univnet", "UnivNetFeatureExtractor"), + ("transformers.models.vibevoice_acoustic_tokenizer.feature_extraction_vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerFeatureExtractor"), + ("transformers.models.voxtral_realtime.feature_extraction_voxtral_realtime", "VoxtralRealtimeFeatureExtractor"), + ("transformers.models.wav2vec2.feature_extraction_wav2vec2", "Wav2Vec2FeatureExtractor"), + ("transformers.models.whisper.feature_extraction_whisper", "WhisperFeatureExtractor"), +] + + +def _load_upstream_classes(class_specs): + """Load feature extractor classes from ~/transformers/src. + + Temporarily replaces the transformers package in sys.modules so that + imports resolve to the upstream checkout. + """ + # 1. Save and remove all cached transformers modules + saved_modules = {} + for key in list(sys.modules.keys()): + if key == "transformers" or key.startswith("transformers."): + saved_modules[key] = sys.modules.pop(key) + + # 2. Prepend upstream src to sys.path + sys.path.insert(0, _UPSTREAM_SRC) + + results = {} + try: + for module_path, class_name in class_specs: + mod = importlib.import_module(module_path) + results[class_name] = getattr(mod, class_name) + finally: + # 3. Remove upstream from sys.path + sys.path.remove(_UPSTREAM_SRC) + # 4. Clear all upstream-loaded transformers modules + for key in list(sys.modules.keys()): + if key == "transformers" or key.startswith("transformers."): + del sys.modules[key] + # 5. Restore the local project's transformers modules + sys.modules.update(saved_modules) + + return results + + +def _load_upstream_class(module_path, class_name): + """Load a single class from ~/transformers/src.""" + return _load_upstream_classes([(module_path, class_name)])[class_name] + + +# Load all FE classes from upstream in one batch +_fe_classes = _load_upstream_classes(_fe_class_specs) +ASTFeatureExtractor = _fe_classes["ASTFeatureExtractor"] +ClapFeatureExtractor = _fe_classes["ClapFeatureExtractor"] +ClvpFeatureExtractor = _fe_classes["ClvpFeatureExtractor"] +DacFeatureExtractor = _fe_classes["DacFeatureExtractor"] +DiaFeatureExtractor = _fe_classes["DiaFeatureExtractor"] +EncodecFeatureExtractor = _fe_classes["EncodecFeatureExtractor"] +GraniteSpeechFeatureExtractor = _fe_classes["GraniteSpeechFeatureExtractor"] +KyutaiSpeechToTextFeatureExtractor = _fe_classes["KyutaiSpeechToTextFeatureExtractor"] +LasrFeatureExtractor = _fe_classes["LasrFeatureExtractor"] +MusicgenMelodyFeatureExtractor = _fe_classes["MusicgenMelodyFeatureExtractor"] +ParakeetFeatureExtractor = _fe_classes["ParakeetFeatureExtractor"] +Phi4MultimodalFeatureExtractor = _fe_classes["Phi4MultimodalFeatureExtractor"] +Pop2PianoFeatureExtractor = _fe_classes["Pop2PianoFeatureExtractor"] +SeamlessM4TFeatureExtractor = _fe_classes["SeamlessM4TFeatureExtractor"] +Speech2TextFeatureExtractor = _fe_classes["Speech2TextFeatureExtractor"] +SpeechT5FeatureExtractor = _fe_classes["SpeechT5FeatureExtractor"] +UnivNetFeatureExtractor = _fe_classes["UnivNetFeatureExtractor"] +VibeVoiceAcousticTokenizerFeatureExtractor = _fe_classes["VibeVoiceAcousticTokenizerFeatureExtractor"] +VoxtralRealtimeFeatureExtractor = _fe_classes["VoxtralRealtimeFeatureExtractor"] +Wav2Vec2FeatureExtractor = _fe_classes["Wav2Vec2FeatureExtractor"] +WhisperFeatureExtractor = _fe_classes["WhisperFeatureExtractor"] + +# Audio processor imports (from local project) from transformers.models.audio_spectrogram_transformer.audio_processing_audio_spectrogram_transformer import ( AudioSpectrogramTransformerAudioProcessor, ) -from transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer import ( - ASTFeatureExtractor, -) from transformers.models.clap.audio_processing_clap import ClapAudioProcessor -from transformers.models.clap.feature_extraction_clap import ClapFeatureExtractor from transformers.models.clvp.audio_processing_clvp import ClvpAudioProcessor -from transformers.models.clvp.feature_extraction_clvp import ClvpFeatureExtractor from transformers.models.dac.audio_processing_dac import DacAudioProcessor -from transformers.models.dac.feature_extraction_dac import DacFeatureExtractor from transformers.models.dia.audio_processing_dia import DiaAudioProcessor -from transformers.models.dia.feature_extraction_dia import DiaFeatureExtractor from transformers.models.encodec.audio_processing_encodec import EncodecAudioProcessor -from transformers.models.encodec.feature_extraction_encodec import EncodecFeatureExtractor -from transformers.models.gemma3n.audio_processing_gemma3n import Gemma3nAudioProcessor -from transformers.models.gemma3n.feature_extraction_gemma3n import Gemma3nAudioFeatureExtractor from transformers.models.granite_speech.audio_processing_granite_speech import GraniteSpeechAudioProcessor -from transformers.models.granite_speech.feature_extraction_granite_speech import GraniteSpeechFeatureExtractor from transformers.models.kyutai_speech_to_text.audio_processing_kyutai_speech_to_text import ( KyutaiSpeechToTextAudioProcessor, ) -from transformers.models.kyutai_speech_to_text.feature_extraction_kyutai_speech_to_text import ( - KyutaiSpeechToTextFeatureExtractor, -) from transformers.models.lasr.audio_processing_lasr import LasrAudioProcessor -from transformers.models.lasr.feature_extraction_lasr import LasrFeatureExtractor from transformers.models.musicgen_melody.audio_processing_musicgen_melody import MusicgenMelodyAudioProcessor -from transformers.models.musicgen_melody.feature_extraction_musicgen_melody import MusicgenMelodyFeatureExtractor from transformers.models.parakeet.audio_processing_parakeet import ParakeetAudioProcessor -from transformers.models.parakeet.feature_extraction_parakeet import ParakeetFeatureExtractor from transformers.models.phi4_multimodal.audio_processing_phi4_multimodal import Phi4MultimodalAudioProcessor -from transformers.models.phi4_multimodal.feature_extraction_phi4_multimodal import Phi4MultimodalFeatureExtractor from transformers.models.pop2piano.audio_processing_pop2piano import Pop2PianoAudioProcessor -from transformers.models.pop2piano.feature_extraction_pop2piano import Pop2PianoFeatureExtractor from transformers.models.seamless_m4t.audio_processing_seamless_m4t import SeamlessM4tAudioProcessor -from transformers.models.seamless_m4t.feature_extraction_seamless_m4t import SeamlessM4TFeatureExtractor from transformers.models.speech_to_text.audio_processing_speech_to_text import SpeechToTextAudioProcessor -from transformers.models.speech_to_text.feature_extraction_speech_to_text import Speech2TextFeatureExtractor from transformers.models.speecht5.audio_processing_speecht5 import SpeechT5AudioProcessor -from transformers.models.speecht5.feature_extraction_speecht5 import SpeechT5FeatureExtractor from transformers.models.univnet.audio_processing_univnet import UnivNetAudioProcessor -from transformers.models.univnet.feature_extraction_univnet import UnivNetFeatureExtractor from transformers.models.vibevoice_acoustic_tokenizer.audio_processing_vibevoice_acoustic_tokenizer import ( VibevoiceAcousticTokenizerAudioProcessor, ) -from transformers.models.vibevoice_acoustic_tokenizer.feature_extraction_vibevoice_acoustic_tokenizer import ( - VibeVoiceAcousticTokenizerFeatureExtractor, -) from transformers.models.voxtral_realtime.audio_processing_voxtral_realtime import VoxtralRealtimeAudioProcessor -from transformers.models.voxtral_realtime.feature_extraction_voxtral_realtime import VoxtralRealtimeFeatureExtractor from transformers.models.wav2vec2.audio_processing_wav2vec2 import Wav2Vec2AudioProcessor -from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor from transformers.models.whisper.audio_processing_whisper import WhisperAudioProcessor -from transformers.models.whisper.feature_extraction_whisper import WhisperFeatureExtractor # Sentinel to exclude a key from default kwargs @@ -99,6 +170,7 @@ "ap_class": AudioSpectrogramTransformerAudioProcessor, "fe_output_key": "input_values", "sample_rate": 16000, + "atol": 1e-6, }, { "name": "clap", @@ -347,10 +419,29 @@ def test_audio_processor_matches_feature_extractor(config): "audio_input_features": "audio_features", } + # Mapping for attention mask and padding mask keys depending on the primary input key + mask_key_map = { + "input_values": "audio_values_mask", + "input_features": "audio_features_mask", + } + + # Find out if this output contains input_values or input_features (to key mask mapping) + has_input_values = "input_values" in fe_output + has_input_features = "input_features" in fe_output + for fe_key in fe_output.keys(): - if fe_key == "attention_mask" or fe_key == "padding_mask" or fe_key == "input_features_mask": - continue + # Remap the primary data keys ap_key = fe_to_ap_key_map.get(fe_key, fe_key) + + # Special handling for attention_mask and padding_mask mapping + if fe_key in ("attention_mask", "padding_mask"): + if has_input_values: + ap_key = mask_key_map["input_values"] + elif has_input_features: + ap_key = mask_key_map["input_features"] + else: + ap_key = fe_key # fallback/default + assert ap_key in ap_output, f"Key {ap_key} (from FE key {fe_key}) not found in audio processor output" fe_tensor = fe_output[fe_key] ap_tensor = ap_output[ap_key] @@ -363,6 +454,130 @@ def test_audio_processor_matches_feature_extractor(config): assert fe_tensor.shape == ap_tensor.shape, ( f"Shape mismatch for key '{fe_key}' (ap key '{ap_key}'): fe {fe_tensor.shape} vs ap {ap_tensor.shape}" ) - assert torch.equal(fe_tensor, ap_tensor), ( - f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}" - ) + atol = config.get("atol", 0.0) + if atol > 0: + assert torch.allclose(fe_tensor, ap_tensor, atol=atol, rtol=0), ( + f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}, atol={atol}" + ) + else: + assert torch.equal(fe_tensor, ap_tensor), ( + f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}" + ) + + +# --------------------------------------------------------------------------- +# Backward compatibility tests +# --------------------------------------------------------------------------- + +# Pairs of (fe_module_path, fe_class_name, ap_class) +_COMPAT_PAIRS = [ + ("transformers.models.whisper.feature_extraction_whisper", "WhisperFeatureExtractor", WhisperAudioProcessor), + ("transformers.models.clap.feature_extraction_clap", "ClapFeatureExtractor", ClapAudioProcessor), + ("transformers.models.encodec.feature_extraction_encodec", "EncodecFeatureExtractor", EncodecAudioProcessor), + ("transformers.models.dac.feature_extraction_dac", "DacFeatureExtractor", DacAudioProcessor), + ("transformers.models.wav2vec2.feature_extraction_wav2vec2", "Wav2Vec2FeatureExtractor", Wav2Vec2AudioProcessor), +] + + +@pytest.mark.parametrize( + "module_path, class_name, ap_class", + _COMPAT_PAIRS, + ids=[p[1] for p in _COMPAT_PAIRS], +) +class TestFeatureExtractorBackwardCompat: + """Tests that deprecated FeatureExtractor wrappers work correctly.""" + + def test_importable_and_warns(self, module_path, class_name, ap_class): + """Old class names are importable and emit FutureWarning.""" + import importlib + + mod = importlib.import_module(module_path) + fe_cls = getattr(mod, class_name) + with pytest.warns(FutureWarning, match="deprecated"): + fe_cls() + + def test_isinstance_check(self, module_path, class_name, ap_class): + """Deprecated FE instances pass isinstance checks against AudioProcessor.""" + import importlib + import warnings + + mod = importlib.import_module(module_path) + fe_cls = getattr(mod, class_name) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + fe = fe_cls() + assert isinstance(fe, ap_class) + assert issubclass(fe_cls, ap_class) + + def test_issubclass(self, module_path, class_name, ap_class): + """Deprecated FE class is a subclass of AudioProcessor.""" + import importlib + + mod = importlib.import_module(module_path) + fe_cls = getattr(mod, class_name) + assert issubclass(fe_cls, ap_class) + + +class TestBatchFeatureLegacyKeys: + """Tests that old output key names are accessible via BatchFeature.""" + + def setup_method(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + # Reset warned keys so each test gets fresh warnings + AudioBatchFeature._warned_keys.clear() + + def test_input_features_resolves_to_audio_features(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_features": np.array([1, 2, 3])}) + with pytest.warns(FutureWarning, match="input_features"): + result = bf["input_features"] + assert np.array_equal(result, np.array([1, 2, 3])) + + def test_input_values_resolves_to_audio_values(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_values": np.array([4, 5, 6])}) + with pytest.warns(FutureWarning, match="input_values"): + result = bf["input_values"] + assert np.array_equal(result, np.array([4, 5, 6])) + + def test_attention_mask_resolves_to_audio_features_mask(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_features": np.array([1]), "audio_features_mask": np.array([1, 1, 0])}) + with pytest.warns(FutureWarning, match="attention_mask"): + result = bf["attention_mask"] + assert np.array_equal(result, np.array([1, 1, 0])) + + def test_attention_mask_resolves_to_audio_values_mask(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_values": np.array([1]), "audio_values_mask": np.array([0, 1, 1])}) + with pytest.warns(FutureWarning, match="attention_mask"): + result = bf["attention_mask"] + assert np.array_equal(result, np.array([0, 1, 1])) + + def test_contains_legacy_key(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_features": np.array([1])}) + assert "input_features" in bf + assert "audio_features" in bf + assert "nonexistent_key" not in bf + + def test_warning_fires_once(self): + from transformers.audio_processing_base import BatchFeature as AudioBatchFeature + + bf = AudioBatchFeature({"audio_features": np.array([1, 2, 3])}) + with pytest.warns(FutureWarning, match="input_features"): + bf["input_features"] + # Second access should not warn + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + bf["input_features"] + future_warnings = [x for x in w if issubclass(x.category, FutureWarning) and "input_features" in str(x.message)] + assert len(future_warnings) == 0 From 0e1653324721a32665076ea30f054d8c1d26dbfc Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:20:22 +0100 Subject: [PATCH 14/28] ensure BC + deprecate --- src/transformers/audio_processing_base.py | 41 ++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/transformers/audio_processing_base.py b/src/transformers/audio_processing_base.py index c8fd46a98bef..6fba8be02082 100644 --- a/src/transformers/audio_processing_base.py +++ b/src/transformers/audio_processing_base.py @@ -13,7 +13,8 @@ # limitations under the License. import os -from typing import Any, TypeVar +import warnings +from typing import Any, ClassVar, TypeVar from .audio_utils import is_valid_audio, load_audio from .feature_extraction_utils import BatchFeature as BaseBatchFeature @@ -25,6 +26,13 @@ ) +_LEGACY_KEY_MAP = { + "input_features": "audio_features", + "input_values": "audio_values", + "audio_input_features": "audio_features", +} + + AudioProcessorType = TypeVar("AudioProcessorType", bound="AudioProcessingMixin") @@ -45,6 +53,37 @@ class BatchFeature(BaseBatchFeature): initialization. """ + _warned_keys: ClassVar[set] = set() + + def __getitem__(self, item): + if isinstance(item, str) and item not in self.data: + new_key = self._resolve_legacy_key(item) + if new_key is not None and new_key in self.data: + if item not in BatchFeature._warned_keys: + warnings.warn( + f"Accessing '{item}' is deprecated, use '{new_key}' instead.", + FutureWarning, + stacklevel=2, + ) + BatchFeature._warned_keys.add(item) + return self.data[new_key] + return super().__getitem__(item) + + def __contains__(self, item): + if item in self.data: + return True + new_key = self._resolve_legacy_key(item) + return new_key is not None and new_key in self.data + + def _resolve_legacy_key(self, old_key): + if old_key in ("attention_mask", "padding_mask"): + if "audio_features_mask" in self.data: + return "audio_features_mask" + if "audio_values_mask" in self.data: + return "audio_values_mask" + return None + return _LEGACY_KEY_MAP.get(old_key) + class AudioProcessingMixin(PreprocessingMixin): """ From 1d095f30087fd6a5a0728d4d1dd097b02f0c50a5 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:20:50 +0100 Subject: [PATCH 15/28] update audio processors --- ...rocessing_audio_spectrogram_transformer.py | 54 +++++++---------- .../models/clap/audio_processing_clap.py | 17 +----- .../models/clvp/audio_processing_clvp.py | 59 +++++-------------- .../models/dac/audio_processing_dac.py | 17 +++--- .../models/dia/audio_processing_dia.py | 19 +++--- .../encodec/audio_processing_encodec.py | 17 +++--- .../audio_processing_granite_speech.py | 2 +- .../audio_processing_kyutai_speech_to_text.py | 36 ++++++----- .../models/lasr/audio_processing_lasr.py | 5 ++ .../audio_processing_musicgen_melody.py | 2 +- .../audio_processing_phi4_multimodal.py | 41 +++++++++---- .../audio_processing_seamless_m4t.py | 43 ++++++-------- .../audio_processing_speech_to_text.py | 16 +++-- .../univnet/audio_processing_univnet.py | 23 ++++---- ...processing_vibevoice_acoustic_tokenizer.py | 6 ++ .../audio_processing_voxtral_realtime.py | 13 +++- .../whisper/audio_processing_whisper.py | 14 ++--- 17 files changed, 178 insertions(+), 206 deletions(-) diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index c7c58ba22743..c7dce5ad462d 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -17,19 +17,13 @@ from ...audio_processing_backends import NumpyAudioBackend, TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature -from ...utils import is_speech_available, is_torch_available +from ...utils import is_torch_available -if is_speech_available(): - import torchaudio.compliance.kaldi as ta_kaldi - -if is_torch_available(): - import torch - - -class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_speech_available() else TorchAudioBackend): +class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_torch_available() else TorchAudioBackend): sample_rate = 16000 force_mono = True + return_attention_mask = False max_length_frames = 1024 do_normalize = True @@ -61,29 +55,19 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_spee mel_floor=1.192092955078125e-07, ) - def _extract_fbank_features_torchaudio(self, waveform) -> np.ndarray: - """Extract mel-filter bank features using torchaudio Kaldi (matches ASTFeatureExtractor).""" - if isinstance(waveform, np.ndarray): - waveform = torch.from_numpy(waveform) - waveform = waveform.unsqueeze(0) - fbank = ta_kaldi.fbank( - waveform, - sample_frequency=self.sample_rate, - window_type="hanning", - num_mel_bins=128, - ) - return fbank.numpy() + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = [audio[i] for i in range(audio.shape[0])] + elif hasattr(audio, 'dim') and audio.dim() > 1: + audio = [audio[i] for i in range(audio.shape[0])] + elif not isinstance(audio, list): + audio = [audio] - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Compute spectrogram per-sample using the same method as ASTFeatureExtractor - if is_speech_available(): - # Use torchaudio Kaldi for exact match with ASTFeatureExtractor - features = [self._extract_fbank_features_torchaudio(waveform) for waveform in audio] - else: - # Use numpy spectrogram (fallback) - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - # (n_mels, frames) -> (frames, n_mels) - features = [f.T for f in features] + if spectrogram_config is None: + spectrogram_config = self.spectrogram_config + features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + # (n_mels, frames) -> (frames, n_mels) + features = [f.T for f in features] # Pad or truncate to max_length_frames padded = [] @@ -100,8 +84,12 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of if self.do_normalize: padded = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] - stacked = np.stack(padded, axis=0) - return BatchFeature({"audio_values": stacked}, tensor_type=return_tensors) + return np.stack(padded, axis=0) + + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # AST does all processing at the feature level in extract_spectrogram + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + return BatchFeature({"audio_values": features}, tensor_type=return_tensors) __all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index 7922773499ab..4358af841b2d 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function from ...feature_extraction_utils import BatchFeature @@ -89,19 +89,6 @@ def _pad_single_clap(self, audio: np.ndarray, max_length: int, padding_mode: str # For other modes, use standard padding via parent's _pad_single return super()._pad_single(audio, max_length) - def _mel_filter_bank(self, spectrogram_config): - stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config - return mel_filter_bank( - num_frequency_bins=(stft_cfg.n_fft // 2) + 1, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - ) - def _extract_single_mel(self, waveform, spectrogram_config=None): """Extract mel spectrogram for a single waveform using audio_utils.spectrogram.""" if spectrogram_config is None: @@ -229,7 +216,7 @@ def _preprocess( padding_strategy = PaddingStrategy.LONGEST # Default to longest for unknown string values else: padding_strategy = padding_mode - audio = self.pad(audio, padding_strategy, nb_max_samples, truncation=False, pad_to_multiple_of=pad_to_multiple_of) + audio, _audio_ranges = self.pad(audio, padding_strategy, nb_max_samples, truncation=False, pad_to_multiple_of=pad_to_multiple_of) # Process each waveform through CLAP's mel extraction (handles truncation internally) padded_inputs = [ diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index 624607bff742..bb2d134fa13f 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -15,17 +15,14 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function -from ...feature_extraction_utils import BatchFeature +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function class ClvpAudioProcessor(NumpyAudioBackend): sample_rate = 22050 force_mono = True - n_fft = 1024 - hop_length = 256 - n_mels = 80 max_length = 132300 # 6 seconds at 22050 Hz + truncation = True spectrogram_config = SpectrogramConfig( stft_config=StftConfig( @@ -40,6 +37,7 @@ class ClvpAudioProcessor(NumpyAudioBackend): f_max=8000.0, norm="slaney", mel_scale="htk", + frequency_bin_mode="linspace", ), log_mode="log", mel_floor=1e-5, @@ -49,24 +47,13 @@ def __init__(self, mel_norms=None, **kwargs): super().__init__(**kwargs) self.mel_norms = mel_norms - def _mel_filter_bank(self, spectrogram_config): - mel_cfg = spectrogram_config.mel_scale_config - stft_cfg = spectrogram_config.stft_config - return mel_filter_bank( - num_frequency_bins=1 + (stft_cfg.n_fft // 2), - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else 8000.0, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - ) - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): if spectrogram_config is None: spectrogram_config = self.spectrogram_config - if not isinstance(audio, list): + if isinstance(audio, np.ndarray) and audio.ndim > 1: + audio = [audio[i] for i in range(audio.shape[0])] + elif not isinstance(audio, list): audio = [audio] stft_cfg = spectrogram_config.stft_config @@ -87,34 +74,16 @@ def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): if self.mel_norms is not None: log_spec = log_spec / np.array(self.mel_norms)[:, None] - features.append(log_spec) - - return features - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Determine the raw-audio target length - if max_length is None: - max_length = self.max_length - - # Truncate to max_length first - audio = [a[..., :max_length] for a in audio] - - # Pad raw audio: if padding=True, pad to longest in batch; otherwise pad to max_length - if padding is True or padding == "longest": - pad_length = max(a.shape[-1] for a in audio) - else: - pad_length = max_length - audio = self.pad(audio, padding=True, max_length=pad_length) - - # Extract spectrogram via audio_utils (with mel_norms applied) - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) + features.append(log_spec.astype(np.float32)) - # Cast to float32 to match the legacy FeatureExtractor - features = [f.astype(np.float32) for f in features] + return np.stack(features, axis=0) if len(features) > 1 else features - output_key = "audio_features" - stacked = np.stack(features, axis=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + """CLVP uses raw-audio-level mask even for spectrogram output.""" + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} __all__ = ["ClvpAudioProcessor"] diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py index 077f00ea3697..80a8590c8c54 100644 --- a/src/transformers/models/dac/audio_processing_dac.py +++ b/src/transformers/models/dac/audio_processing_dac.py @@ -15,7 +15,6 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...feature_extraction_utils import BatchFeature class DacAudioProcessor(NumpyAudioBackend): @@ -23,14 +22,14 @@ class DacAudioProcessor(NumpyAudioBackend): force_mono = True add_channel_dim = True - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - lengths = [a.shape[-1] for a in audio] - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = max(a.shape[-1] for a in audio) - padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) - stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) - return output + def _to_batch(self, audio): + return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_values_mask": mask} __all__ = ["DacAudioProcessor"] diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py index ef1a0b38c6d0..5766acd746b5 100644 --- a/src/transformers/models/dia/audio_processing_dia.py +++ b/src/transformers/models/dia/audio_processing_dia.py @@ -15,7 +15,6 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...feature_extraction_utils import BatchFeature class DiaAudioProcessor(NumpyAudioBackend): @@ -24,16 +23,14 @@ class DiaAudioProcessor(NumpyAudioBackend): add_channel_dim = True pad_to_multiple_of = 512 - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - if pad_to_multiple_of is None: - pad_to_multiple_of = self.pad_to_multiple_of - lengths = [a.shape[-1] for a in audio] - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = max(a.shape[-1] for a in audio) - padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) - stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) - return output + def _to_batch(self, audio): + return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_values_mask": mask} __all__ = ["DiaAudioProcessor"] diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py index 89376fbe7d5b..4208cc5c1ec8 100644 --- a/src/transformers/models/encodec/audio_processing_encodec.py +++ b/src/transformers/models/encodec/audio_processing_encodec.py @@ -15,7 +15,6 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...feature_extraction_utils import BatchFeature class EncodecAudioProcessor(NumpyAudioBackend): @@ -23,14 +22,14 @@ class EncodecAudioProcessor(NumpyAudioBackend): force_mono = True add_channel_dim = True - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - lengths = [a.shape[-1] for a in audio] - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = max(a.shape[-1] for a in audio) - padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) - stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) - return output + def _to_batch(self, audio): + return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_values_mask": mask} __all__ = ["EncodecAudioProcessor"] diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index 3ea66476ae66..099b0bcedb48 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -70,7 +70,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of audio_lengths = [a.shape[-1] for a in audio] # Pad audio to longest in batch - audio = self.pad(audio, padding=True, max_length=max_length) + audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) # Stack and extract spectrogram audio_stacked = torch.stack(audio) diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py index fcee9eee0313..09713556577b 100644 --- a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py @@ -28,35 +28,33 @@ def __init__(self, audio_delay_seconds=2.5, audio_silence_prefix_seconds=1.0, ** self.audio_silence_prefix_seconds = audio_silence_prefix_seconds super().__init__(**kwargs) - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Track lengths for padding_mask - lengths = [a.shape[-1] for a in audio] + def _to_batch(self, audio): + return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - # Pad audio to batch longest - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = max(a.shape[-1] for a in audio) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_values_mask": mask} - # Create padding_mask (1 for real audio, 0 for padding) - padding_mask = np.array([[1] * l + [0] * (padded_length - l) for l in lengths]) + def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + # Pad audio to batch longest + audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = audio[0].shape[-1] - # Stack audio with channel dim - stacked = np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + stacked = self._to_batch(audio) + mask_dict = self._get_mask(audio_ranges, padded_length, do_extract_spectrogram=False, spectrogram_config=None) + audio_values_mask = mask_dict["audio_values_mask"] # Add silence prefix (left) and delay (right) padding pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate) pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate) if pad_left > 0 or pad_right > 0: - # Pad audio - audio_pad_width = [(0, 0), (0, 0), (pad_left, pad_right)] - stacked = np.pad(stacked, audio_pad_width, mode="constant", constant_values=0.0) - - # Pad padding_mask - mask_pad_width = [(0, 0), (pad_left, pad_right)] - padding_mask = np.pad(padding_mask, mask_pad_width, mode="constant", constant_values=0) + stacked = np.pad(stacked, [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0) + audio_values_mask = np.pad(audio_values_mask, [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0) - output = BatchFeature({"audio_values": stacked, "padding_mask": padding_mask}, tensor_type=return_tensors) - return output + return BatchFeature({"audio_values": stacked, "audio_values_mask": audio_values_mask}, tensor_type=return_tensors) __all__ = ["KyutaiSpeechToTextAudioProcessor"] diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index 5ba188a6e1d7..400e82d46829 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -62,6 +62,11 @@ def __init__(self, **kwargs): upper_edge_hertz=7500.0, ) + def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): + stft_cfg = spectrogram_config.stft_config + win_length = stft_cfg.win_length or stft_cfg.n_fft + return (audio_lengths - win_length) // stft_cfg.hop_length + 1 + def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): import torch diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py index 88723799ed11..0373dce62f86 100644 --- a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py @@ -81,7 +81,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of # Pad raw audio if padding: - audio = self.pad(audio, padding=True, max_length=max_length) + audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) # Extract chroma features features = self.extract_spectrogram(audio, spectrogram_config=None) diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index d778d5ebcc5a..1bdd232bb372 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from spectrograms import numpy_mel_spectrogram as _np_spec + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import mel_filter_bank +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature @@ -30,18 +32,33 @@ class Phi4MultimodalAudioProcessor(TorchAudioBackend): audio_compression_rate = 8 audio_downsample_rate = 1 audio_feat_stride = 1 + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig(n_fft=512), + mel_scale_config=MelScaleConfig( + n_mels=80, + f_min=0, + f_max=7690, + mel_scale="kaldi", + triangularize_in_mel_space=True, + ), + ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.mel_filters = mel_filter_bank( - num_frequency_bins=self.n_fft // 2 + 1, - num_mel_filters=self.n_mels, - min_frequency=self.mel_min_frequency, - max_frequency=self.mel_max_frequency, + def _mel_filter_bank(self, spectrogram_config): + import torch + + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + mel_filters_np = _np_spec.mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, sampling_rate=self.sample_rate, - triangularize_in_mel_space=True, - mel_scale="kaldi", + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) + return torch.from_numpy(mel_filters_np).to(torch.float32) def extract_spectrogram(self, audio, **kwargs): import torch @@ -84,7 +101,7 @@ def extract_spectrogram(self, audio, **kwargs): spec_power = torch.abs(S) ** 2 # Mel filterbank + log - mel_filters = torch.from_numpy(self.mel_filters).to(torch.float32) + mel_filters = self.mel_filters.to(torch.float32) log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) log_spec = torch.log(log_spec) @@ -118,7 +135,7 @@ def _preprocess( audio_lengths = torch.tensor([a.shape[-1] for a in audio]) # Pad and truncate - audio = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + audio, _audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) # Extract spectrogram features = self.extract_spectrogram(audio, audio_lengths=audio_lengths) diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index a4e96e8178c6..5ba746c95608 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function from ...feature_extraction_utils import BatchFeature @@ -52,19 +52,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.window = window_function(400, "povey", periodic=False) - def _mel_filter_bank(self, spectrogram_config): - mel_cfg = spectrogram_config.mel_scale_config - return mel_filter_bank( - num_frequency_bins=257, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate // 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - ) - def _extract_fbank_features(self, waveform): waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers features = spectrogram( @@ -84,7 +71,6 @@ def _extract_fbank_features(self, waveform): return features def feature_normalize(self, features): - # Per-mel-bin normalization with ddof=1 for variance normalized = [] for f in features: mean = np.expand_dims(f.mean(axis=0), 0) @@ -93,36 +79,43 @@ def feature_normalize(self, features): return normalized def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Extract Kaldi-style features matching the FE exactly + # Extract features from raw (unpadded) audio, then pad at feature level features = [self._extract_fbank_features(waveform) for waveform in audio] - - # Per-mel-bin normalization features = self.feature_normalize(features) - # Pad features to longest (pad_to_multiple_of=2 for stride) - max_len = max(f.shape[0] for f in features) + feature_lengths = [f.shape[0] for f in features] + + # Pad features to longest (pad_to_multiple_of stride) + max_len = max(feature_lengths) if max_len % self.stride != 0: max_len = ((max_len // self.stride) + 1) * self.stride padded = [] for f in features: if f.shape[0] < max_len: - pad_amount = max_len - f.shape[0] - f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + f = np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), mode="constant", constant_values=0.0) padded.append(f) - stacked = np.stack(padded, axis=0) # (batch, frames, n_mels) + stacked = np.stack(padded, axis=0) batch_size, num_frames, num_channels = stacked.shape + # Feature-level attention_mask + attention_mask = np.zeros((batch_size, num_frames), dtype=np.int32) + for i, length in enumerate(feature_lengths): + attention_mask[i, :length] = 1 + # Stride concatenation remainder = num_frames % self.stride if remainder != 0: stacked = stacked[:, : num_frames - remainder, :] + attention_mask = attention_mask[:, : num_frames - remainder] num_frames = num_frames - remainder stacked = stacked.reshape(batch_size, num_frames // self.stride, num_channels * self.stride) + indices = np.arange(0, num_frames) + attention_mask = attention_mask[:, indices % self.stride == 1] - output_key = "audio_features" - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + data = {"audio_features": stacked, "audio_features_mask": attention_mask} + return BatchFeature(data=data, tensor_type=return_tensors) __all__ = ["SeamlessM4tAudioProcessor"] diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index b1cf5ba1f4a0..3a075b43720b 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -96,7 +96,7 @@ def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, p return x.astype(np.float32) def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Extract Kaldi-style features matching the FE exactly + # Extract features from raw (unpadded) audio, then pad at feature level features = [self._extract_fbank_features(waveform) for waveform in audio] lengths = [f.shape[0] for f in features] @@ -105,8 +105,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of padded = [] for f in features: if f.shape[0] < max_len: - pad_amount = max_len - f.shape[0] - f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) + f = np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), mode="constant", constant_values=0.0) padded.append(f) # Utterance CMVN normalization @@ -115,9 +114,16 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of for f, length in zip(padded, lengths) ] - output_key = "audio_features" stacked = np.stack(normalized, axis=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + data = {"audio_features": stacked} + + if self.return_attention_mask: + attention_mask = np.zeros((len(lengths), max_len), dtype=np.int32) + for i, length in enumerate(lengths): + attention_mask[i, :length] = 1 + data["audio_features_mask"] = attention_mask + + return BatchFeature(data=data, tensor_type=return_tensors) __all__ = ["SpeechToTextAudioProcessor"] diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index 449646726727..65a25c85eeb7 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import mel_filter_bank, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function from ...feature_extraction_utils import BatchFeature @@ -34,19 +34,20 @@ class UnivNetAudioProcessor(NumpyAudioBackend): normalize_min = -11.512925148010254 normalize_max = 2.3143386840820312 max_length_s = 10 + spectrogram_config = SpectrogramConfig( + stft_config=StftConfig(n_fft=1024), + mel_scale_config=MelScaleConfig( + n_mels=100, + f_min=0.0, + f_max=12000.0, + mel_scale="slaney", + norm="slaney", + ), + ) def __init__(self, **kwargs): super().__init__(**kwargs) self.num_max_samples = self.max_length_s * self.sample_rate - self.mel_filters = mel_filter_bank( - num_frequency_bins=1 + self.n_fft // 2, - num_mel_filters=self.n_mels, - min_frequency=self.fmin, - max_frequency=self.fmax, - sampling_rate=self.sample_rate, - norm="slaney", - mel_scale="slaney", - ) self.window = window_function(self.n_fft, "hann", periodic=True) def mel_spectrogram(self, waveform): @@ -94,7 +95,7 @@ def extract_spectrogram(self, audio, *, spectrogram_config): def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, generator=None, **kwargs): # Pad raw audio if padding: - audio = self.pad(audio, padding=True, max_length=max_length) + audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) # Extract mel spectrograms features = self.extract_spectrogram(audio, spectrogram_config=None) diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 866113b39b82..0fbaec66b74c 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -34,5 +34,11 @@ def _process_audio(self, audio_el): audio_el = audio_el / (max_val + self.eps) return audio_el + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_values_mask": mask} + __all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index 60bead1ebc59..1554b3dcfbb1 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -14,8 +14,10 @@ import torch +from spectrograms import numpy_mel_spectrogram as _np_spec + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class VoxtralRealtimeAudioProcessor(TorchAudioBackend): @@ -73,11 +75,15 @@ def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): processed.append(spec) return processed + def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): + stft_cfg = spectrogram_config.stft_config + win_length = stft_cfg.win_length or stft_cfg.n_fft + return (audio_lengths - win_length) // stft_cfg.hop_length + 1 + def _mel_filter_bank(self, spectrogram_config): - """Override to use numpy mel_filter_bank for exact match with feature extractor.""" stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - mel_filters_np = mel_filter_bank( + mel_filters_np = _np_spec.mel_filter_bank( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -85,6 +91,7 @@ def _mel_filter_bank(self, spectrogram_config): sampling_rate=self.sample_rate, norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) return torch.from_numpy(mel_filters_np).to(torch.float32) diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index c7ea01d00de9..f2850bc10cb8 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -14,13 +14,16 @@ import torch +from spectrograms import numpy_mel_spectrogram as _np_spec + from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class WhisperAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True + return_attention_mask = False truncation = True max_length = 480000 # 30 seconds at 16000 Hz spectrogram_config = SpectrogramConfig( @@ -48,13 +51,9 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): return features def _mel_filter_bank(self, spectrogram_config): - """ - Override to use the same numpy-based mel filter bank as WhisperFeatureExtractor - for exact numerical compatibility. - """ stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - mel_filters = mel_filter_bank( + mel_filters_np = _np_spec.mel_filter_bank( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -62,8 +61,9 @@ def _mel_filter_bank(self, spectrogram_config): sampling_rate=self.sample_rate, norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, ) - return torch.from_numpy(mel_filters).to(torch.float32) + return torch.from_numpy(mel_filters_np).to(torch.float32) def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): """ From 4fcd3e2e6e5475cd465331b892ebbe912ca647b9 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:29:34 +0100 Subject: [PATCH 16/28] remove test files for another repo --- tests/test_audio_processing_common.py | 360 ----------- ..._audio_processors_vs_feature_extractors.py | 583 ------------------ 2 files changed, 943 deletions(-) delete mode 100644 tests/test_audio_processing_common.py delete mode 100644 tests/test_audio_processors_vs_feature_extractors.py diff --git a/tests/test_audio_processing_common.py b/tests/test_audio_processing_common.py deleted file mode 100644 index 9bc72955f5c6..000000000000 --- a/tests/test_audio_processing_common.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright 2025 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -import tempfile - -import numpy as np - -from transformers.testing_utils import ( - check_json_file_has_correct_format, - require_torch, -) -from transformers.utils import is_torch_available - - -if is_torch_available(): - import torch - - -def prepare_audio_inputs( - batch_size, - min_length=400, - max_length=2000, - num_channels=1, - equal_length=False, - numpify=False, - torchify=False, -): - """This function prepares a list of numpy arrays, or a list of PyTorch tensors if one specifies torchify=True. - - One can specify whether the audio inputs are of the same length or not. - """ - - assert not (numpify and torchify), "You cannot specify both numpy and PyTorch tensors at the same time" - - audio_inputs = [] - for _ in range(batch_size): - if equal_length: - length = max_length - else: - length = np.random.randint(min_length, max_length) - - if num_channels > 1: - audio_inputs.append(np.random.randn(length, num_channels).astype(np.float32)) - else: - audio_inputs.append(np.random.randn(length).astype(np.float32)) - - if torchify: - audio_inputs = [torch.from_numpy(audio) for audio in audio_inputs] - - return audio_inputs - - -class AudioProcessingTestMixin: - """Mixin class for testing audio processors, analogous to ``ImageProcessingTestMixin``. - - Subclasses must set the following in ``setUp``: - - * ``self.audio_processing_classes`` – ``dict[str, type]`` mapping backend name to class - * ``self.audio_processor_dict`` – kwargs to instantiate the processor - * ``self.audio_processor_tester`` – object with ``prepare_audio_inputs()`` and ``batch_size`` - """ - - # ─── Serialization ──────────────────────────────────────────────── - - def test_audio_processor_to_json_string(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor = audio_processing_class(**self.audio_processor_dict) - obj = json.loads(audio_processor.to_json_string()) - for key, value in self.audio_processor_dict.items(): - self.assertEqual(obj[key], value) - - def test_audio_processor_to_json_file(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor_first = audio_processing_class(**self.audio_processor_dict) - - with tempfile.TemporaryDirectory() as tmpdirname: - json_file_path = os.path.join(tmpdirname, "audio_processor.json") - audio_processor_first.to_json_file(json_file_path) - audio_processor_second = audio_processing_class.from_json_file(json_file_path) - - self.assertEqual(audio_processor_second.to_dict(), audio_processor_first.to_dict()) - - def test_audio_processor_from_and_save_pretrained(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor_first = audio_processing_class(**self.audio_processor_dict) - - with tempfile.TemporaryDirectory() as tmpdirname: - saved_file = audio_processor_first.save_pretrained(tmpdirname)[0] - check_json_file_has_correct_format(saved_file) - audio_processor_second = audio_processing_class.from_pretrained(tmpdirname) - - self.assertEqual(audio_processor_second.to_dict(), audio_processor_first.to_dict()) - - def test_init_without_params(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor = audio_processing_class() - self.assertIsNotNone(audio_processor) - - # ─── Backend equivalence ────────────────────────────────────────── - - @require_torch - def test_backends_equivalence(self): - if len(self.audio_processing_classes) < 2: - self.skipTest(reason="Skipping backends equivalence test as there are less than 2 backends") - - audio_input = np.random.randn(16000).astype(np.float32) - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - - encodings = {} - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor = audio_processing_class(**self.audio_processor_dict) - encodings[backend_name] = audio_processor(audio_input, sample_rate=sample_rate, return_tensors="pt") - - backend_names = list(encodings.keys()) - reference_backend = backend_names[0] - reference_key = list(encodings[reference_backend].keys())[0] - reference_values = encodings[reference_backend][reference_key] - for backend_name in backend_names[1:]: - torch.testing.assert_close(reference_values, encodings[backend_name][reference_key], atol=1e-5, rtol=1e-5) - - @require_torch - def test_backends_equivalence_batched(self): - if len(self.audio_processing_classes) < 2: - self.skipTest(reason="Skipping backends equivalence test as there are less than 2 backends") - - audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False) - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - - encodings = {} - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processor = audio_processing_class(**self.audio_processor_dict) - encodings[backend_name] = audio_processor(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - - backend_names = list(encodings.keys()) - reference_backend = backend_names[0] - reference_key = list(encodings[reference_backend].keys())[0] - reference_values = encodings[reference_backend][reference_key] - for backend_name in backend_names[1:]: - torch.testing.assert_close(reference_values, encodings[backend_name][reference_key], atol=1e-5, rtol=1e-5) - - # ─── Cross-backend save / load ──────────────────────────────────── - - def test_save_load_backends(self): - """Test that we can load audio processors saved by one backend with another.""" - if len(self.audio_processing_classes) < 2: - self.skipTest("Skipping backend save/load test as there are less than 2 backends") - - backend_names = list(self.audio_processing_classes.keys()) - - for backend1 in backend_names: - processor1 = self.audio_processing_classes[backend1](**self.audio_processor_dict) - - for backend2 in backend_names: - if backend1 == backend2: - continue - - with tempfile.TemporaryDirectory() as tmpdirname: - processor1.save_pretrained(tmpdirname) - processor2 = self.audio_processing_classes[backend2].from_pretrained(tmpdirname) - - dict1 = processor1.to_dict() - dict2 = processor2.to_dict() - common_keys = set(dict1) & set(dict2) - self.assertEqual( - {k: dict1[k] for k in common_keys}, - {k: dict2[k] for k in common_keys}, - f"Backends {backend1} and {backend2} differ in common keys", - ) - - # ─── Input type tests ───────────────────────────────────────────── - - @require_torch - def test_call_numpy(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False) - for audio in audio_inputs: - self.assertIsInstance(audio, np.ndarray) - - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - - # Test not batched input - encoded = audio_processing(audio_inputs[0], sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoded.keys())[0] - self.assertEqual(len(encoded[output_key].shape), 2) # (1, length) - - # Test batched - encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - self.assertEqual(encoded[output_key].shape[0], self.audio_processor_tester.batch_size) - - @require_torch - def test_call_pytorch(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=False, torchify=True) - - for audio in audio_inputs: - self.assertIsInstance(audio, torch.Tensor) - - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - - # Test not batched input - encoded = audio_processing(audio_inputs[0], sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoded.keys())[0] - self.assertEqual(len(encoded[output_key].shape), 2) - - # Test batched - encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - self.assertEqual(encoded[output_key].shape[0], self.audio_processor_tester.batch_size) - - @require_torch - def test_call_multichannel_force_mono(self): - """Test that multi-channel audio is correctly averaged to mono.""" - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - processor_dict = {**self.audio_processor_dict, "force_mono": True} - audio_processing = audio_processing_class(**processor_dict) - - audio_inputs = prepare_audio_inputs( - batch_size=self.audio_processor_tester.batch_size, - num_channels=2, - min_length=self.audio_processor_tester.min_length, - max_length=self.audio_processor_tester.max_length, - equal_length=True, - ) - - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoded.keys())[0] - # After force_mono, output should be 2D: (batch, length) - self.assertEqual(len(encoded[output_key].shape), 2) - - # ─── Padding tests ──────────────────────────────────────────────── - - @require_torch - def test_padding_right(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - processor_dict = {**self.audio_processor_dict, "padding_side": "right"} - audio_processing = audio_processing_class(**processor_dict) - - audio_inputs = [ - np.random.randn(100).astype(np.float32), - np.random.randn(200).astype(np.float32), - ] - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoded.keys())[0] - self.assertEqual(encoded[output_key].shape[-1], 200) - - @require_torch - def test_padding_left(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - processor_dict = {**self.audio_processor_dict, "padding_side": "left"} - audio_processing = audio_processing_class(**processor_dict) - - audio_inputs = [ - np.random.randn(100).astype(np.float32), - np.random.randn(200).astype(np.float32), - ] - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - encoded = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoded.keys())[0] - self.assertEqual(encoded[output_key].shape[-1], 200) - - # ─── Truncation tests ───────────────────────────────────────────── - - @require_torch - def test_truncation(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - - audio_inputs = [ - np.random.randn(500).astype(np.float32), - np.random.randn(1000).astype(np.float32), - ] - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - encoded = audio_processing( - audio_inputs, sample_rate=sample_rate, truncation=True, max_length=300, return_tensors="pt" - ) - output_key = list(encoded.keys())[0] - self.assertEqual(encoded[output_key].shape[-1], 300) - - @require_torch - def test_truncation_without_max_length_raises(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - - audio_inputs = [np.random.randn(500).astype(np.float32)] - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - with self.assertRaises(ValueError): - audio_processing( - audio_inputs, sample_rate=sample_rate, truncation=True, max_length=None, return_tensors="pt" - ) - - # ─── pad_to_multiple_of ─────────────────────────────────────────── - - @require_torch - def test_pad_to_multiple_of(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - - audio_inputs = [np.random.randn(100).astype(np.float32)] - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - encoded = audio_processing( - audio_inputs, - sample_rate=sample_rate, - truncation=True, - max_length=150, - pad_to_multiple_of=64, - return_tensors="pt", - ) - output_key = list(encoded.keys())[0] - # max_length=150 rounded up to next multiple of 64 → 192 - self.assertEqual(encoded[output_key].shape[-1] % 64, 0) - - # ─── Sample rate validation ─────────────────────────────────────── - - def test_wrong_sample_rate_raises(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - - audio_inputs = [np.random.randn(100).astype(np.float32)] - expected_sr = self.audio_processor_dict.get("sample_rate", 16000) - with self.assertRaises(ValueError): - audio_processing(audio_inputs, sample_rate=expected_sr + 1000, return_tensors="pt") - - # ─── Dtype casting ──────────────────────────────────────────────── - - @require_torch - def test_cast_dtype(self): - for backend_name, audio_processing_class in self.audio_processing_classes.items(): - audio_processing = audio_processing_class(**self.audio_processor_dict) - - audio_inputs = self.audio_processor_tester.prepare_audio_inputs(equal_length=True) - sample_rate = self.audio_processor_dict.get("sample_rate", 16000) - - encoding = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt") - output_key = list(encoding.keys())[0] - self.assertEqual(encoding[output_key].dtype, torch.float32) - - encoding = encoding.to(torch.float16) - self.assertEqual(encoding[output_key].dtype, torch.float16) - - encoding = audio_processing(audio_inputs, sample_rate=sample_rate, return_tensors="pt").to( - "cpu", torch.bfloat16 - ) - self.assertEqual(encoding[output_key].device, torch.device("cpu")) - self.assertEqual(encoding[output_key].dtype, torch.bfloat16) diff --git a/tests/test_audio_processors_vs_feature_extractors.py b/tests/test_audio_processors_vs_feature_extractors.py deleted file mode 100644 index 0d5d4fdc472f..000000000000 --- a/tests/test_audio_processors_vs_feature_extractors.py +++ /dev/null @@ -1,583 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Tests comparing the new AudioProcessor classes against the legacy FeatureExtractor classes. - -For each model, we: -1. Instantiate the FeatureExtractor via from_pretrained (from the Hub) -2. Instantiate the corresponding AudioProcessor directly -3. Run both on the same batched audio input -4. Assert torch.equal on the main output tensors -""" - -import importlib -import os -import sys - -import numpy as np -import pytest -import torch - -# --------------------------------------------------------------------------- -# Feature extractor classes are loaded from ~/transformers/src (upstream). -# We temporarily swap sys.path and clear cached transformers modules so that -# ``import transformers.models.X.feature_extraction_X`` resolves to the -# upstream checkout rather than the locally-installed (audio-processors) copy. -# --------------------------------------------------------------------------- -_UPSTREAM_SRC = os.path.expanduser("~/transformers/src") - -_fe_class_specs = [ - ("transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer", "ASTFeatureExtractor"), - ("transformers.models.clap.feature_extraction_clap", "ClapFeatureExtractor"), - ("transformers.models.clvp.feature_extraction_clvp", "ClvpFeatureExtractor"), - ("transformers.models.dac.feature_extraction_dac", "DacFeatureExtractor"), - ("transformers.models.dia.feature_extraction_dia", "DiaFeatureExtractor"), - ("transformers.models.encodec.feature_extraction_encodec", "EncodecFeatureExtractor"), - ("transformers.models.granite_speech.feature_extraction_granite_speech", "GraniteSpeechFeatureExtractor"), - ("transformers.models.kyutai_speech_to_text.feature_extraction_kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"), - ("transformers.models.lasr.feature_extraction_lasr", "LasrFeatureExtractor"), - ("transformers.models.musicgen_melody.feature_extraction_musicgen_melody", "MusicgenMelodyFeatureExtractor"), - ("transformers.models.parakeet.feature_extraction_parakeet", "ParakeetFeatureExtractor"), - ("transformers.models.phi4_multimodal.feature_extraction_phi4_multimodal", "Phi4MultimodalFeatureExtractor"), - ("transformers.models.pop2piano.feature_extraction_pop2piano", "Pop2PianoFeatureExtractor"), - ("transformers.models.seamless_m4t.feature_extraction_seamless_m4t", "SeamlessM4TFeatureExtractor"), - ("transformers.models.speech_to_text.feature_extraction_speech_to_text", "Speech2TextFeatureExtractor"), - ("transformers.models.speecht5.feature_extraction_speecht5", "SpeechT5FeatureExtractor"), - ("transformers.models.univnet.feature_extraction_univnet", "UnivNetFeatureExtractor"), - ("transformers.models.vibevoice_acoustic_tokenizer.feature_extraction_vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerFeatureExtractor"), - ("transformers.models.voxtral_realtime.feature_extraction_voxtral_realtime", "VoxtralRealtimeFeatureExtractor"), - ("transformers.models.wav2vec2.feature_extraction_wav2vec2", "Wav2Vec2FeatureExtractor"), - ("transformers.models.whisper.feature_extraction_whisper", "WhisperFeatureExtractor"), -] - - -def _load_upstream_classes(class_specs): - """Load feature extractor classes from ~/transformers/src. - - Temporarily replaces the transformers package in sys.modules so that - imports resolve to the upstream checkout. - """ - # 1. Save and remove all cached transformers modules - saved_modules = {} - for key in list(sys.modules.keys()): - if key == "transformers" or key.startswith("transformers."): - saved_modules[key] = sys.modules.pop(key) - - # 2. Prepend upstream src to sys.path - sys.path.insert(0, _UPSTREAM_SRC) - - results = {} - try: - for module_path, class_name in class_specs: - mod = importlib.import_module(module_path) - results[class_name] = getattr(mod, class_name) - finally: - # 3. Remove upstream from sys.path - sys.path.remove(_UPSTREAM_SRC) - # 4. Clear all upstream-loaded transformers modules - for key in list(sys.modules.keys()): - if key == "transformers" or key.startswith("transformers."): - del sys.modules[key] - # 5. Restore the local project's transformers modules - sys.modules.update(saved_modules) - - return results - - -def _load_upstream_class(module_path, class_name): - """Load a single class from ~/transformers/src.""" - return _load_upstream_classes([(module_path, class_name)])[class_name] - - -# Load all FE classes from upstream in one batch -_fe_classes = _load_upstream_classes(_fe_class_specs) -ASTFeatureExtractor = _fe_classes["ASTFeatureExtractor"] -ClapFeatureExtractor = _fe_classes["ClapFeatureExtractor"] -ClvpFeatureExtractor = _fe_classes["ClvpFeatureExtractor"] -DacFeatureExtractor = _fe_classes["DacFeatureExtractor"] -DiaFeatureExtractor = _fe_classes["DiaFeatureExtractor"] -EncodecFeatureExtractor = _fe_classes["EncodecFeatureExtractor"] -GraniteSpeechFeatureExtractor = _fe_classes["GraniteSpeechFeatureExtractor"] -KyutaiSpeechToTextFeatureExtractor = _fe_classes["KyutaiSpeechToTextFeatureExtractor"] -LasrFeatureExtractor = _fe_classes["LasrFeatureExtractor"] -MusicgenMelodyFeatureExtractor = _fe_classes["MusicgenMelodyFeatureExtractor"] -ParakeetFeatureExtractor = _fe_classes["ParakeetFeatureExtractor"] -Phi4MultimodalFeatureExtractor = _fe_classes["Phi4MultimodalFeatureExtractor"] -Pop2PianoFeatureExtractor = _fe_classes["Pop2PianoFeatureExtractor"] -SeamlessM4TFeatureExtractor = _fe_classes["SeamlessM4TFeatureExtractor"] -Speech2TextFeatureExtractor = _fe_classes["Speech2TextFeatureExtractor"] -SpeechT5FeatureExtractor = _fe_classes["SpeechT5FeatureExtractor"] -UnivNetFeatureExtractor = _fe_classes["UnivNetFeatureExtractor"] -VibeVoiceAcousticTokenizerFeatureExtractor = _fe_classes["VibeVoiceAcousticTokenizerFeatureExtractor"] -VoxtralRealtimeFeatureExtractor = _fe_classes["VoxtralRealtimeFeatureExtractor"] -Wav2Vec2FeatureExtractor = _fe_classes["Wav2Vec2FeatureExtractor"] -WhisperFeatureExtractor = _fe_classes["WhisperFeatureExtractor"] - -# Audio processor imports (from local project) -from transformers.models.audio_spectrogram_transformer.audio_processing_audio_spectrogram_transformer import ( - AudioSpectrogramTransformerAudioProcessor, -) -from transformers.models.clap.audio_processing_clap import ClapAudioProcessor -from transformers.models.clvp.audio_processing_clvp import ClvpAudioProcessor -from transformers.models.dac.audio_processing_dac import DacAudioProcessor -from transformers.models.dia.audio_processing_dia import DiaAudioProcessor -from transformers.models.encodec.audio_processing_encodec import EncodecAudioProcessor -from transformers.models.granite_speech.audio_processing_granite_speech import GraniteSpeechAudioProcessor -from transformers.models.kyutai_speech_to_text.audio_processing_kyutai_speech_to_text import ( - KyutaiSpeechToTextAudioProcessor, -) -from transformers.models.lasr.audio_processing_lasr import LasrAudioProcessor -from transformers.models.musicgen_melody.audio_processing_musicgen_melody import MusicgenMelodyAudioProcessor -from transformers.models.parakeet.audio_processing_parakeet import ParakeetAudioProcessor -from transformers.models.phi4_multimodal.audio_processing_phi4_multimodal import Phi4MultimodalAudioProcessor -from transformers.models.pop2piano.audio_processing_pop2piano import Pop2PianoAudioProcessor -from transformers.models.seamless_m4t.audio_processing_seamless_m4t import SeamlessM4tAudioProcessor -from transformers.models.speech_to_text.audio_processing_speech_to_text import SpeechToTextAudioProcessor -from transformers.models.speecht5.audio_processing_speecht5 import SpeechT5AudioProcessor -from transformers.models.univnet.audio_processing_univnet import UnivNetAudioProcessor -from transformers.models.vibevoice_acoustic_tokenizer.audio_processing_vibevoice_acoustic_tokenizer import ( - VibevoiceAcousticTokenizerAudioProcessor, -) -from transformers.models.voxtral_realtime.audio_processing_voxtral_realtime import VoxtralRealtimeAudioProcessor -from transformers.models.wav2vec2.audio_processing_wav2vec2 import Wav2Vec2AudioProcessor -from transformers.models.whisper.audio_processing_whisper import WhisperAudioProcessor - - -# Sentinel to exclude a key from default kwargs -_EXCLUDE = object() - -# Each entry is a dict with model config. Keys: -# name, hub_repo, fe_class, ap_class, fe_output_key, sample_rate -# fe_kwargs (optional): extra kwargs for the FE call (use _EXCLUDE to remove a default key) -# ap_kwargs (optional): extra kwargs for the AP call -MODEL_CONFIGS = [ - { - "name": "audio_spectrogram_transformer", - "hub_repo": "MIT/ast-finetuned-audioset-10-10-0.4593", - "fe_class": ASTFeatureExtractor, - "ap_class": AudioSpectrogramTransformerAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 16000, - "atol": 1e-6, - }, - { - "name": "clap", - "hub_repo": "laion/clap-htsat-unfused", - "fe_class": ClapFeatureExtractor, - "ap_class": ClapAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 48000, - }, - { - "name": "clvp", - "hub_repo": "susnato/clvp_dev", - "fe_class": ClvpFeatureExtractor, - "ap_class": ClvpAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 22050, - "ap_init_kwargs": { - "mel_norms": [-7.0095, -6.0832, -4.644, -3.3562, -2.4548, -2.0097, -1.6036, -1.8641, -2.3728, -2.3455, -2.5947, -2.6695, -2.7129, -2.8555, -3.0251, -3.0889, -3.4261, -3.6759, -4.078, -4.4624, -4.7812, -5.0075, -5.1284, -5.2717, -5.4006, -5.4993, -5.531, -5.5878, -5.6726, -5.7016, -5.7943, -5.8831, -5.9537, -5.9989, -6.0305, -6.0539, -6.0748, -6.1163, -6.1481, -6.2476, -6.3195, -6.4457, -6.5377, -6.611, -6.6481, -6.6671, -6.6539, -6.6499, -6.6794, -6.7833, -6.9307, -7.0818, -7.1894, -7.2439, -7.3168, -7.3779, -7.4491, -7.5233, -7.6224, -7.7473, -7.8994, -8.0604, -8.2181, -8.3998, -8.5556, -8.7161, -8.8481, -8.9582, -9.0371, -9.0867, -9.1546, -9.2038, -9.2334, -9.2292, -9.2304, -9.268, -9.3156, -9.3716, -9.4165, -9.4822], - }, - }, - { - "name": "dac", - "hub_repo": "descript/dac_16khz", - "fe_class": DacFeatureExtractor, - "ap_class": DacAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 16000, - }, - { - "name": "dia", - "hub_repo": "nari-labs/Dia-1.6B-0626", - "fe_class": DiaFeatureExtractor, - "ap_class": DiaAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 44100, - }, - { - "name": "encodec", - "hub_repo": "facebook/encodec_24khz", - "fe_class": EncodecFeatureExtractor, - "ap_class": EncodecAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 24000, - }, - # { - # "name": "gemma3n", - # "hub_repo": "google/gemma-3n-e4b-it", - # "fe_class": Gemma3nAudioFeatureExtractor, - # "ap_class": Gemma3nAudioProcessor, - # "fe_output_key": "input_features", - # "sample_rate": 16000, - # # AP now implements custom FFT with HTK preemphasis and FFT overdrive - # }, - { - "name": "granite_speech", - "hub_repo": "ibm-granite/granite-speech-3.2-8b", - "fe_class": GraniteSpeechFeatureExtractor, - "ap_class": GraniteSpeechAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - "fe_kwargs": {"sampling_rate": _EXCLUDE, "return_tensors": _EXCLUDE, "padding": _EXCLUDE}, - }, - { - "name": "kyutai_speech_to_text", - "hub_repo": "kyutai/stt-2.6b-en-trfs", - "fe_class": KyutaiSpeechToTextFeatureExtractor, - "ap_class": KyutaiSpeechToTextAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 24000, - # AP now implements 1-second delay padding - }, - { - "name": "lasr", - "hub_repo": None, - "fe_class": LasrFeatureExtractor, - "ap_class": LasrAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - }, - { - "name": "musicgen_melody", - "hub_repo": "facebook/musicgen-melody", - "fe_class": MusicgenMelodyFeatureExtractor, - "ap_class": MusicgenMelodyAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 32000, - }, - { - "name": "parakeet", - "hub_repo": "nvidia/parakeet-ctc-1.1b", - "fe_class": ParakeetFeatureExtractor, - "ap_class": ParakeetAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - # AP now implements preemphasis, natural log, and slaney mel filters - }, - { - "name": "phi4_multimodal", - "hub_repo": "microsoft/Phi-4-multimodal-instruct", - "fe_class": Phi4MultimodalFeatureExtractor, - "ap_class": Phi4MultimodalAudioProcessor, - "fe_output_key": "audio_input_features", - "sample_rate": 16000, - }, - # { - # "name": "pop2piano", - # "hub_repo": "sweetcocoa/pop2piano", - # "fe_class": Pop2PianoFeatureExtractor, - # "ap_class": Pop2PianoAudioProcessor, - # "fe_output_key": "input_features", - # "sample_rate": 22050, - # "fe_kwargs": {"sampling_rate": [22050, 22050]}, - # # Skipped: Requires essentia library - # }, - { - "name": "seamless_m4t", - "hub_repo": "facebook/hf-seamless-m4t-medium", - "fe_class": SeamlessM4TFeatureExtractor, - "ap_class": SeamlessM4tAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - # AP now implements Kaldi-style features with stride concatenation - }, - { - "name": "speech_to_text", - "hub_repo": "facebook/s2t-small-librispeech-asr", - "fe_class": Speech2TextFeatureExtractor, - "ap_class": SpeechToTextAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - }, - { - "name": "speecht5", - "hub_repo": "microsoft/speecht5_asr", - "fe_class": SpeechT5FeatureExtractor, - "ap_class": SpeechT5AudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 16000, - }, - # { - # "name": "univnet", - # "hub_repo": "dg845/univnet-dev", - # "fe_class": UnivNetFeatureExtractor, - # "ap_class": UnivNetAudioProcessor, - # "fe_output_key": "input_features", - # "sample_rate": 24000, - # }, - { - "name": "vibevoice_acoustic_tokenizer", - "hub_repo": "microsoft/VibeVoice-AcousticTokenizer", - "fe_class": VibeVoiceAcousticTokenizerFeatureExtractor, - "ap_class": VibevoiceAcousticTokenizerAudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 24000, - "fe_kwargs": {"return_tensors": _EXCLUDE, "padding": _EXCLUDE}, - }, - { - "name": "voxtral_realtime", - "hub_repo": "mistralai/Voxtral-Mini-4B-Realtime-2602", - "fe_class": VoxtralRealtimeFeatureExtractor, - "ap_class": VoxtralRealtimeAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - }, - { - "name": "wav2vec2", - "hub_repo": "facebook/wav2vec2-large-960h-lv60-self", - "fe_class": Wav2Vec2FeatureExtractor, - "ap_class": Wav2Vec2AudioProcessor, - "fe_output_key": "input_values", - "sample_rate": 16000, - }, - { - "name": "whisper", - "hub_repo": "openai/whisper-small", - "fe_class": WhisperFeatureExtractor, - "ap_class": WhisperAudioProcessor, - "fe_output_key": "input_features", - "sample_rate": 16000, - "ap_kwargs": {"max_length": None, "truncation": False}, - }, -] - - -def _make_audio_batch(sample_rate: int, seed: int = 42) -> list[np.ndarray]: - """Create a deterministic batched audio input: two clips of different lengths.""" - rng = np.random.default_rng(seed) - return [ - rng.standard_normal(sample_rate).astype(np.float32), # 1 second - rng.standard_normal(sample_rate * 2).astype(np.float32), # 2 seconds - ] - - -@pytest.mark.parametrize( - "config", - MODEL_CONFIGS, - ids=[c["name"] if isinstance(c, dict) else c.values[0]["name"] for c in MODEL_CONFIGS], -) -def test_audio_processor_matches_feature_extractor(config): - hub_repo = config["hub_repo"] - fe_class = config["fe_class"] - ap_class = config["ap_class"] - fe_output_key = config["fe_output_key"] - sample_rate = config["sample_rate"] - - # Instantiate feature extractor from the Hub (or with defaults if hub_repo is None) - if hub_repo is not None: - fe = fe_class.from_pretrained(hub_repo) - else: - fe = fe_class() - - # Instantiate audio processor directly - ap_init_kwargs = config.get("ap_init_kwargs", {}) - ap = ap_class(**ap_init_kwargs) - - # Create batched audio input (deterministic) - audio_batch = _make_audio_batch(sample_rate) - - # Default kwargs - default_fe_kwargs = { - "sampling_rate": sample_rate, - "return_tensors": "pt", - "padding": True, - } - default_ap_kwargs = { - "sampling_rate": sample_rate, - "return_tensors": "pt", - "padding": True, - } - - # Apply per-model overrides (use _EXCLUDE sentinel to remove default keys) - fe_kwargs = {**default_fe_kwargs, **config.get("fe_kwargs", {})} - fe_kwargs = {k: v for k, v in fe_kwargs.items() if v is not _EXCLUDE} - ap_kwargs = {**default_ap_kwargs, **config.get("ap_kwargs", {})} - ap_kwargs = {k: v for k, v in ap_kwargs.items() if v is not _EXCLUDE} - - # Run feature extractor (copy inputs since some FEs mutate the list in-place) - fe_output = fe([x.copy() for x in audio_batch], **fe_kwargs) - - # Run audio processor - ap_output = ap([x.copy() for x in audio_batch], **ap_kwargs) - - fe_to_ap_key_map = { - "input_features": "audio_features", - "input_values": "audio_values", - "audio_input_features": "audio_features", - } - - # Mapping for attention mask and padding mask keys depending on the primary input key - mask_key_map = { - "input_values": "audio_values_mask", - "input_features": "audio_features_mask", - } - - # Find out if this output contains input_values or input_features (to key mask mapping) - has_input_values = "input_values" in fe_output - has_input_features = "input_features" in fe_output - - for fe_key in fe_output.keys(): - # Remap the primary data keys - ap_key = fe_to_ap_key_map.get(fe_key, fe_key) - - # Special handling for attention_mask and padding_mask mapping - if fe_key in ("attention_mask", "padding_mask"): - if has_input_values: - ap_key = mask_key_map["input_values"] - elif has_input_features: - ap_key = mask_key_map["input_features"] - else: - ap_key = fe_key # fallback/default - - assert ap_key in ap_output, f"Key {ap_key} (from FE key {fe_key}) not found in audio processor output" - fe_tensor = fe_output[fe_key] - ap_tensor = ap_output[ap_key] - - if not isinstance(fe_tensor, torch.Tensor): - fe_tensor = torch.tensor(fe_tensor) - if not isinstance(ap_tensor, torch.Tensor): - ap_tensor = torch.tensor(ap_tensor) - - assert fe_tensor.shape == ap_tensor.shape, ( - f"Shape mismatch for key '{fe_key}' (ap key '{ap_key}'): fe {fe_tensor.shape} vs ap {ap_tensor.shape}" - ) - atol = config.get("atol", 0.0) - if atol > 0: - assert torch.allclose(fe_tensor, ap_tensor, atol=atol, rtol=0), ( - f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}, atol={atol}" - ) - else: - assert torch.equal(fe_tensor, ap_tensor), ( - f"Value mismatch for key '{fe_key}' (ap key '{ap_key}'): max abs diff = {(fe_tensor - ap_tensor).abs().max().item():.6e}" - ) - - -# --------------------------------------------------------------------------- -# Backward compatibility tests -# --------------------------------------------------------------------------- - -# Pairs of (fe_module_path, fe_class_name, ap_class) -_COMPAT_PAIRS = [ - ("transformers.models.whisper.feature_extraction_whisper", "WhisperFeatureExtractor", WhisperAudioProcessor), - ("transformers.models.clap.feature_extraction_clap", "ClapFeatureExtractor", ClapAudioProcessor), - ("transformers.models.encodec.feature_extraction_encodec", "EncodecFeatureExtractor", EncodecAudioProcessor), - ("transformers.models.dac.feature_extraction_dac", "DacFeatureExtractor", DacAudioProcessor), - ("transformers.models.wav2vec2.feature_extraction_wav2vec2", "Wav2Vec2FeatureExtractor", Wav2Vec2AudioProcessor), -] - - -@pytest.mark.parametrize( - "module_path, class_name, ap_class", - _COMPAT_PAIRS, - ids=[p[1] for p in _COMPAT_PAIRS], -) -class TestFeatureExtractorBackwardCompat: - """Tests that deprecated FeatureExtractor wrappers work correctly.""" - - def test_importable_and_warns(self, module_path, class_name, ap_class): - """Old class names are importable and emit FutureWarning.""" - import importlib - - mod = importlib.import_module(module_path) - fe_cls = getattr(mod, class_name) - with pytest.warns(FutureWarning, match="deprecated"): - fe_cls() - - def test_isinstance_check(self, module_path, class_name, ap_class): - """Deprecated FE instances pass isinstance checks against AudioProcessor.""" - import importlib - import warnings - - mod = importlib.import_module(module_path) - fe_cls = getattr(mod, class_name) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - fe = fe_cls() - assert isinstance(fe, ap_class) - assert issubclass(fe_cls, ap_class) - - def test_issubclass(self, module_path, class_name, ap_class): - """Deprecated FE class is a subclass of AudioProcessor.""" - import importlib - - mod = importlib.import_module(module_path) - fe_cls = getattr(mod, class_name) - assert issubclass(fe_cls, ap_class) - - -class TestBatchFeatureLegacyKeys: - """Tests that old output key names are accessible via BatchFeature.""" - - def setup_method(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - # Reset warned keys so each test gets fresh warnings - AudioBatchFeature._warned_keys.clear() - - def test_input_features_resolves_to_audio_features(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_features": np.array([1, 2, 3])}) - with pytest.warns(FutureWarning, match="input_features"): - result = bf["input_features"] - assert np.array_equal(result, np.array([1, 2, 3])) - - def test_input_values_resolves_to_audio_values(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_values": np.array([4, 5, 6])}) - with pytest.warns(FutureWarning, match="input_values"): - result = bf["input_values"] - assert np.array_equal(result, np.array([4, 5, 6])) - - def test_attention_mask_resolves_to_audio_features_mask(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_features": np.array([1]), "audio_features_mask": np.array([1, 1, 0])}) - with pytest.warns(FutureWarning, match="attention_mask"): - result = bf["attention_mask"] - assert np.array_equal(result, np.array([1, 1, 0])) - - def test_attention_mask_resolves_to_audio_values_mask(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_values": np.array([1]), "audio_values_mask": np.array([0, 1, 1])}) - with pytest.warns(FutureWarning, match="attention_mask"): - result = bf["attention_mask"] - assert np.array_equal(result, np.array([0, 1, 1])) - - def test_contains_legacy_key(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_features": np.array([1])}) - assert "input_features" in bf - assert "audio_features" in bf - assert "nonexistent_key" not in bf - - def test_warning_fires_once(self): - from transformers.audio_processing_base import BatchFeature as AudioBatchFeature - - bf = AudioBatchFeature({"audio_features": np.array([1, 2, 3])}) - with pytest.warns(FutureWarning, match="input_features"): - bf["input_features"] - # Second access should not warn - import warnings - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - bf["input_features"] - future_warnings = [x for x in w if issubclass(x.category, FutureWarning) and "input_features" in str(x.message)] - assert len(future_warnings) == 0 From 5b5aa0a265cf4d2215d796cf48072897fc0b38c3 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 17:46:19 +0100 Subject: [PATCH 17/28] add computation_dtype to have matching torch/ numpy implems --- src/transformers/audio_processing_backends.py | 2 ++ src/transformers/audio_utils.py | 1 + .../whisper/audio_processing_whisper.py | 25 ++----------------- 3 files changed, 5 insertions(+), 23 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 1107dde05e9f..af48806547ac 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -359,6 +359,7 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config + computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None return _torch_spec.mel_filter_bank_torch( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, @@ -369,6 +370,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): mel_scale=mel_cfg.mel_scale, triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, frequency_bin_mode=mel_cfg.frequency_bin_mode, + computation_dtype=computation_dtype, ) def _to_batch(self, audio): diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index d26ad649a0c3..2cc3d71b37f4 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -103,6 +103,7 @@ class MelScaleConfig: norm: str | None = None triangularize_in_mel_space: bool = False frequency_bin_mode: str = "rfft" + computation_dtype: str | None = None def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index f2850bc10cb8..63bc1058fc3e 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -14,8 +14,6 @@ import torch -from spectrograms import numpy_mel_spectrogram as _np_spec - from ...audio_processing_backends import TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig @@ -36,6 +34,7 @@ class WhisperAudioProcessor(TorchAudioBackend): n_mels=80, mel_scale="slaney", norm="slaney", + computation_dtype="float64", ), log_mode="log10", ) @@ -50,28 +49,8 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): return features - def _mel_filter_bank(self, spectrogram_config): - stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config - mel_filters_np = _np_spec.mel_filter_bank( - num_frequency_bins=1 + stft_cfg.n_fft // 2, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - ) - return torch.from_numpy(mel_filters_np).to(torch.float32) - def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): - """ - Override to use the same matrix multiplication order as WhisperFeatureExtractor - for exact numerical compatibility. FeatureExtractor uses (n_mels, n_freq) @ (n_freq, time), - while the generic spectrograms module uses (time, n_freq) @ (n_freq, n_mels) then transpose. - The different summation order produces slightly different rounding (1 ULP). - """ + # Override to match WhisperFeatureExtractor's mel transformation order for numerical compatibility. stacked = torch.stack(features) if isinstance(features, list) else features mel_spec = torch.matmul(self.mel_filters.T, stacked) return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) From 14c9cae2ad5898bdc6b77481a6ae6ba4825535b0 Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 18:31:55 +0100 Subject: [PATCH 18/28] use 5.5 for deprecation --- src/transformers/utils/deprecation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index 98af20e5df77..9b44e549df1b 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -33,7 +33,7 @@ class Action(ExplicitEnum): RAISE = "raise" -def deprecated_feature_extractor(audio_processor_class, old_class_name, version="4.55"): +def deprecated_feature_extractor(audio_processor_class, old_class_name, version="5.5"): """Create a deprecated FeatureExtractor alias for an AudioProcessor. Uses dynamic class creation to reduce boilerplate across ~20 models. From b35918e864196cda2dc4846bf272a00ba941906b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 18:36:31 +0100 Subject: [PATCH 19/28] lasr update --- src/transformers/audio_processing_backends.py | 24 +++-- src/transformers/audio_utils.py | 3 + .../models/lasr/audio_processing_lasr.py | 93 ++++++------------- 3 files changed, 45 insertions(+), 75 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index af48806547ac..6aff5c9832a1 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -262,14 +262,11 @@ def _extract_spectrogram( """Compute the (power) spectrogram via STFT using the torch backend.""" stft_cfg = spectrogram_config.stft_config - - # if spectrogram_config.preemphasis is not None: - # audio_ranges = kwargs.get("audio_ranges", None) - # if audio_ranges is not None: - # device = waveform.device - # timemask = torch.arange(waveform.shape[1], device=device).unsqueeze(0) - # timemask = timemask < audio_ranges.unsqueeze(1) - # waveform = waveform.masked_fill(~timemask, 0.0) + computation_dtype = ( + getattr(torch, spectrogram_config.computation_dtype) + if spectrogram_config.computation_dtype + else None + ) magnitudes = _torch_spec._extract_spectrogram( audio, @@ -287,6 +284,8 @@ def _extract_spectrogram( periodic=stft_cfg.periodic, preemphasis=spectrogram_config.preemphasis, remove_dc_offset=spectrogram_config.remove_dc_offset, + computation_dtype=computation_dtype, + left_align_fft=stft_cfg.left_align_fft, ) return magnitudes @@ -360,7 +359,7 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None - return _torch_spec.mel_filter_bank_torch( + mel_filters = _torch_spec.mel_filter_bank_torch( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -371,7 +370,14 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, frequency_bin_mode=mel_cfg.frequency_bin_mode, computation_dtype=computation_dtype, + bands_to_zero=mel_cfg.bands_to_zero, ) + # When computation_dtype is set only on the mel config (not on the + # spectrogram config), the filters were computed in high precision for + # accuracy but the spectrogram will be in the default dtype — cast back. + if computation_dtype is not None and not spectrogram_config.computation_dtype: + mel_filters = mel_filters.to(torch.get_default_dtype()) + return mel_filters def _to_batch(self, audio): return torch.stack(audio) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 2cc3d71b37f4..7f1d1ae4eefb 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -78,6 +78,7 @@ class StftConfig: onesided: bool | None = None pad: int = 0 periodic: bool = True + left_align_fft: bool = False def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -104,6 +105,7 @@ class MelScaleConfig: triangularize_in_mel_space: bool = False frequency_bin_mode: str = "rfft" computation_dtype: str | None = None + bands_to_zero: int = 0 def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -127,6 +129,7 @@ class SpectrogramConfig: remove_dc_offset: bool = False mel_floor: float = 1e-10 waveform_scale: float | None = None + computation_dtype: str | None = None def __getitem__(self, key): if hasattr(self, key): diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index 400e82d46829..a795fbbb5325 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -12,34 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import torch from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, hertz_to_mel - - -def _linear_to_mel_weight_matrix(num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz): - """Kaldi-style mel weight matrix matching the LASR FE implementation.""" - internal_dtype = np.float64 - bands_to_zero = 1 - nyquist_hertz = sample_rate / 2.0 - linear_frequencies = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins, dtype=internal_dtype)[bands_to_zero:] - spectrogram_bins_mel = hertz_to_mel(linear_frequencies, mel_scale="kaldi")[:, np.newaxis] - - edges = np.linspace( - hertz_to_mel(lower_edge_hertz, mel_scale="kaldi"), - hertz_to_mel(upper_edge_hertz, mel_scale="kaldi"), - num_mel_bins + 2, - dtype=internal_dtype, - ) - lower_edge_mel = edges[:-2][np.newaxis, :] - center_mel = edges[1:-1][np.newaxis, :] - upper_edge_mel = edges[2:][np.newaxis, :] - - lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel) - upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel) - mel_weights = np.maximum(0.0, np.minimum(lower_slopes, upper_slopes)) - return np.pad(mel_weights, [[bands_to_zero, 0], [0, 0]]).astype(np.float64) +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class LasrAudioProcessor(TorchAudioBackend): @@ -47,54 +23,39 @@ class LasrAudioProcessor(TorchAudioBackend): force_mono = True add_channel_dim = True spectrogram_config = SpectrogramConfig( - stft_config=StftConfig(n_fft=512, hop_length=160, win_length=400, power=2.0), - mel_scale_config=MelScaleConfig(n_mels=128, f_min=125.0, f_max=7500.0, mel_scale="kaldi"), + stft_config=StftConfig( + n_fft=512, + hop_length=160, + win_length=400, + power=2.0, + center=False, + periodic=False, + left_align_fft=True, + ), + mel_scale_config=MelScaleConfig( + n_mels=128, + f_min=125.0, + f_max=7500.0, + mel_scale="kaldi", + triangularize_in_mel_space=True, + bands_to_zero=1, + computation_dtype="float64", + ), log_mode="log", + mel_floor=1e-5, + computation_dtype="float64", ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.mel_filters = _linear_to_mel_weight_matrix( - num_mel_bins=128, - num_spectrogram_bins=512 // 2 + 1, - sample_rate=self.sample_rate, - lower_edge_hertz=125.0, - upper_edge_hertz=7500.0, - ) + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + # LASR uses (time, freq) @ (freq, mels) -> (time, mels) ordering, + # matching the upstream FE's unfold-based output layout. + mel_spec = torch.matmul(features.transpose(-2, -1), self.mel_filters.to(device=features.device, dtype=features.dtype)) + return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): stft_cfg = spectrogram_config.stft_config win_length = stft_cfg.win_length or stft_cfg.n_fft return (audio_lengths - win_length) // stft_cfg.hop_length + 1 - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): - import torch - - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - - stft_cfg = spectrogram_config.stft_config - n_fft = stft_cfg.n_fft - hop_length = stft_cfg.hop_length - win_length = stft_cfg.win_length or n_fft - - if isinstance(audio, list): - waveform = torch.stack(audio, dim=0).to(torch.float64) - else: - waveform = audio.to(torch.float64) - - device = waveform.device - - window = torch.hann_window(win_length, periodic=False, device=device, dtype=torch.float64) - frames = waveform.unfold(-1, win_length, hop_length) - stft = torch.fft.rfft(window * frames, n=n_fft) - power_spec = torch.abs(stft) ** 2 - - mel_filters = torch.from_numpy(self.mel_filters).to(device) - mel_spec = torch.clamp(power_spec @ mel_filters, min=1e-5) - mel_spec = torch.log(mel_spec) - - return [mel_spec[i].to(torch.float32) for i in range(mel_spec.shape[0])] - __all__ = ["LasrAudioProcessor"] From 7ded1aab81d064b9d98813d581bd75947843503c Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 19:16:53 +0100 Subject: [PATCH 20/28] gemma3n update --- .../gemma3n/audio_processing_gemma3n.py | 83 +++++-------------- 1 file changed, 21 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index 58a41a9f4c13..ab9ea1317d8d 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -12,32 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import numpy as np from ...audio_processing_backends import NumpyAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature - - -def _create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, fft_length, norm=None): - """HTK-style mel filterbank matrix matching Gemma3n FE implementation.""" - all_freqs = np.arange(n_freqs, dtype=np.float32) * (sample_rate / fft_length) - m_min = 2595.0 * math.log10(1.0 + (f_min / 700.0)) - m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) - m_pts = np.linspace(m_min, m_max, n_mels + 2) - f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) - f_diff = f_pts[1:] - f_pts[:-1] - slopes = np.expand_dims(f_pts, 0) - np.expand_dims(all_freqs, 1) - zero = np.zeros(1, dtype=np.float32) - down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] - up_slopes = slopes[:, 2:] / f_diff[1:] - fb = np.maximum(zero, np.minimum(down_slopes, up_slopes)) - if norm == "slaney": - enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) - fb *= np.expand_dims(enorm, 0) - return fb def _unfold(array, dimension, size, step): @@ -100,32 +78,6 @@ def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): else: self.per_bin_stddev = None - def _mel_filter_bank(self, spectrogram_config): - """Custom HTK-style mel filterbank matching the original Gemma3n FE.""" - sc = spectrogram_config - msc = sc.mel_scale_config - return _create_fb_matrix( - n_freqs=sc.stft_config.n_fft // 2 + 1, - f_min=msc.f_min, - f_max=msc.f_max, - n_mels=msc.n_mels, - sample_rate=self.sample_rate, - fft_length=sc.stft_config.n_fft, - ) - - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - - # Process all waveforms at once (bypass base class per-element iteration) - if not isinstance(audio, list): - audio = [audio] - - features = self._extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) - features = self._apply_mel_scale(features, spectrogram_config=spectrogram_config, **kwargs) - # Skip _normalize_magnitude: _apply_mel_scale already applies log + per-bin normalization - return features - def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): stft_cfg = spectrogram_config.stft_config preemphasis = spectrogram_config.preemphasis @@ -147,22 +99,29 @@ def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): frames = frames * self.window # Broadcasting window stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) - magnitude_spec = np.abs(stft) + return np.abs(stft) def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): - """Apply mel filterbank, log compression, and per-bin normalization.""" - result = [] - for mag_spec in features: - mel_spec = np.matmul(mag_spec, self.mel_filters) - log_mel_spec = np.log(np.maximum(mel_spec, spectrogram_config.mel_floor)) - - if self.per_bin_mean is not None: - log_mel_spec = log_mel_spec - self.per_bin_mean - if self.per_bin_stddev is not None: - log_mel_spec = log_mel_spec / self.per_bin_stddev - - result.append(log_mel_spec.astype(np.float32)) - return result + """Apply mel filterbank. Features are in (batch, time, freq) format.""" + mel_spec = np.matmul(features, self.mel_filters) + return np.maximum(spectrogram_config.mel_floor, mel_spec) + + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + """Apply log compression and per-bin normalization.""" + result = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) + + if self.per_bin_mean is not None: + result = result - self.per_bin_mean + if self.per_bin_stddev is not None: + result = result / self.per_bin_stddev + + return result.astype(np.float32) + + def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): + """Frame count for unfold-based STFT (no centering).""" + hop_length = spectrogram_config.stft_config.hop_length + frame_size = spectrogram_config.stft_config.win_length + 1 + return (audio_lengths - frame_size) // hop_length + 1 __all__ = ["Gemma3nAudioProcessor"] From 2d5b8b36055745be4d20068d5f500069cd9de12b Mon Sep 17 00:00:00 2001 From: Eustache Le Bihan Date: Thu, 12 Mar 2026 19:28:14 +0100 Subject: [PATCH 21/28] temporarily use separate backends files --- src/transformers/torch_mel_spectrogram.py | 575 ++++++++++++++++++++++ 1 file changed, 575 insertions(+) create mode 100644 src/transformers/torch_mel_spectrogram.py diff --git a/src/transformers/torch_mel_spectrogram.py b/src/transformers/torch_mel_spectrogram.py new file mode 100644 index 000000000000..8a59d0b61fd2 --- /dev/null +++ b/src/transformers/torch_mel_spectrogram.py @@ -0,0 +1,575 @@ +"""PyTorch implementation of mel spectrogram computation.""" + +import math + +import torch + + +# --- Frequency conversion utilities --- + +def _hertz_to_mel_scalar(freq: float, mel_scale: str = "htk") -> float: + """Convert a single Hz value to mel using Python math (float64).""" + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + elif mel_scale == "kaldi": + return 1127.0 * math.log(1.0 + freq / 700.0) + # slaney + f_sp = 200.0 / 3 + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - 0.0) / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + return min_log_mel + math.log(freq / min_log_hz) / logstep + return (freq - 0.0) / f_sp + + +def hertz_to_mel(freq: torch.Tensor, mel_scale: str = "htk") -> torch.Tensor: + if mel_scale == "htk": + return 2595.0 * torch.log10(1.0 + freq / 700.0) + elif mel_scale == "kaldi": + return 1127.0 * torch.log(1.0 + freq / 700.0) + # slaney + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = 27.0 / torch.log(torch.tensor(6.4)) + mels = 3.0 * freq / 200.0 + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + torch.log(freq[log_region] / min_log_hertz) * logstep + return mels + + +def mel_to_hertz(mels: torch.Tensor, mel_scale: str = "htk") -> torch.Tensor: + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (torch.exp(mels / 1127.0) - 1.0) + # slaney + f_sp = 200.0 / 3 + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - 0.0) / f_sp + logstep = math.log(6.4) / 27.0 + freq = 0.0 + f_sp * mels + log_region = mels >= min_log_mel + freq[log_region] = min_log_hz * torch.exp(logstep * (mels[log_region] - min_log_mel)) + return freq + + +def _create_triangular_filter_bank( + fft_freqs: torch.Tensor, filter_freqs: torch.Tensor +) -> torch.Tensor: + filter_diff = filter_freqs[1:] - filter_freqs[:-1] + slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return torch.clamp(torch.minimum(down_slopes, up_slopes), min=0) + + +def _kaldi_mel_filter_bank( + num_frequency_bins: int, + num_mel_filters: int, + min_frequency: float, + max_frequency: float, + sampling_rate: int, +) -> torch.Tensor: + """Compute mel filter bank matching kaldi's exact construction. + + Replicates torchaudio.compliance.kaldi.get_mel_banks exactly: + - Uses 1127*ln mel scale (not 2595*log10) + - Computes mel points via mel_low + i * delta (not torch.linspace) + - Uses n_fft/2 FFT bins (excludes Nyquist), then pads with zero column + + Returns: + Tensor of shape (num_frequency_bins, num_mel_filters). + """ + n_fft = (num_frequency_bins - 1) * 2 + num_fft_bins = n_fft // 2 # kaldi excludes Nyquist bin + fft_bin_width = sampling_rate / n_fft + + mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0) + mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0) + mel_delta = (mel_high - mel_low) / (num_mel_filters + 1) + + bin_idx = torch.arange(num_mel_filters).unsqueeze(1) + left_mel = mel_low + bin_idx * mel_delta + center_mel = mel_low + (bin_idx + 1.0) * mel_delta + right_mel = mel_low + (bin_idx + 2.0) * mel_delta + + mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log() + mel = mel.unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + + # kaldi pads a zero column for the Nyquist bin + banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0) + + return banks.T # (num_frequency_bins, num_mel_filters) + + +def mel_filter_bank_torch( + num_frequency_bins: int, + num_mel_filters: int, + min_frequency: float, + max_frequency: float, + sampling_rate: int, + norm: str | None = None, + mel_scale: str = "htk", + triangularize_in_mel_space: bool = False, + frequency_bin_mode: str = "rfft", + computation_dtype: "torch.dtype | None" = None, + bands_to_zero: int = 0, +) -> torch.Tensor: + """Compute mel filter bank as a pure PyTorch tensor. + + Matches torchaudio's melscale_fbanks: mel range endpoints are computed in + float64 (Python math), then all tensor work is done in the default dtype + (float32). + + Args: + computation_dtype: If provided, all intermediate tensor operations are + performed in this dtype (e.g. ``torch.float64``), and the result is + cast back to the default dtype. This is useful to obtain results + that are numerically identical to a NumPy (float64) reference + implementation. + bands_to_zero: Number of lowest frequency bins to zero out before + building the filter bank. The zeroed rows are restored (as zeros) + in the output. Set to 1 to exclude the DC bin (HTK / LASR style). + + Returns: + Tensor of shape (num_frequency_bins, num_mel_filters). + """ + if triangularize_in_mel_space and bands_to_zero == 0: + # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks. + # Kept for backward compatibility with models that rely on this behaviour + # (AST, SeamlessM4T, Speech2Text, etc.). + return _kaldi_mel_filter_bank( + num_frequency_bins, num_mel_filters, min_frequency, max_frequency, sampling_rate, + ) + + mel_min = _hertz_to_mel_scalar(min_frequency, mel_scale=mel_scale) + mel_max = _hertz_to_mel_scalar(max_frequency, mel_scale=mel_scale) + + n_fft = (num_frequency_bins - 1) * 2 + + if triangularize_in_mel_space: + # Kaldi-style direct slope computation in mel space. + # Uses mel_low + i * delta (not linspace) and direct per-band slopes + # to match the exact numerical behaviour of kaldi/HTK filter banks. + mel_delta = (mel_max - mel_min) / (num_mel_filters + 1) + bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1) + left_mel = mel_min + bin_idx * mel_delta + center_mel = mel_min + (bin_idx + 1.0) * mel_delta + right_mel = mel_min + (bin_idx + 2.0) * mel_delta + + fft_bin_width = sampling_rate / n_fft + num_fft_bins = num_frequency_bins - bands_to_zero + hz_freqs = fft_bin_width * torch.arange(bands_to_zero, num_frequency_bins, dtype=computation_dtype) + mel = hertz_to_mel(hz_freqs, mel_scale=mel_scale).unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + mel_filters = torch.max(torch.zeros(1, dtype=computation_dtype), torch.min(up_slope, down_slope)) + + # Transpose to (num_fft_bins, num_mel_filters) and restore zeroed bands + mel_filters = mel_filters.T + if bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, bands_to_zero, 0)) + + return mel_filters + + mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype) + + filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) + if frequency_bin_mode == "rfft": + fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) + else: + fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins) + if computation_dtype is not None: + fft_freqs = fft_freqs.to(computation_dtype) + + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) + + if norm == "slaney": + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters = mel_filters * enorm.unsqueeze(0) + + if bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, bands_to_zero, 0)) + + return mel_filters + + +def window_function(window_length, name="hann_window", periodic=True, wkwargs=None): + """Create a window tensor using torch window functions.""" + if wkwargs is None: + wkwargs = {} + if name in ["hann", "hann_window"]: + return torch.hann_window(window_length, periodic=periodic, **wkwargs) + elif name in ["hamming", "hamming_window"]: + return torch.hamming_window(window_length, periodic=periodic, **wkwargs) + elif name == "boxcar": + return torch.ones(window_length) + elif name == "povey": + return torch.hann_window(window_length, periodic=periodic, **wkwargs).pow(0.85) + else: + raise ValueError(f"Unknown window function '{name}'") + + +# --- Sub-methods --- + +def _extract_spectrogram( + waveform: torch.Tensor, + sampling_rate: int, + *, + n_fft: int = 400, + win_length: int | None = None, + hop_length: int | None = None, + window_fn: str = "hann_window", + wkwargs: dict | None = None, + power: float = 2.0, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + pad: int = 0, + periodic: bool = True, + dither: float = 0.0, + preemphasis: float | None = None, + remove_dc_offset: bool = False, + computation_dtype: "torch.dtype | None" = None, + left_align_fft: bool = False, +) -> torch.Tensor: + """Compute the (power) spectrogram via STFT. + + Args: + waveform: Input waveform of shape (..., time). + sampling_rate: Sample rate in Hz. + left_align_fft: If True, use manual framing with unfold(win_length) + zero-pad + right + rfft(n_fft). This left-aligns the window in the FFT buffer (kaldi + style), instead of center-padding it (torch.stft default). + + Returns: + Power spectrogram of shape (..., freq, time). + """ + if win_length is None: + win_length = n_fft + if hop_length is None: + hop_length = win_length // 2 + if computation_dtype is not None: + waveform = waveform.to(computation_dtype) + device = waveform.device + dtype = waveform.dtype + + needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset or left_align_fft + + window_wkwargs = {**(wkwargs or {}), "dtype": dtype} + window = window_function(win_length, name=window_fn, periodic=periodic, wkwargs=window_wkwargs) + window = window.to(device=device) + + if needs_manual_framing and win_length < n_fft: + frame_length = win_length + else: + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = torch.nn.functional.pad(window, (left_pad, right_pad)) + frame_length = n_fft + + fft_length = n_fft + num_frequency_bins = (fft_length // 2) + 1 + + is_1d = waveform.ndim == 1 + if is_1d: + waveform = waveform.unsqueeze(0) + + leading_shape = waveform.shape[:-1] + waveform = waveform.reshape(-1, waveform.shape[-1]) + + if pad > 0: + waveform = torch.nn.functional.pad(waveform, (pad, pad)) + + if needs_manual_framing: + result = _manual_stft( + waveform, window, frame_length, hop_length, fft_length, + num_frequency_bins, power, normalized, center, pad_mode, + dither, preemphasis, remove_dc_offset, + ) + else: + result = _torch_stft( + waveform, window, frame_length, hop_length, fft_length, + power, normalized, center, pad_mode, + ) + + result = result.reshape(*leading_shape, result.shape[-2], result.shape[-1]) + + if is_1d: + result = result.squeeze(0) + + if computation_dtype is not None: + return result + return result.float() + + +def _apply_mel_scale( + spectrogram: torch.Tensor, + mel_filters: torch.Tensor, + mel_floor: float = 1e-10, +) -> torch.Tensor: + """Apply mel filterbank to a spectrogram. + + Args: + spectrogram: Power spectrogram of shape (..., freq, time). + mel_filters: Mel filterbank of shape (freq, n_mels). + mel_floor: Minimum value for clamping. + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + # (..., time, freq) @ (freq, n_mels) -> (..., time, n_mels) -> (..., n_mels, time) + mel_spec = torch.matmul(spectrogram.transpose(-2, -1), mel_filters).transpose(-2, -1) + return torch.clamp(mel_spec, min=mel_floor) + + +def _torch_stft( + waveform, window, frame_length, hop_length, fft_length, + power, normalized, center, pad_mode, +): + """Fast path using torch.stft. Returns power spectrogram of shape (batch, freq, time).""" + stft_out = torch.stft( + waveform, + n_fft=fft_length, + hop_length=hop_length, + win_length=frame_length, + window=window, + center=center, + pad_mode=pad_mode, + normalized=False, + return_complex=True, + ) + if normalized: + stft_out = stft_out / window.pow(2.0).sum().sqrt() + return stft_out.abs() ** power + + +def _manual_stft( + waveform, window, frame_length, hop_length, fft_length, + num_frequency_bins, power, normalized, center, pad_mode, + dither, preemphasis, remove_dc_offset, +): + """Manual framing STFT for kaldi-specific features. Returns power spectrogram of shape (batch, freq, time).""" + if center: + waveform = torch.nn.functional.pad( + waveform, (frame_length // 2, frame_length // 2), mode=pad_mode + ) + + # Extract all frames at once: (batch, num_frames, frame_length) + frames = waveform.unfold(-1, frame_length, hop_length) + + if dither != 0.0: + frames = frames + dither * torch.randn_like(frames) + + if remove_dc_offset: + frames = frames - frames.mean(dim=-1, keepdim=True) + + if preemphasis is not None: + frames = torch.cat([ + frames[..., :1] * (1 - preemphasis), + frames[..., 1:] - preemphasis * frames[..., :-1], + ], dim=-1) + + frames = frames * window + + # Zero-pad frames to fft_length if frame_length < fft_length (kaldi left-aligns in FFT buffer) + if frame_length < fft_length: + frames = torch.nn.functional.pad(frames, (0, fft_length - frame_length)) + + # Batched FFT: (batch, num_frames, fft_length) -> (batch, num_frames, num_frequency_bins) + spec = torch.fft.rfft(frames, n=fft_length) + + if normalized: + spec = spec / window.pow(2.0).sum().sqrt() + + spec = spec.abs() ** power + + # (batch, num_frames, freq) -> (batch, freq, num_frames) + return spec.transpose(-2, -1) + + +# --- Main function --- + +def mel_spectrogram( + waveform: torch.Tensor, + sampling_rate: int, + *, + n_fft: int = 400, + win_length: int | None = None, + hop_length: int | None = None, + window_fn: str = "hann_window", + wkwargs: dict | None = None, + power: float = 2.0, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + pad: int = 0, + periodic: bool = True, + # mel scale kwargs + n_mels: int = 128, + f_min: float = 0.0, + f_max: float | None = None, + mel_scale: str = "htk", + norm: str | None = None, + triangularize_in_mel_space: bool = False, + # kaldi-specific kwargs + dither: float = 0.0, + preemphasis: float | None = None, + remove_dc_offset: bool = False, + mel_floor: float = 1e-10, +) -> torch.Tensor: + """Compute mel spectrogram using PyTorch. + + Args: + waveform: Input waveform of shape (..., time). + sampling_rate: Sample rate in Hz. + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + if f_max is None: + f_max = sampling_rate / 2.0 + + spectrogram = _extract_spectrogram( + waveform, sampling_rate, + n_fft=n_fft, win_length=win_length, hop_length=hop_length, + window_fn=window_fn, wkwargs=wkwargs, power=power, + center=center, pad_mode=pad_mode, normalized=normalized, pad=pad, periodic=periodic, + dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset, + ) + + num_frequency_bins = spectrogram.shape[-2] + mel_filters = mel_filter_bank_torch( + num_frequency_bins, n_mels, f_min, f_max, sampling_rate, + norm=norm, mel_scale=mel_scale, + triangularize_in_mel_space=triangularize_in_mel_space, + ).to(spectrogram.device) + + return _apply_mel_scale(spectrogram, mel_filters, mel_floor=mel_floor) + + +class MelSpectrogram(torch.nn.Module): + """Cached mel spectrogram transform — precomputes window and mel filterbank. + + Same API and exact same results as the functional ``mel_spectrogram``, but + avoids recomputing the window and mel filterbank on every call. + + Usage:: + + transform = MelSpectrogram(sampling_rate=16000, n_fft=1024, n_mels=80) + transform = transform.cuda() # move buffers to GPU once + mel = transform(waveform) # fast repeated calls + """ + + def __init__( + self, + sampling_rate: int, + *, + n_fft: int = 400, + win_length: int | None = None, + hop_length: int | None = None, + window_fn: str = "hann_window", + wkwargs: dict | None = None, + power: float = 2.0, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + pad: int = 0, + periodic: bool = True, + n_mels: int = 128, + f_min: float = 0.0, + f_max: float | None = None, + mel_scale: str = "htk", + norm: str | None = None, + triangularize_in_mel_space: bool = False, + dither: float = 0.0, + preemphasis: float | None = None, + remove_dc_offset: bool = False, + mel_floor: float = 1e-10, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.power = power + self.center = center + self.pad_mode = pad_mode + self.normalized = normalized + self.pad = pad + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max if f_max is not None else sampling_rate / 2.0 + self.mel_floor = mel_floor + self.dither = dither + self.preemphasis = preemphasis + self.remove_dc_offset = remove_dc_offset + + self._needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset + + # Build window + window = window_function(self.win_length, name=window_fn, periodic=periodic, wkwargs=wkwargs) + if self._needs_manual_framing and self.win_length < n_fft: + self._frame_length = self.win_length + else: + if self.win_length < n_fft: + left_pad = (n_fft - self.win_length) // 2 + right_pad = n_fft - self.win_length - left_pad + window = torch.nn.functional.pad(window, (left_pad, right_pad)) + self._frame_length = n_fft + self.register_buffer("window", window) + + # Build mel filterbank + num_frequency_bins = n_fft // 2 + 1 + mel_fb = mel_filter_bank_torch( + num_frequency_bins, n_mels, self.f_min, self.f_max, sampling_rate, + norm=norm, mel_scale=mel_scale, + triangularize_in_mel_space=triangularize_in_mel_space, + ) + self.register_buffer("mel_filters", mel_fb) + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + """Compute mel spectrogram. + + Args: + waveform: Input of shape (..., time). + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + is_1d = waveform.ndim == 1 + if is_1d: + waveform = waveform.unsqueeze(0) + + leading_shape = waveform.shape[:-1] + waveform = waveform.reshape(-1, waveform.shape[-1]) + + if self.pad > 0: + waveform = torch.nn.functional.pad(waveform, (self.pad, self.pad)) + + if self._needs_manual_framing: + spec = _manual_stft( + waveform, self.window, self._frame_length, self.hop_length, + self.n_fft, self.n_fft // 2 + 1, self.power, self.normalized, + self.center, self.pad_mode, self.dither, self.preemphasis, + self.remove_dc_offset, + ) + else: + spec = _torch_stft( + waveform, self.window, self._frame_length, self.hop_length, + self.n_fft, self.power, self.normalized, self.center, self.pad_mode, + ) + + spec = spec.reshape(*leading_shape, spec.shape[-2], spec.shape[-1]) + if is_1d: + spec = spec.squeeze(0) + spec = spec.float() + + return _apply_mel_scale(spec, self.mel_filters, mel_floor=self.mel_floor) From c7199451a8975482ea3f81a82549107c9b943911 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Mon, 23 Mar 2026 12:18:38 +0100 Subject: [PATCH 22/28] all tests passing --- src/transformers/audio_processing_backends.py | 130 ++++-- src/transformers/audio_processing_utils.py | 71 ++- src/transformers/audio_utils.py | 11 +- .../models/clap/audio_processing_clap.py | 8 +- .../gemma3n/audio_processing_gemma3n.py | 29 +- .../models/lasr/audio_processing_lasr.py | 1 - .../parakeet/audio_processing_parakeet.py | 102 +++-- .../audio_processing_phi4_multimodal.py | 2 + .../audio_processing_seamless_m4t.py | 40 +- .../audio_processing_speech_to_text.py | 37 +- ...processing_vibevoice_acoustic_tokenizer.py | 6 +- src/transformers/numpy_mel_spectrogram.py | 413 ++++++++++++++++++ src/transformers/torch_mel_spectrogram.py | 187 +++----- 13 files changed, 748 insertions(+), 289 deletions(-) create mode 100644 src/transformers/numpy_mel_spectrogram.py diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 6aff5c9832a1..90b9b66b4c95 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -14,11 +14,9 @@ # limitations under the License. -import sys -from pathlib import Path - import numpy as np +from . import numpy_mel_spectrogram as _np_spec from .audio_processing_utils import BaseAudioProcessor from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db from .feature_extraction_utils import BatchFeature @@ -27,16 +25,10 @@ logger = logging.get_logger(__name__) -_WORKSPACE_ROOT = str(Path(__file__).resolve().parents[3]) -if _WORKSPACE_ROOT not in sys.path: - sys.path.insert(0, _WORKSPACE_ROOT) - -from spectrograms import numpy_mel_spectrogram as _np_spec - if is_torch_available(): import torch - from spectrograms import torch_mel_spectrogram as _torch_spec + from . import torch_mel_spectrogram as _torch_spec class NumpyAudioBackend(BaseAudioProcessor): @@ -88,7 +80,7 @@ def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray: return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) - def _extract_spectrogram( + def _stft( self, audio: list[np.ndarray], *, @@ -97,23 +89,29 @@ def _extract_spectrogram( ) -> list[np.ndarray]: """Compute the (power) spectrogram via STFT using the numpy backend.""" stft_cfg = spectrogram_config.stft_config + n_fft = stft_cfg.n_fft + win_length = stft_cfg.win_length or n_fft + hop_length = stft_cfg.hop_length or win_length // 2 - return _np_spec._extract_spectrogram( - audio, - self.sample_rate, - n_fft=stft_cfg.n_fft, - win_length=stft_cfg.win_length, - hop_length=stft_cfg.hop_length, - window_fn=stft_cfg.window_fn, - power=stft_cfg.power, - center=stft_cfg.center, - pad_mode=stft_cfg.pad_mode, - normalized=stft_cfg.normalized, - pad=stft_cfg.pad, - periodic=stft_cfg.periodic, - preemphasis=spectrogram_config.preemphasis, - remove_dc_offset=spectrogram_config.remove_dc_offset, - ) + window = _np_spec.window_function(win_length, name=stft_cfg.window_fn, periodic=stft_cfg.periodic) + needs_manual_framing = (spectrogram_config.preemphasis is not None) or spectrogram_config.remove_dc_offset + window, frame_length = _np_spec._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + frames, num_frames = _np_spec._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode) + compute_dtype = np.result_type(audio.dtype, window.dtype) + frames = frames.astype(compute_dtype, copy=False) + + frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) + + return _np_spec._windowed_fft(frames, window, n_fft, stft_cfg.power, stft_cfg.normalized) + + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + """Apply per-frame signal conditioning using the numpy backend.""" + return _np_spec._apply_frame_processing( + frames, + preemphasis=spectrogram_config.preemphasis, + remove_dc_offset=spectrogram_config.remove_dc_offset, + ) def _apply_mel_scale( self, @@ -252,7 +250,23 @@ def _pad_single(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": return F.pad(audio, pad_args, "constant", self.padding_value) - def _extract_spectrogram( + def _needs_manual_framing(self, spectrogram_config): + """Whether the STFT requires manual framing (unfold-based) instead of torch.stft. + + Manual framing is needed when per-frame processing must happen between + frame extraction and windowing (e.g. per-frame preemphasis, DC offset removal, + or left-aligned FFT padding). + + Override in model-specific processors that handle preemphasis at the + waveform level (in ``_pre_stft``) and don't need per-frame processing. + """ + return ( + (spectrogram_config.preemphasis is not None) + or spectrogram_config.remove_dc_offset + or spectrogram_config.stft_config.left_align_fft + ) + + def _stft( self, audio: list["torch.Tensor"], # TODO: this can be either a audio or batch of audio and this should be documented *, @@ -260,7 +274,6 @@ def _extract_spectrogram( **kwargs, ) -> list["torch.Tensor"]: """Compute the (power) spectrogram via STFT using the torch backend.""" - stft_cfg = spectrogram_config.stft_config computation_dtype = ( getattr(torch, spectrogram_config.computation_dtype) @@ -268,28 +281,51 @@ def _extract_spectrogram( else None ) - magnitudes = _torch_spec._extract_spectrogram( - audio, - self.sample_rate, - n_fft=stft_cfg.n_fft, - win_length=stft_cfg.win_length, - hop_length=stft_cfg.hop_length, - window_fn=stft_cfg.window_fn, - wkwargs=stft_cfg.wkwargs, - power=stft_cfg.power, - center=stft_cfg.center, - pad_mode=stft_cfg.pad_mode, - normalized=stft_cfg.normalized, - pad=stft_cfg.pad, - periodic=stft_cfg.periodic, + n_fft = stft_cfg.n_fft + win_length = stft_cfg.win_length or n_fft + hop_length = stft_cfg.hop_length or win_length // 2 + + if computation_dtype is not None: + audio = audio.to(computation_dtype) + + needs_manual_framing = self._needs_manual_framing(spectrogram_config) + + window_wkwargs = {**(stft_cfg.wkwargs or {}), "dtype": audio.dtype} + window = _torch_spec.window_function(win_length, name=stft_cfg.window_fn, periodic=stft_cfg.periodic, wkwargs=window_wkwargs) + window = window.to(device=audio.device) + window, frame_length = _torch_spec._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + if needs_manual_framing: + apply_fp = lambda frames: self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) + magnitudes = _torch_spec._manual_stft( + audio, window, frame_length, hop_length, n_fft, + n_fft // 2 + 1, stft_cfg.power, stft_cfg.normalized, + stft_cfg.center, stft_cfg.pad_mode, + apply_frame_processing=apply_fp, + ) + else: + stft_out = _torch_spec._torch_stft( + audio, window, frame_length, hop_length, n_fft, + stft_cfg.normalized, stft_cfg.center, stft_cfg.pad_mode, + ) + magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power) + + if computation_dtype is not None: + return magnitudes + return magnitudes.float() + + def _compute_magnitudes(self, stft_out, power): + """Convert complex STFT output to a real-valued magnitude spectrogram.""" + return stft_out.abs() ** power + + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + """Apply per-frame signal conditioning using the torch backend.""" + return _torch_spec._apply_frame_processing( + frames, preemphasis=spectrogram_config.preemphasis, remove_dc_offset=spectrogram_config.remove_dc_offset, - computation_dtype=computation_dtype, - left_align_fft=stft_cfg.left_align_fft, ) - return magnitudes - def _apply_mel_scale( self, features: list["torch.Tensor"], diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 21dfdfe15223..94eb292de95c 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -157,6 +157,7 @@ def _preprocess( # pad and truncate audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) padded_length = audio[0].shape[-1] + self._audio_lengths = [end - start for start, end in audio_ranges] if do_extract_spectrogram: audio = self._to_batch(audio) if do_batch_spectrogram else audio @@ -323,30 +324,76 @@ def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | return features - def _extract_spectrogram(self, *args, **kwargs): + # ── Spectrogram extraction pipeline ────────────────────────────────── + # + # The full feature-extraction pipeline executed by `extract_spectrogram`: + # + # 1. _extract_spectrogram (STFT → power/magnitude spectrogram) + # a. _pre_stft – waveform-level pre-processing (hook, no-op by default) + # b. _prepare_window_and_framing – build/pad window, decide frame length + # c. _frame_waveform – slice waveform into overlapping frames + # d. _apply_frame_processing – per-frame conditioning: dither, DC offset, preemphasis (hook) + # e. windowing + FFT + power + # 2. _apply_mel_scale (mel filterbank projection) + # 3. _normalize_magnitude (log / dB scaling, optional per-utterance norm) + # + # Backend subclasses (NumpyAudioBackend, TorchAudioBackend) implement the + # full pipeline. Model-specific processors can override individual hooks + # (_pre_stft, _apply_frame_processing) or the entire _extract_spectrogram + # when the base STFT path is insufficient (e.g., Parakeet's custom magnitude + # computation). + + def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): + """Orchestrate the STFT pipeline. + + Runs the sub-steps listed above in order. Override this only when the + pipeline ordering itself needs to change (e.g., Parakeet needs audio-length + detection before ``_pre_stft``). Otherwise, override individual hooks. """ - Compute the (power) spectrogram via STFT. + audio = self._pre_stft(audio, spectrogram_config=spectrogram_config, **kwargs) + return self._stft(audio, spectrogram_config=spectrogram_config, **kwargs) - Implemented by backend subclasses (e.g., ``TorchAudioBackend``). - """ - raise NotImplementedError + def _pre_stft(self, audio, *, spectrogram_config, **kwargs): + """Hook: waveform-level pre-processing before STFT. - def _apply_mel_scale(self, *args, **kwargs): + Called before framing. Default: no-op (returns audio unchanged). + Override for processing on the full waveform, e.g. length-aware + preemphasis with masking (Parakeet). """ - Apply mel filterbank to a spectrogram. + return audio - Implemented by backend subclasses (e.g., ``TorchAudioBackend``). + def _stft(self, audio, *, spectrogram_config, **kwargs): + """Compute the STFT and return a power/magnitude spectrogram. + + Implemented by backend subclasses. Internally runs: + window creation → padding → framing → ``_apply_frame_processing`` → + windowing → FFT → power. + + Override in model-specific processors that need a fully custom STFT + (e.g., Gemma3n's unfold-based STFT with extra-sample framing). """ raise NotImplementedError - def _normalize_magnitude(self, *args, **kwargs): - """ - Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + """Hook: per-frame signal conditioning after frame extraction. - Implemented by backend subclasses (e.g., ``TorchAudioBackend``). + Called after framing, before windowing and FFT. Default backend + implementations apply dither, DC-offset removal, and standard + preemphasis. + + Override for non-standard frame processing, e.g. HTK-style + preemphasis (Gemma3n). """ raise NotImplementedError + def _apply_mel_scale(self, *args, **kwargs): + """Apply mel filterbank to spectrogram features.""" + raise NotImplementedError + + def _normalize_magnitude(self, *args, **kwargs): + """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features.""" + raise NotImplementedError + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): raise NotImplementedError diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 7f1d1ae4eefb..ff6f9a5788a2 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -76,7 +76,6 @@ class StftConfig: pad_mode: str = "reflect" normalized: bool = False onesided: bool | None = None - pad: int = 0 periodic: bool = True left_align_fft: bool = False @@ -400,10 +399,11 @@ def hertz_to_mel(freq: float | np.ndarray, mel_scale: str = "htk") -> float | np elif mel_scale == "kaldi": return 1127.0 * np.log(1.0 + (freq / 700.0)) + f_sp = 200.0 / 3 min_log_hertz = 1000.0 - min_log_mel = 15.0 + min_log_mel = min_log_hertz / f_sp logstep = 27.0 / np.log(6.4) - mels = 3.0 * freq / 200.0 + mels = freq / f_sp if isinstance(freq, np.ndarray): log_region = freq >= min_log_hertz @@ -436,10 +436,11 @@ def mel_to_hertz(mels: float | np.ndarray, mel_scale: str = "htk") -> float | np elif mel_scale == "kaldi": return 700.0 * (np.exp(mels / 1127.0) - 1.0) + f_sp = 200.0 / 3 min_log_hertz = 1000.0 - min_log_mel = 15.0 + min_log_mel = min_log_hertz / f_sp logstep = np.log(6.4) / 27.0 - freq = 200.0 * mels / 3.0 + freq = f_sp * mels if isinstance(mels, np.ndarray): log_region = mels >= min_log_mel diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index 4358af841b2d..d72ebf972457 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -177,7 +177,6 @@ def _preprocess( return_tensors, spectrogram_config=None, do_extract_spectrogram=None, - do_batch_spectrogram=True, **kwargs, ): # Use instance defaults when not explicitly provided (matching feature extractor behavior) @@ -237,12 +236,7 @@ def _preprocess( is_longer = [[longer] for longer in is_longer] input_features = {"audio_features": input_mel, "is_longer": is_longer} - input_features = BatchFeature(input_features) - - if return_tensors is not None: - input_features = input_features.convert_to_tensors(return_tensors) - - return input_features + return BatchFeature(input_features, tensor_type=return_tensors) __all__ = ["ClapAudioProcessor"] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index ab9ea1317d8d..7ea42246d746 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -78,26 +78,27 @@ def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): else: self.per_bin_stddev = None - def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): - stft_cfg = spectrogram_config.stft_config + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + """HTK-style preemphasis on frames extracted with an extra sample.""" preemphasis = spectrogram_config.preemphasis - - frame_size_for_unfold = stft_cfg.win_length + 1 - frames_to_process = _unfold(audio, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length) - - # Preemphasis if preemphasis is not None and preemphasis > 0.0: if self.preemphasis_htk_flavor: - first_in_frame = frames_to_process[..., :1] * (1.0 - preemphasis) - rest_in_frame = frames_to_process[..., 1:-1] - preemphasis * frames_to_process[..., :-2] - frames = np.concatenate([first_in_frame, rest_in_frame], axis=-1) + first = frames[..., :1] * (1.0 - preemphasis) + rest = frames[..., 1:-1] - preemphasis * frames[..., :-2] + return np.concatenate([first, rest], axis=-1) else: - frames = frames_to_process[..., 1:] - preemphasis * frames_to_process[..., :-1] - else: - frames = frames_to_process[..., :-1] + return frames[..., 1:] - preemphasis * frames[..., :-1] + return frames[..., :-1] + + def _stft(self, audio, *, spectrogram_config, **kwargs): + stft_cfg = spectrogram_config.stft_config + + frame_size_for_unfold = stft_cfg.win_length + 1 + frames = _unfold(audio, dimension=-1, size=frame_size_for_unfold, step=stft_cfg.hop_length) - frames = frames * self.window # Broadcasting window + frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) + frames = frames * self.window stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) return np.abs(stft) diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index a795fbbb5325..3f0c9c92a21e 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -21,7 +21,6 @@ class LasrAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True - add_channel_dim = True spectrogram_config = SpectrogramConfig( stft_config=StftConfig( n_fft=512, diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index 82a6becab471..f328bbe12ab8 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import librosa -import torch - from ...audio_processing_backends import TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig @@ -36,6 +33,7 @@ class ParakeetAudioProcessor(TorchAudioBackend): n_mels=80, f_min=0.0, norm="slaney", + mel_scale="slaney", ), preemphasis=0.97, log_mode="log", @@ -43,59 +41,81 @@ class ParakeetAudioProcessor(TorchAudioBackend): ) def _mel_filter_bank(self, spectrogram_config): - """Use librosa mel filters for exact numerical match with the feature extractor.""" + """Compute mel filters via numpy for exact numerical match with the feature extractor. + + The FE uses librosa which accumulates into a float32 array per-band. + Replicating that truncation pattern is needed for bit-exact results. + """ + import numpy as np + import torch + + from ...audio_utils import hertz_to_mel, mel_to_hertz + stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - mel_filters = librosa.filters.mel( - sr=self.sample_rate, - n_fft=stft_cfg.n_fft, - n_mels=mel_cfg.n_mels, - fmin=mel_cfg.f_min, - fmax=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - norm=mel_cfg.norm, - ) - # librosa returns (n_mels, freq); transpose to (freq, n_mels) for base class convention - return torch.from_numpy(mel_filters.T).to(torch.float32) - + n_fft = stft_cfg.n_fft + n_mels = mel_cfg.n_mels + f_min = mel_cfg.f_min + f_max = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 + + mel_min = hertz_to_mel(f_min, mel_scale=mel_cfg.mel_scale) + mel_max = hertz_to_mel(f_max, mel_scale=mel_cfg.mel_scale) + mel_pts = np.linspace(mel_min, mel_max, n_mels + 2) + filter_freqs = mel_to_hertz(mel_pts.copy(), mel_scale=mel_cfg.mel_scale) + fft_freqs = np.linspace(0, self.sample_rate / 2, 1 + n_fft // 2) + + fdiff = np.diff(filter_freqs) + ramps = np.subtract.outer(filter_freqs, fft_freqs) + + # Accumulate into f32 per-band to match librosa's truncation pattern + weights = np.zeros((n_mels, 1 + n_fft // 2), dtype=np.float32) + for i in range(n_mels): + lower = -ramps[i] / fdiff[i] + upper = ramps[i + 2] / fdiff[i + 1] + weights[i] = np.maximum(0, np.minimum(lower, upper)) + + if mel_cfg.norm == "slaney": + enorm = 2.0 / (filter_freqs[2 : n_mels + 2] - filter_freqs[:n_mels]) + weights *= enorm[:, np.newaxis] + + return torch.from_numpy(weights.T).to(torch.float32) + + def _compute_magnitudes(self, stft_out, power): + import torch + + magnitudes = torch.view_as_real(stft_out) + magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) + if power != 1.0: + magnitudes = magnitudes.pow(power) + return magnitudes + + def _needs_manual_framing(self, spectrogram_config): + # Preemphasis is handled waveform-level in _pre_stft; no per-frame processing needed. + return spectrogram_config.remove_dc_offset or spectrogram_config.stft_config.left_align_fft + def _pre_stft(self, audio, *, spectrogram_config, **kwargs): + import torch + + if not isinstance(self._audio_lengths, torch.Tensor): + self._audio_lengths = torch.tensor(self._audio_lengths, device=audio.device) + preemphasis = spectrogram_config.preemphasis if preemphasis is not None: - timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < self._audio_lengths.unsqueeze(1) audio = torch.cat( [audio[:, :1], audio[:, 1:] - preemphasis * audio[:, :-1]], dim=1 ) + timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < self._audio_lengths.unsqueeze(1) audio = audio.masked_fill(~timemask, 0.0) return audio - def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): - # Detect audio lengths from zero-padded waveform for preemphasis masking and normalization - if audio.ndim == 2: - indices = torch.arange(audio.shape[-1], device=audio.device).expand_as(audio) - self._audio_lengths = indices.masked_fill(audio == 0, -1).max(dim=-1).values + 1 - - audio = self._pre_stft(audio, spectrogram_config=spectrogram_config, **kwargs) - - # Compute STFT matching the FE's magnitude computation for exact numerical match - stft_cfg = spectrogram_config.stft_config - window = torch.hann_window(stft_cfg.win_length, periodic=stft_cfg.periodic, device=audio.device) - stft = torch.stft( - audio, - stft_cfg.n_fft, - hop_length=stft_cfg.hop_length, - win_length=stft_cfg.win_length, - window=window, - return_complex=True, - pad_mode=stft_cfg.pad_mode, - ) - magnitudes = torch.view_as_real(stft) - magnitudes = torch.sqrt(magnitudes.pow(2).sum(-1)) - magnitudes = magnitudes.pow(2) - return magnitudes - def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + import torch + return torch.matmul(self.mel_filters.T, features) def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + import torch + # Match FE: log(mel_spec + guard_value) instead of log(clamp(mel_spec, guard_value)) features = torch.log(features + spectrogram_config.mel_floor) diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index 1bdd232bb372..7667c2a21737 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -127,6 +127,8 @@ def _preprocess( truncation, pad_to_multiple_of, return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, **kwargs, ) -> BatchFeature: import torch diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index 5ba746c95608..c734409952b4 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature @@ -48,27 +48,12 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): waveform_scale=32768.0, ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.window = window_function(400, "povey", periodic=False) - def _extract_fbank_features(self, waveform): - waveform = np.squeeze(waveform) * (2**15) # Kaldi compliance: 16-bit signed integers - features = spectrogram( - waveform, - self.window, - frame_length=400, - hop_length=160, - fft_length=512, - power=2.0, - center=False, - preemphasis=0.97, - mel_filters=self.mel_filters, - log_mel="log", - mel_floor=1.192092955078125e-07, - remove_dc_offset=True, - ).T - return features + """Extract log-mel filterbank features for a single waveform using the base spectrogram pipeline.""" + waveform = np.squeeze(waveform) * self.spectrogram_config.waveform_scale + features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) + # extract_spectrogram returns list of (n_mels, time); transpose to (time, n_mels) + return features[0].T def feature_normalize(self, features): normalized = [] @@ -78,7 +63,18 @@ def feature_normalize(self, features): normalized.append((f - mean) / np.sqrt(var + 1e-7)) return normalized - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, + **kwargs, + ): # Extract features from raw (unpadded) audio, then pad at feature level features = [self._extract_fbank_features(waveform) for waveform in audio] features = self.feature_normalize(features) diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index 3a075b43720b..4f91a50b1f2e 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig from ...feature_extraction_utils import BatchFeature from ...utils import is_speech_available @@ -57,31 +57,19 @@ def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): super().__init__(**kwargs) self.normalize_means = normalize_means self.normalize_vars = normalize_vars - if not is_speech_available(): - self.window = window_function(400, "povey", periodic=False) def _extract_fbank_features(self, waveform): - waveform = waveform * (2**15) # Kaldi compliance + """Extract log-mel filterbank features for a single waveform.""" + waveform = waveform * self.spectrogram_config.waveform_scale if is_speech_available(): waveform_tensor = torch.from_numpy(waveform).unsqueeze(0) features = ta_kaldi.fbank(waveform_tensor, num_mel_bins=80, sample_frequency=self.sample_rate) return features.numpy() else: waveform = np.squeeze(waveform) - return spectrogram( - waveform, - self.window, - frame_length=400, - hop_length=160, - fft_length=512, - power=2.0, - center=False, - preemphasis=0.97, - mel_filters=self.mel_filters, - log_mel="log", - mel_floor=1.192092955078125e-07, - remove_dc_offset=True, - ).T + features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) + # extract_spectrogram returns list of (n_mels, time); transpose to (time, n_mels) + return features[0].T @staticmethod def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, padding_value=0.0): @@ -95,7 +83,18 @@ def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, p x[input_length:] = padding_value return x.astype(np.float32) - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, + **kwargs, + ): # Extract features from raw (unpadded) audio, then pad at feature level features = [self._extract_fbank_features(waveform) for waveform in audio] lengths = [f.shape[0] for f in features] diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 0fbaec66b74c..df882f7c0805 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -20,7 +20,6 @@ class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): sample_rate = 24000 force_mono = True - add_channel_dim = True target_dB_FS = -25 eps = 1e-6 @@ -34,6 +33,11 @@ def _process_audio(self, audio_el): audio_el = audio_el / (max_val + self.eps) return audio_el + def _preprocess(self, audio, **kwargs): + result = super()._preprocess(audio, **kwargs) + result["audio_values"] = result["audio_values"].unsqueeze(1) + return result + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) for i, (start, end) in enumerate(audio_ranges): diff --git a/src/transformers/numpy_mel_spectrogram.py b/src/transformers/numpy_mel_spectrogram.py new file mode 100644 index 000000000000..cd90215b78a1 --- /dev/null +++ b/src/transformers/numpy_mel_spectrogram.py @@ -0,0 +1,413 @@ +"""NumPy implementation of mel spectrogram computation.""" + +import numpy as np +import librosa + + +# --- Frequency conversion utilities --- + +def hertz_to_mel(freq, mel_scale="htk"): + if mel_scale == "htk": + return 2595.0 * np.log10(1.0 + (freq / 700.0)) + elif mel_scale == "kaldi": + return 1127.0 * np.log(1.0 + (freq / 700.0)) + # slaney + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = 27.0 / np.log(6.4) + mels = 3.0 * freq / 200.0 + if isinstance(freq, np.ndarray): + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep + elif freq >= min_log_hertz: + mels = min_log_mel + np.log(freq / min_log_hertz) * logstep + return mels + + +def mel_to_hertz(mels, mel_scale="htk"): + if mel_scale == "htk": + return 700.0 * (np.power(10, mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (np.exp(mels / 1127.0) - 1.0) + # slaney + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = np.log(6.4) / 27.0 + freq = 200.0 * mels / 3.0 + if isinstance(mels, np.ndarray): + log_region = mels >= min_log_mel + freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel)) + elif mels >= min_log_mel: + freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel)) + return freq + + +# --- Filter bank --- + +def _create_triangular_filter_bank(fft_freqs, filter_freqs): + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + return np.maximum(0, np.minimum(down_slopes, up_slopes)) + + +def mel_filter_bank( + num_frequency_bins, + num_mel_filters, + min_frequency, + max_frequency, + sampling_rate, + norm=None, + mel_scale="htk", + triangularize_in_mel_space=False, + frequency_bin_mode="rfft", +): + mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale) + mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale) + mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) + filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale) + + n_fft = (num_frequency_bins - 1) * 2 + + if triangularize_in_mel_space: + fft_bin_width = sampling_rate / n_fft + fft_freqs = hertz_to_mel( + fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale + ) + filter_freqs = mel_freqs + elif frequency_bin_mode == "rfft": + fft_freqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) + else: + fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) + + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) + + if norm == "slaney": + enorm = 2.0 / ( + filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters] + ) + mel_filters *= np.expand_dims(enorm, 0) + + return mel_filters + + +# --- Window --- + +def window_function(window_length, name="hann_window", periodic=True): + N = window_length + 1 if periodic else window_length + fac = np.linspace(-np.pi, np.pi, N) + if name in ("hann", "hann_window"): + w = 0.5 + 0.5 * np.cos(fac) + elif name in ("hamming", "hamming_window"): + w = 0.54 + 0.46 * np.cos(fac) + elif name == "boxcar": + w = np.ones(N) + elif name == "povey": + w = (0.5 + 0.5 * np.cos(fac)) ** 0.85 + else: + raise ValueError(f"Unknown window function '{name}'") + return w[:window_length] if periodic else w + + +# --- Sub-methods --- + +def _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing): + if needs_manual_framing and win_length < n_fft: + frame_length = win_length + else: + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = np.pad(window, (left_pad, right_pad)) + frame_length = n_fft + return window, frame_length + + +def _frame_waveform(waveform, frame_length, hop_length, n_fft, center, pad_mode): + squeezed = waveform.ndim == 1 + if squeezed: + waveform = waveform[np.newaxis, :] + if center: + # Use librosa-compatible split-padding to match their STFT exactly + # This replicates librosa's optimization to avoid copying the entire signal + start_k = int(np.ceil(n_fft // 2 / hop_length)) + tail_k = (waveform.shape[-1] + n_fft // 2 - n_fft) // hop_length + 1 + + if tail_k <= start_k: + # Head and tail overlap, use simple full padding + waveform = np.pad(waveform, ((0, 0), (frame_length // 2, frame_length // 2)), mode=pad_mode) + num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length + frame_starts = np.arange(num_frames) * hop_length + frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) + frames = waveform[:, frame_indices] # (batch, num_frames, frame_length) + else: + # Split padding: handle head and tail separately like librosa + # Pre-padding: left pad only + padding = [(0, 0) for _ in range(waveform.ndim)] + padding[-1] = (frame_length // 2, 0) + y_pre = np.pad( + waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1], + padding, + mode=pad_mode, + ) + y_frames_pre = librosa.util.frame(y_pre, frame_length=frame_length, hop_length=hop_length) + y_frames_pre = y_frames_pre[..., :start_k] + y_frames_pre = np.moveaxis(y_frames_pre, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length) + extra = y_frames_pre.shape[-2] + + # Post-padding: right pad only + padding[-1] = (0, frame_length // 2) + y_post = np.pad( + waveform[..., (tail_k) * hop_length - n_fft // 2 :], + padding, + mode=pad_mode, + ) + y_frames_post = librosa.util.frame(y_post, frame_length=frame_length, hop_length=hop_length) + y_frames_post = np.moveaxis(y_frames_post, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length) + extra += y_frames_post.shape[-2] + + # Middle: no padding + start = start_k * hop_length - n_fft // 2 + y_frames_middle = librosa.util.frame( + waveform[..., start:], frame_length=frame_length, hop_length=hop_length + ) + y_frames_middle = np.moveaxis(y_frames_middle, -2, -1) # (batch, frame_length, num_frames) -> (batch, num_frames, frame_length) + + # Total frames + num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2] + + # Concatenate frames + frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2) + else: + # No centering: no padding + num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length + frame_starts = np.arange(num_frames) * hop_length + frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) + frames = waveform[:, frame_indices] # (batch, num_frames, frame_length) + + if squeezed: + frames = frames.squeeze(0) + return frames, num_frames + + +def _apply_frame_processing(frames, *, dither=0.0, preemphasis=None, remove_dc_offset=False): + compute_dtype = frames.dtype + if dither != 0.0: + frames = frames + dither * np.random.randn(*frames.shape).astype(compute_dtype) + if remove_dc_offset: + frames = frames - frames.mean(axis=-1, keepdims=True) + if preemphasis is not None: + preemph_src = preemphasis * frames[..., :-1] + frames[..., 1:] = frames[..., 1:] - preemph_src + frames[..., 0] = frames[..., 0] * (1 - preemphasis) + return frames + + +def _windowed_fft(frames, window, fft_length, power, normalized): + """Apply window, compute FFT, and return power spectrogram of shape (..., freq, time).""" + frames = frames * window + spec = np.fft.rfft(frames, n=fft_length, axis=-1).astype(np.complex64) + if normalized: + spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) + spec = np.abs(spec, dtype=np.float64) ** power + return np.moveaxis(spec, -1, -2) + + +def _apply_mel_scale( + spectrogram: np.ndarray, + mel_filters: np.ndarray, + mel_floor: float = 1e-10, +) -> np.ndarray: + """Apply mel filterbank to a spectrogram. + + Args: + spectrogram: Power spectrogram of shape (..., freq, time). + mel_filters: Mel filterbank of shape (freq, n_mels). + mel_floor: Minimum value for clamping. + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + # (n_mels, freq) @ (..., freq, time) -> (..., n_mels, time) + mel_spec = np.matmul(mel_filters.T, spectrogram) + return np.maximum(mel_floor, mel_spec) + + +# --- Main function --- + +def mel_spectrogram( + waveform: np.ndarray, + sampling_rate: int, + *, + n_fft: int = 400, + win_length: int | None = None, + hop_length: int | None = None, + window_fn: str = "hann_window", + wkwargs: dict | None = None, + power: float = 2.0, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + periodic: bool = True, + # mel scale kwargs + n_mels: int = 128, + f_min: float = 0.0, + f_max: float | None = None, + mel_scale: str = "htk", + norm: str | None = None, + triangularize_in_mel_space: bool = False, + # kaldi-specific kwargs + dither: float = 0.0, + preemphasis: float | None = None, + remove_dc_offset: bool = False, + mel_floor: float = 1e-10, +) -> np.ndarray: + """Compute mel spectrogram using NumPy. + + Args: + waveform: Input waveform of shape (..., time). + sampling_rate: Sample rate in Hz. + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + if f_max is None: + f_max = sampling_rate / 2.0 + + # --- STFT --- + if win_length is None: + win_length = n_fft + if hop_length is None: + hop_length = win_length // 2 + window = window_function(win_length, name=window_fn, periodic=periodic) + + needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset + window, frame_length = _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + is_1d = waveform.ndim == 1 + if is_1d: + waveform = waveform[np.newaxis, :] + leading_shape = waveform.shape[:-1] + waveform = waveform.reshape(-1, waveform.shape[-1]) + frames, num_frames = _frame_waveform(waveform, frame_length, hop_length, n_fft, center, pad_mode) + compute_dtype = np.result_type(waveform.dtype, window.dtype) + frames = frames.astype(compute_dtype, copy=False) + frames = _apply_frame_processing(frames, dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset) + spectrogram = _windowed_fft(frames, window, n_fft, power, normalized) + + num_frequency_bins = n_fft // 2 + 1 + spectrogram = spectrogram.reshape(*leading_shape, num_frequency_bins, num_frames) + if is_1d: + spectrogram = spectrogram.squeeze(0) + + num_frequency_bins = spectrogram.shape[-2] + mel_fb = mel_filter_bank( + num_frequency_bins, n_mels, f_min, f_max, sampling_rate, + norm=norm, mel_scale=mel_scale, + triangularize_in_mel_space=triangularize_in_mel_space, + ) + + return _apply_mel_scale(spectrogram, mel_fb, mel_floor=mel_floor) + + +class MelSpectrogram: + """Cached mel spectrogram — precomputes window and mel filterbank. + + Same API and exact same results as the functional ``mel_spectrogram``, but + avoids recomputing the window and mel filterbank on every call. + + Usage:: + + transform = MelSpectrogram(sampling_rate=16000, n_fft=1024, n_mels=80) + mel = transform(waveform) # fast repeated calls + """ + + def __init__( + self, + sampling_rate: int, + *, + n_fft: int = 400, + win_length: int | None = None, + hop_length: int | None = None, + window_fn: str = "hann_window", + wkwargs: dict | None = None, + power: float = 2.0, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + periodic: bool = True, + n_mels: int = 128, + f_min: float = 0.0, + f_max: float | None = None, + mel_scale: str = "htk", + norm: str | None = None, + triangularize_in_mel_space: bool = False, + dither: float = 0.0, + preemphasis: float | None = None, + remove_dc_offset: bool = False, + mel_floor: float = 1e-10, + ): + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.win_length = win_length if win_length is not None else n_fft + self.hop_length = hop_length if hop_length is not None else self.win_length // 2 + self.power = power + self.center = center + self.pad_mode = pad_mode + self.normalized = normalized + self.periodic = periodic + self.n_mels = n_mels + self.f_min = f_min + self.f_max = f_max if f_max is not None else sampling_rate / 2.0 + self.mel_floor = mel_floor + self.dither = dither + self.preemphasis = preemphasis + self.remove_dc_offset = remove_dc_offset + self.window_fn = window_fn + + # Precompute window + needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset + window = window_function(self.win_length, name=window_fn, periodic=periodic) + self._window, self._frame_length = _prepare_window_and_framing( + window, self.win_length, n_fft, needs_manual_framing, + ) + + # Precompute mel filterbank + num_frequency_bins = n_fft // 2 + 1 + self._mel_fb = mel_filter_bank( + num_frequency_bins, n_mels, self.f_min, self.f_max, sampling_rate, + norm=norm, mel_scale=mel_scale, + triangularize_in_mel_space=triangularize_in_mel_space, + ) + + def __call__(self, waveform: np.ndarray) -> np.ndarray: + """Compute mel spectrogram. + + Args: + waveform: Input of shape (..., time). + + Returns: + Mel spectrogram of shape (..., n_mels, time). + """ + is_1d = waveform.ndim == 1 + if is_1d: + waveform = waveform[np.newaxis, :] + leading_shape = waveform.shape[:-1] + waveform = waveform.reshape(-1, waveform.shape[-1]) + frames, num_frames = _frame_waveform( + waveform, self._frame_length, self.hop_length, self.n_fft, self.center, self.pad_mode, + ) + compute_dtype = np.result_type(waveform.dtype, self._window.dtype) + frames = frames.astype(compute_dtype, copy=False) + frames = _apply_frame_processing( + frames, dither=self.dither, preemphasis=self.preemphasis, remove_dc_offset=self.remove_dc_offset, + ) + spectrogram = _windowed_fft(frames, self._window, self.n_fft, self.power, self.normalized) + + num_frequency_bins = self.n_fft // 2 + 1 + spectrogram = spectrogram.reshape(*leading_shape, num_frequency_bins, num_frames) + if is_1d: + spectrogram = spectrogram.squeeze(0) + + return _apply_mel_scale(spectrogram, self._mel_fb, mel_floor=self.mel_floor) diff --git a/src/transformers/torch_mel_spectrogram.py b/src/transformers/torch_mel_spectrogram.py index 8a59d0b61fd2..3d48f2b8192a 100644 --- a/src/transformers/torch_mel_spectrogram.py +++ b/src/transformers/torch_mel_spectrogram.py @@ -29,10 +29,11 @@ def hertz_to_mel(freq: torch.Tensor, mel_scale: str = "htk") -> torch.Tensor: elif mel_scale == "kaldi": return 1127.0 * torch.log(1.0 + freq / 700.0) # slaney + f_sp = 200.0 / 3 min_log_hertz = 1000.0 - min_log_mel = 15.0 + min_log_mel = min_log_hertz / f_sp logstep = 27.0 / torch.log(torch.tensor(6.4)) - mels = 3.0 * freq / 200.0 + mels = freq / f_sp log_region = freq >= min_log_hertz mels[log_region] = min_log_mel + torch.log(freq[log_region] / min_log_hertz) * logstep return mels @@ -218,54 +219,7 @@ def window_function(window_length, name="hann_window", periodic=True, wkwargs=No # --- Sub-methods --- -def _extract_spectrogram( - waveform: torch.Tensor, - sampling_rate: int, - *, - n_fft: int = 400, - win_length: int | None = None, - hop_length: int | None = None, - window_fn: str = "hann_window", - wkwargs: dict | None = None, - power: float = 2.0, - center: bool = True, - pad_mode: str = "reflect", - normalized: bool = False, - pad: int = 0, - periodic: bool = True, - dither: float = 0.0, - preemphasis: float | None = None, - remove_dc_offset: bool = False, - computation_dtype: "torch.dtype | None" = None, - left_align_fft: bool = False, -) -> torch.Tensor: - """Compute the (power) spectrogram via STFT. - - Args: - waveform: Input waveform of shape (..., time). - sampling_rate: Sample rate in Hz. - left_align_fft: If True, use manual framing with unfold(win_length) + zero-pad - right + rfft(n_fft). This left-aligns the window in the FFT buffer (kaldi - style), instead of center-padding it (torch.stft default). - - Returns: - Power spectrogram of shape (..., freq, time). - """ - if win_length is None: - win_length = n_fft - if hop_length is None: - hop_length = win_length // 2 - if computation_dtype is not None: - waveform = waveform.to(computation_dtype) - device = waveform.device - dtype = waveform.dtype - - needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset or left_align_fft - - window_wkwargs = {**(wkwargs or {}), "dtype": dtype} - window = window_function(win_length, name=window_fn, periodic=periodic, wkwargs=window_wkwargs) - window = window.to(device=device) - +def _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing): if needs_manual_framing and win_length < n_fft: frame_length = win_length else: @@ -274,40 +228,20 @@ def _extract_spectrogram( right_pad = n_fft - win_length - left_pad window = torch.nn.functional.pad(window, (left_pad, right_pad)) frame_length = n_fft + return window, frame_length - fft_length = n_fft - num_frequency_bins = (fft_length // 2) + 1 - - is_1d = waveform.ndim == 1 - if is_1d: - waveform = waveform.unsqueeze(0) - - leading_shape = waveform.shape[:-1] - waveform = waveform.reshape(-1, waveform.shape[-1]) - - if pad > 0: - waveform = torch.nn.functional.pad(waveform, (pad, pad)) - - if needs_manual_framing: - result = _manual_stft( - waveform, window, frame_length, hop_length, fft_length, - num_frequency_bins, power, normalized, center, pad_mode, - dither, preemphasis, remove_dc_offset, - ) - else: - result = _torch_stft( - waveform, window, frame_length, hop_length, fft_length, - power, normalized, center, pad_mode, - ) - - result = result.reshape(*leading_shape, result.shape[-2], result.shape[-1]) - - if is_1d: - result = result.squeeze(0) - if computation_dtype is not None: - return result - return result.float() +def _apply_frame_processing(frames, *, dither=0.0, preemphasis=None, remove_dc_offset=False): + if dither != 0.0: + frames = frames + dither * torch.randn_like(frames) + if remove_dc_offset: + frames = frames - frames.mean(dim=-1, keepdim=True) + if preemphasis is not None: + frames = torch.cat([ + frames[..., :1] * (1 - preemphasis), + frames[..., 1:] - preemphasis * frames[..., :-1], + ], dim=-1) + return frames def _apply_mel_scale( @@ -332,9 +266,9 @@ def _apply_mel_scale( def _torch_stft( waveform, window, frame_length, hop_length, fft_length, - power, normalized, center, pad_mode, + normalized, center, pad_mode, ): - """Fast path using torch.stft. Returns power spectrogram of shape (batch, freq, time).""" + """Fast path using torch.stft. Returns complex STFT of shape (batch, freq, time).""" stft_out = torch.stft( waveform, n_fft=fft_length, @@ -348,13 +282,13 @@ def _torch_stft( ) if normalized: stft_out = stft_out / window.pow(2.0).sum().sqrt() - return stft_out.abs() ** power + return stft_out def _manual_stft( waveform, window, frame_length, hop_length, fft_length, num_frequency_bins, power, normalized, center, pad_mode, - dither, preemphasis, remove_dc_offset, + apply_frame_processing=None, ): """Manual framing STFT for kaldi-specific features. Returns power spectrogram of shape (batch, freq, time).""" if center: @@ -365,17 +299,8 @@ def _manual_stft( # Extract all frames at once: (batch, num_frames, frame_length) frames = waveform.unfold(-1, frame_length, hop_length) - if dither != 0.0: - frames = frames + dither * torch.randn_like(frames) - - if remove_dc_offset: - frames = frames - frames.mean(dim=-1, keepdim=True) - - if preemphasis is not None: - frames = torch.cat([ - frames[..., :1] * (1 - preemphasis), - frames[..., 1:] - preemphasis * frames[..., :-1], - ], dim=-1) + if apply_frame_processing is not None: + frames = apply_frame_processing(frames) frames = frames * window @@ -410,7 +335,6 @@ def mel_spectrogram( center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - pad: int = 0, periodic: bool = True, # mel scale kwargs n_mels: int = 128, @@ -437,13 +361,45 @@ def mel_spectrogram( if f_max is None: f_max = sampling_rate / 2.0 - spectrogram = _extract_spectrogram( - waveform, sampling_rate, - n_fft=n_fft, win_length=win_length, hop_length=hop_length, - window_fn=window_fn, wkwargs=wkwargs, power=power, - center=center, pad_mode=pad_mode, normalized=normalized, pad=pad, periodic=periodic, - dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset, - ) + # --- STFT --- + if win_length is None: + win_length = n_fft + if hop_length is None: + hop_length = win_length // 2 + device = waveform.device + dtype = waveform.dtype + + needs_manual_framing = (dither != 0.0) or (preemphasis is not None) or remove_dc_offset + + window_wkwargs = {**(wkwargs or {}), "dtype": dtype} + window = window_function(win_length, name=window_fn, periodic=periodic, wkwargs=window_wkwargs) + window = window.to(device=device) + window, frame_length = _prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + is_1d = waveform.ndim == 1 + if is_1d: + waveform = waveform.unsqueeze(0) + leading_shape = waveform.shape[:-1] + waveform = waveform.reshape(-1, waveform.shape[-1]) + if needs_manual_framing: + frame_proc = lambda f: _apply_frame_processing( + f, dither=dither, preemphasis=preemphasis, remove_dc_offset=remove_dc_offset, + ) + spectrogram = _manual_stft( + waveform, window, frame_length, hop_length, n_fft, + n_fft // 2 + 1, power, normalized, center, pad_mode, + apply_frame_processing=frame_proc, + ) + else: + spectrogram = _torch_stft( + waveform, window, frame_length, hop_length, n_fft, + power, normalized, center, pad_mode, + ) + + spectrogram = spectrogram.reshape(*leading_shape, spectrogram.shape[-2], spectrogram.shape[-1]) + if is_1d: + spectrogram = spectrogram.squeeze(0) + spectrogram = spectrogram.float() num_frequency_bins = spectrogram.shape[-2] mel_filters = mel_filter_bank_torch( @@ -481,7 +437,6 @@ def __init__( center: bool = True, pad_mode: str = "reflect", normalized: bool = False, - pad: int = 0, periodic: bool = True, n_mels: int = 128, f_min: float = 0.0, @@ -503,7 +458,6 @@ def __init__( self.center = center self.pad_mode = pad_mode self.normalized = normalized - self.pad = pad self.n_mels = n_mels self.f_min = f_min self.f_max = f_max if f_max is not None else sampling_rate / 2.0 @@ -516,14 +470,7 @@ def __init__( # Build window window = window_function(self.win_length, name=window_fn, periodic=periodic, wkwargs=wkwargs) - if self._needs_manual_framing and self.win_length < n_fft: - self._frame_length = self.win_length - else: - if self.win_length < n_fft: - left_pad = (n_fft - self.win_length) // 2 - right_pad = n_fft - self.win_length - left_pad - window = torch.nn.functional.pad(window, (left_pad, right_pad)) - self._frame_length = n_fft + window, self._frame_length = _prepare_window_and_framing(window, self.win_length, n_fft, self._needs_manual_framing) self.register_buffer("window", window) # Build mel filterbank @@ -551,15 +498,15 @@ def forward(self, waveform: torch.Tensor) -> torch.Tensor: leading_shape = waveform.shape[:-1] waveform = waveform.reshape(-1, waveform.shape[-1]) - if self.pad > 0: - waveform = torch.nn.functional.pad(waveform, (self.pad, self.pad)) - if self._needs_manual_framing: + frame_proc = lambda f: _apply_frame_processing( + f, dither=self.dither, preemphasis=self.preemphasis, remove_dc_offset=self.remove_dc_offset, + ) spec = _manual_stft( waveform, self.window, self._frame_length, self.hop_length, self.n_fft, self.n_fft // 2 + 1, self.power, self.normalized, - self.center, self.pad_mode, self.dither, self.preemphasis, - self.remove_dc_offset, + self.center, self.pad_mode, + apply_frame_processing=frame_proc, ) else: spec = _torch_stft( From 760ef655476d8e6b651cc930d276b3108d369feb Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Thu, 2 Apr 2026 16:17:23 +0200 Subject: [PATCH 23/28] udpates --- .gitignore | 4 + src/transformers/audio_processing_backends.py | 509 ++++++++++++++---- src/transformers/audio_processing_utils.py | 134 ++++- src/transformers/audio_utils.py | 4 +- ...rocessing_audio_spectrogram_transformer.py | 37 +- .../models/clap/audio_processing_clap.py | 260 +++------ .../models/clvp/audio_processing_clvp.py | 40 +- .../models/dac/audio_processing_dac.py | 7 - .../models/dia/audio_processing_dia.py | 7 - .../encodec/audio_processing_encodec.py | 7 - .../gemma3n/audio_processing_gemma3n.py | 28 +- .../audio_processing_granite_speech.py | 2 +- .../audio_processing_kyutai_speech_to_text.py | 41 +- .../models/lasr/audio_processing_lasr.py | 7 +- .../parakeet/audio_processing_parakeet.py | 49 +- .../audio_processing_phi4_multimodal.py | 145 +++-- .../audio_processing_speech_to_text.py | 2 +- .../univnet/audio_processing_univnet.py | 119 ++-- ...processing_vibevoice_acoustic_tokenizer.py | 6 - .../audio_processing_voxtral_realtime.py | 63 +-- .../whisper/audio_processing_whisper.py | 8 +- 21 files changed, 799 insertions(+), 680 deletions(-) diff --git a/.gitignore b/.gitignore index 75f5a9998310..6ed5479ab0c4 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,7 @@ tags # Cursor IDE files .cursor/ test-results/ +src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +.gitignore +tests/test_wav2vec2_whisper.py +run_preprocessing_tests.sh diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 90b9b66b4c95..5f5795ddaa42 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -14,9 +14,11 @@ # limitations under the License. +import math + import numpy as np +import librosa -from . import numpy_mel_spectrogram as _np_spec from .audio_processing_utils import BaseAudioProcessor from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db from .feature_extraction_utils import BatchFeature @@ -28,7 +30,93 @@ if is_torch_available(): import torch - from . import torch_mel_spectrogram as _torch_spec + + +# ── NumPy frequency conversion utilities ────────────────────────────── + +def _np_hertz_to_mel(freq, mel_scale="htk"): + if mel_scale == "htk": + return 2595.0 * np.log10(1.0 + (freq / 700.0)) + elif mel_scale == "kaldi": + return 1127.0 * np.log(1.0 + (freq / 700.0)) + # slaney + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = 27.0 / np.log(6.4) + mels = 3.0 * freq / 200.0 + if isinstance(freq, np.ndarray): + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep + elif freq >= min_log_hertz: + mels = min_log_mel + np.log(freq / min_log_hertz) * logstep + return mels + + +def _np_mel_to_hertz(mels, mel_scale="htk"): + if mel_scale == "htk": + return 700.0 * (np.power(10, mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (np.exp(mels / 1127.0) - 1.0) + # slaney + min_log_hertz = 1000.0 + min_log_mel = 15.0 + logstep = np.log(6.4) / 27.0 + freq = 200.0 * mels / 3.0 + if isinstance(mels, np.ndarray): + log_region = mels >= min_log_mel + freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel)) + elif mels >= min_log_mel: + freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel)) + return freq + + +# ── Torch frequency conversion utilities ────────────────────────────── + +def _torch_hertz_to_mel_scalar(freq: float, mel_scale: str = "htk") -> float: + if mel_scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + elif mel_scale == "kaldi": + return 1127.0 * math.log(1.0 + freq / 700.0) + # slaney + f_sp = 200.0 / 3 + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - 0.0) / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + return min_log_mel + math.log(freq / min_log_hz) / logstep + return (freq - 0.0) / f_sp + + +def _torch_hertz_to_mel(freq: "torch.Tensor", mel_scale: str = "htk") -> "torch.Tensor": + if mel_scale == "htk": + return 2595.0 * torch.log10(1.0 + freq / 700.0) + elif mel_scale == "kaldi": + return 1127.0 * torch.log(1.0 + freq / 700.0) + # slaney + f_sp = 200.0 / 3 + min_log_hertz = 1000.0 + min_log_mel = min_log_hertz / f_sp + logstep = 27.0 / torch.log(torch.tensor(6.4)) + mels = freq / f_sp + log_region = freq >= min_log_hertz + mels[log_region] = min_log_mel + torch.log(freq[log_region] / min_log_hertz) * logstep + return mels + + +def _torch_mel_to_hertz(mels: "torch.Tensor", mel_scale: str = "htk") -> "torch.Tensor": + if mel_scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + elif mel_scale == "kaldi": + return 700.0 * (torch.exp(mels / 1127.0) - 1.0) + # slaney + f_sp = 200.0 / 3 + min_log_hz = 1000.0 + min_log_mel = (min_log_hz - 0.0) / f_sp + logstep = math.log(6.4) / 27.0 + freq = 0.0 + f_sp * mels + log_region = mels >= min_log_mel + freq[log_region] = min_log_hz * torch.exp(logstep * (mels[log_region] - min_log_mel)) + return freq class NumpyAudioBackend(BaseAudioProcessor): @@ -80,38 +168,124 @@ def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray: return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) - def _stft( - self, - audio: list[np.ndarray], - *, - spectrogram_config: SpectrogramConfig, - **kwargs, - ) -> list[np.ndarray]: - """Compute the (power) spectrogram via STFT using the numpy backend.""" - stft_cfg = spectrogram_config.stft_config - n_fft = stft_cfg.n_fft - win_length = stft_cfg.win_length or n_fft - hop_length = stft_cfg.hop_length or win_length // 2 + def _create_stft_window(self, win_length, stft_cfg, audio): + N = win_length + 1 if stft_cfg.periodic else win_length + fac = np.linspace(-np.pi, np.pi, N) + name = stft_cfg.window_fn + if name in ("hann", "hann_window"): + w = 0.5 + 0.5 * np.cos(fac) + elif name in ("hamming", "hamming_window"): + w = 0.54 + 0.46 * np.cos(fac) + elif name == "boxcar": + w = np.ones(N) + elif name == "povey": + w = (0.5 + 0.5 * np.cos(fac)) ** 0.85 + else: + raise ValueError(f"Unknown window function '{name}'") + return w[:win_length] if stft_cfg.periodic else w + + def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing): + if needs_manual_framing and win_length < n_fft: + frame_length = win_length + else: + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = np.pad(window, (left_pad, right_pad)) + frame_length = n_fft + return window, frame_length + + def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad_mode): + squeezed = waveform.ndim == 1 + if squeezed: + waveform = waveform[np.newaxis, :] + if center: + start_k = int(np.ceil(n_fft // 2 / hop_length)) + tail_k = (waveform.shape[-1] + n_fft // 2 - n_fft) // hop_length + 1 + + if tail_k <= start_k: + waveform = np.pad(waveform, ((0, 0), (frame_length // 2, frame_length // 2)), mode=pad_mode) + num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length + frame_starts = np.arange(num_frames) * hop_length + frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) + frames = waveform[:, frame_indices] + else: + padding = [(0, 0) for _ in range(waveform.ndim)] + padding[-1] = (frame_length // 2, 0) + y_pre = np.pad( + waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1], + padding, + mode=pad_mode, + ) + y_frames_pre = librosa.util.frame(y_pre, frame_length=frame_length, hop_length=hop_length) + y_frames_pre = y_frames_pre[..., :start_k] + y_frames_pre = np.moveaxis(y_frames_pre, -2, -1) + extra = y_frames_pre.shape[-2] + + padding[-1] = (0, frame_length // 2) + y_post = np.pad( + waveform[..., (tail_k) * hop_length - n_fft // 2 :], + padding, + mode=pad_mode, + ) + y_frames_post = librosa.util.frame(y_post, frame_length=frame_length, hop_length=hop_length) + y_frames_post = np.moveaxis(y_frames_post, -2, -1) + extra += y_frames_post.shape[-2] + + start = start_k * hop_length - n_fft // 2 + y_frames_middle = librosa.util.frame( + waveform[..., start:], frame_length=frame_length, hop_length=hop_length + ) + y_frames_middle = np.moveaxis(y_frames_middle, -2, -1) + + num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2] + frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2) + else: + num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length + frame_starts = np.arange(num_frames) * hop_length + frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) + frames = waveform[:, frame_indices] - window = _np_spec.window_function(win_length, name=stft_cfg.window_fn, periodic=stft_cfg.periodic) - needs_manual_framing = (spectrogram_config.preemphasis is not None) or spectrogram_config.remove_dc_offset - window, frame_length = _np_spec._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + if squeezed: + frames = frames.squeeze(0) + return frames, num_frames - frames, num_frames = _np_spec._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode) + def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + frames, _ = self._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode) compute_dtype = np.result_type(audio.dtype, window.dtype) - frames = frames.astype(compute_dtype, copy=False) + return frames.astype(compute_dtype, copy=False) - frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + frames = frames * window + spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64) + if stft_cfg.normalized: + spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) + return np.moveaxis(spec, -1, -2) - return _np_spec._windowed_fft(frames, window, n_fft, stft_cfg.power, stft_cfg.normalized) + def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + frames, _ = self._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode) + compute_dtype = np.result_type(audio.dtype, window.dtype) + frames = frames.astype(compute_dtype, copy=False) + frames = frames * window + spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64) + if stft_cfg.normalized: + spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) + return np.moveaxis(spec, -1, -2) + + def _compute_magnitudes(self, stft_out, power): + return np.abs(stft_out, dtype=np.float64) ** power def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): """Apply per-frame signal conditioning using the numpy backend.""" - return _np_spec._apply_frame_processing( - frames, - preemphasis=spectrogram_config.preemphasis, - remove_dc_offset=spectrogram_config.remove_dc_offset, - ) + compute_dtype = frames.dtype + if spectrogram_config.remove_dc_offset: + frames = frames - frames.mean(axis=-1, keepdims=True) + preemphasis = spectrogram_config.preemphasis + if preemphasis is not None: + preemph_src = preemphasis * frames[..., :-1] + frames[..., 1:] = frames[..., 1:] - preemph_src + frames[..., 0] = frames[..., 0] * (1 - preemphasis) + return frames def _apply_mel_scale( self, @@ -121,7 +295,11 @@ def _apply_mel_scale( **kwargs, ) -> list[np.ndarray]: """Apply mel filterbank to spectrogram features using the numpy backend.""" - return _np_spec._apply_mel_scale(features, self.mel_filters, mel_floor=spectrogram_config.mel_floor) + if spectrogram_config.mel_scale_config.matmul_order == "features_first": + mel_spec = np.matmul(features, self.mel_filters) + else: + mel_spec = np.matmul(self.mel_filters.T, features) + return np.maximum(spectrogram_config.mel_floor, mel_spec) def _normalize_magnitude( self, @@ -168,17 +346,42 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - return _np_spec.mel_filter_bank( - num_frequency_bins=1 + stft_cfg.n_fft // 2, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - frequency_bin_mode=mel_cfg.frequency_bin_mode, - ) + num_frequency_bins = 1 + stft_cfg.n_fft // 2 + num_mel_filters = mel_cfg.n_mels + min_frequency = mel_cfg.f_min + max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 + sampling_rate = self.sample_rate + + mel_min = _np_hertz_to_mel(min_frequency, mel_scale=mel_cfg.mel_scale) + mel_max = _np_hertz_to_mel(max_frequency, mel_scale=mel_cfg.mel_scale) + mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) + filter_freqs = _np_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale) + + n_fft = (num_frequency_bins - 1) * 2 + + if mel_cfg.triangularize_in_mel_space: + fft_bin_width = sampling_rate / n_fft + fft_freqs = _np_hertz_to_mel( + fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_cfg.mel_scale + ) + filter_freqs = mel_freqs + elif mel_cfg.frequency_bin_mode == "rfft": + fft_freqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) + else: + fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) + + # Triangular filter bank + filter_diff = np.diff(filter_freqs) + slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + mel_filters = np.maximum(0, np.minimum(down_slopes, up_slopes)) + + if mel_cfg.norm == "slaney": + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters *= np.expand_dims(enorm, 0) + + return mel_filters def _to_batch(self, audio): return np.stack(audio) @@ -251,66 +454,70 @@ def _pad_single(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": return F.pad(audio, pad_args, "constant", self.padding_value) def _needs_manual_framing(self, spectrogram_config): - """Whether the STFT requires manual framing (unfold-based) instead of torch.stft. - - Manual framing is needed when per-frame processing must happen between - frame extraction and windowing (e.g. per-frame preemphasis, DC offset removal, - or left-aligned FFT padding). - - Override in model-specific processors that handle preemphasis at the - waveform level (in ``_pre_stft``) and don't need per-frame processing. - """ - return ( - (spectrogram_config.preemphasis is not None) - or spectrogram_config.remove_dc_offset - or spectrogram_config.stft_config.left_align_fft - ) - - def _stft( - self, - audio: list["torch.Tensor"], # TODO: this can be either a audio or batch of audio and this should be documented - *, - spectrogram_config: SpectrogramConfig, - **kwargs, - ) -> list["torch.Tensor"]: - """Compute the (power) spectrogram via STFT using the torch backend.""" - stft_cfg = spectrogram_config.stft_config - computation_dtype = ( - getattr(torch, spectrogram_config.computation_dtype) - if spectrogram_config.computation_dtype - else None - ) - - n_fft = stft_cfg.n_fft - win_length = stft_cfg.win_length or n_fft - hop_length = stft_cfg.hop_length or win_length // 2 - - if computation_dtype is not None: - audio = audio.to(computation_dtype) - - needs_manual_framing = self._needs_manual_framing(spectrogram_config) - - window_wkwargs = {**(stft_cfg.wkwargs or {}), "dtype": audio.dtype} - window = _torch_spec.window_function(win_length, name=stft_cfg.window_fn, periodic=stft_cfg.periodic, wkwargs=window_wkwargs) - window = window.to(device=audio.device) - window, frame_length = _torch_spec._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + """Extends the base check with ``left_align_fft`` which also requires manual framing.""" + return super()._needs_manual_framing(spectrogram_config) or spectrogram_config.stft_config.left_align_fft + + def _create_stft_window(self, win_length, stft_cfg, audio): + dtype = getattr(torch, stft_cfg.window_dtype) if stft_cfg.window_dtype else audio.dtype + wkwargs = {**(stft_cfg.wkwargs or {}), "dtype": dtype} + name = stft_cfg.window_fn + if name in ("hann", "hann_window"): + window = torch.hann_window(win_length, periodic=stft_cfg.periodic, **wkwargs) + elif name in ("hamming", "hamming_window"): + window = torch.hamming_window(win_length, periodic=stft_cfg.periodic, **wkwargs) + elif name == "boxcar": + window = torch.ones(win_length) + elif name == "povey": + window = torch.hann_window(win_length, periodic=stft_cfg.periodic, **wkwargs).pow(0.85) + else: + raise ValueError(f"Unknown window function '{name}'") + return window.to(device=audio.device) - if needs_manual_framing: - apply_fp = lambda frames: self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) - magnitudes = _torch_spec._manual_stft( - audio, window, frame_length, hop_length, n_fft, - n_fft // 2 + 1, stft_cfg.power, stft_cfg.normalized, - stft_cfg.center, stft_cfg.pad_mode, - apply_frame_processing=apply_fp, - ) + def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing): + if needs_manual_framing and win_length < n_fft: + frame_length = win_length else: - stft_out = _torch_spec._torch_stft( - audio, window, frame_length, hop_length, n_fft, - stft_cfg.normalized, stft_cfg.center, stft_cfg.pad_mode, + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = torch.nn.functional.pad(window, (left_pad, right_pad)) + frame_length = n_fft + return window, frame_length + + def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + if stft_cfg.center: + audio = torch.nn.functional.pad( + audio, (frame_length // 2, frame_length // 2), mode=stft_cfg.pad_mode ) - magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power) + return audio.unfold(-1, frame_length, hop_length) + + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + frames = frames * window + if frame_length < n_fft: + frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length)) + spec = torch.fft.rfft(frames, n=n_fft) + if stft_cfg.normalized: + spec = spec / window.pow(2.0).sum().sqrt() + return spec.transpose(-2, -1) + + def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + stft_out = torch.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=frame_length, + window=window, + center=stft_cfg.center, + pad_mode=stft_cfg.pad_mode, + normalized=False, + return_complex=True, + ) + if stft_cfg.normalized: + stft_out = stft_out / window.pow(2.0).sum().sqrt() + return stft_out - if computation_dtype is not None: + def _cast_stft_output(self, magnitudes, spectrogram_config): + if spectrogram_config.computation_dtype: return magnitudes return magnitudes.float() @@ -320,11 +527,15 @@ def _compute_magnitudes(self, stft_out, power): def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): """Apply per-frame signal conditioning using the torch backend.""" - return _torch_spec._apply_frame_processing( - frames, - preemphasis=spectrogram_config.preemphasis, - remove_dc_offset=spectrogram_config.remove_dc_offset, - ) + if spectrogram_config.remove_dc_offset: + frames = frames - frames.mean(dim=-1, keepdim=True) + preemphasis = spectrogram_config.preemphasis + if preemphasis is not None: + frames = torch.cat([ + frames[..., :1] * (1 - preemphasis), + frames[..., 1:] - preemphasis * frames[..., :-1], + ], dim=-1) + return frames def _apply_mel_scale( self, @@ -334,7 +545,12 @@ def _apply_mel_scale( **kwargs, ) -> list["torch.Tensor"]: """Apply mel filterbank to spectrogram features using the torch backend.""" - return _torch_spec._apply_mel_scale(features, self.mel_filters, mel_floor=spectrogram_config.mel_floor) + mel_filters = self.mel_filters.to(device=features.device) + if spectrogram_config.mel_scale_config.matmul_order == "features_first": + mel_spec = torch.matmul(features.transpose(-2, -1), mel_filters) + else: + mel_spec = torch.matmul(mel_filters.T, features) + return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) def _normalize_magnitude( self, @@ -395,19 +611,86 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None - mel_filters = _torch_spec.mel_filter_bank_torch( - num_frequency_bins=1 + stft_cfg.n_fft // 2, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - frequency_bin_mode=mel_cfg.frequency_bin_mode, - computation_dtype=computation_dtype, - bands_to_zero=mel_cfg.bands_to_zero, - ) + num_frequency_bins = 1 + stft_cfg.n_fft // 2 + num_mel_filters = mel_cfg.n_mels + min_frequency = mel_cfg.f_min + max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 + sampling_rate = self.sample_rate + + if mel_cfg.triangularize_in_mel_space and mel_cfg.bands_to_zero == 0: + # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks. + n_fft = (num_frequency_bins - 1) * 2 + num_fft_bins = n_fft // 2 + fft_bin_width = sampling_rate / n_fft + + mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0) + mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0) + mel_delta = (mel_high - mel_low) / (num_mel_filters + 1) + + bin_idx = torch.arange(num_mel_filters).unsqueeze(1) + left_mel = mel_low + bin_idx * mel_delta + center_mel = mel_low + (bin_idx + 1.0) * mel_delta + right_mel = mel_low + (bin_idx + 2.0) * mel_delta + + mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log() + mel = mel.unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0) + + mel_filters = banks.T + elif mel_cfg.triangularize_in_mel_space: + # Kaldi-style with bands_to_zero > 0 + n_fft = (num_frequency_bins - 1) * 2 + mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) + mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) + mel_delta = (mel_max - mel_min) / (num_mel_filters + 1) + bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1) + left_mel = mel_min + bin_idx * mel_delta + center_mel = mel_min + (bin_idx + 1.0) * mel_delta + right_mel = mel_min + (bin_idx + 2.0) * mel_delta + + fft_bin_width = sampling_rate / n_fft + hz_freqs = fft_bin_width * torch.arange(mel_cfg.bands_to_zero, num_frequency_bins, dtype=computation_dtype) + mel = _torch_hertz_to_mel(hz_freqs, mel_scale=mel_cfg.mel_scale).unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + mel_filters = torch.max(torch.zeros(1, dtype=computation_dtype), torch.min(up_slope, down_slope)) + + mel_filters = mel_filters.T + if mel_cfg.bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) + else: + n_fft = (num_frequency_bins - 1) * 2 + mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) + mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) + mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype) + filter_freqs = _torch_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale) + + if mel_cfg.frequency_bin_mode == "rfft": + fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) + else: + fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins) + if computation_dtype is not None: + fft_freqs = fft_freqs.to(computation_dtype) + + # Triangular filter bank + filter_diff = filter_freqs[1:] - filter_freqs[:-1] + slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + mel_filters = torch.clamp(torch.minimum(down_slopes, up_slopes), min=0) + + if mel_cfg.norm == "slaney": + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters = mel_filters * enorm.unsqueeze(0) + + if mel_cfg.bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) + # When computation_dtype is set only on the mel config (not on the # spectrogram config), the filters were computed in high precision for # accuracy but the spectrogram will be in the default dtype — cast back. diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 94eb292de95c..c1110bb993de 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -60,7 +60,7 @@ class BaseAudioProcessor(AudioProcessingMixin): truncation = None pad_to_multiple_of = None - return_attention_mask = True # TODO: we should either get a more appropriate name, either always return input mask + return_padding_mask = True spectrogram_config = None do_extract_spectrogram = None @@ -157,16 +157,15 @@ def _preprocess( # pad and truncate audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) padded_length = audio[0].shape[-1] - self._audio_lengths = [end - start for start, end in audio_ranges] if do_extract_spectrogram: audio = self._to_batch(audio) if do_batch_spectrogram else audio - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config) + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, audio_ranges=audio_ranges) output = {"audio_features": feature} else: output = {"audio_values": self._to_batch(audio)} - if self.return_attention_mask: + if self.return_padding_mask: output.update(self._get_mask( audio_ranges, padded_length, do_extract_spectrogram=do_extract_spectrogram, spectrogram_config=spectrogram_config )) @@ -329,48 +328,127 @@ def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | # The full feature-extraction pipeline executed by `extract_spectrogram`: # # 1. _extract_spectrogram (STFT → power/magnitude spectrogram) - # a. _pre_stft – waveform-level pre-processing (hook, no-op by default) - # b. _prepare_window_and_framing – build/pad window, decide frame length - # c. _frame_waveform – slice waveform into overlapping frames - # d. _apply_frame_processing – per-frame conditioning: dither, DC offset, preemphasis (hook) - # e. windowing + FFT + power + # a. _stft – orchestrates steps b–g (overridable for fully custom STFTs) + # b. _needs_manual_framing – decide framing strategy (hook) + # c. _create_stft_window – create the STFT window (backend) + # d. _prepare_window_and_framing– pad/reshape window, decide frame length (backend) + # e. manual path (needs_manual_framing=True): + # _frame_audio – center pad + frame extraction (backend) + # _apply_frame_processing – per-frame conditioning (hook) + # _window_and_fft – window + zero-pad + FFT + normalize → complex (backend) + # native path (needs_manual_framing=False): + # _native_stft – native STFT returning complex output (backend) + # f. _compute_magnitudes – complex → real magnitudes (backend, shared by both paths) + # g. _cast_stft_output – cast output dtype (hook, no-op by default) # 2. _apply_mel_scale (mel filterbank projection) # 3. _normalize_magnitude (log / dB scaling, optional per-utterance norm) # # Backend subclasses (NumpyAudioBackend, TorchAudioBackend) implement the # full pipeline. Model-specific processors can override individual hooks - # (_pre_stft, _apply_frame_processing) or the entire _extract_spectrogram - # when the base STFT path is insufficient (e.g., Parakeet's custom magnitude - # computation). + # (_apply_frame_processing) or the entire _stft when the base STFT path + # is insufficient. + # + # ``audio_ranges`` is passed through as a kwarg from ``_preprocess`` so that + # model-specific overrides (e.g., Parakeet waveform-level preemphasis, + # Phi4 boundary masking) can access original audio lengths without stashing + # state on ``self``. def _extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): """Orchestrate the STFT pipeline. Runs the sub-steps listed above in order. Override this only when the - pipeline ordering itself needs to change (e.g., Parakeet needs audio-length - detection before ``_pre_stft``). Otherwise, override individual hooks. + pipeline ordering itself needs to change. Otherwise, override individual hooks. """ - audio = self._pre_stft(audio, spectrogram_config=spectrogram_config, **kwargs) return self._stft(audio, spectrogram_config=spectrogram_config, **kwargs) - def _pre_stft(self, audio, *, spectrogram_config, **kwargs): - """Hook: waveform-level pre-processing before STFT. + def _stft(self, audio, *, spectrogram_config, **kwargs): + """Compute the STFT and return a power/magnitude spectrogram. - Called before framing. Default: no-op (returns audio unchanged). - Override for processing on the full waveform, e.g. length-aware - preemphasis with masking (Parakeet). + Orchestrates the sub-steps listed in the pipeline documentation above. + Backend subclasses implement the individual leaf methods; model-specific + processors can override this entirely for a fully custom STFT + (e.g., Gemma3n's unfold-based STFT with extra-sample framing). """ - return audio + stft_cfg = spectrogram_config.stft_config + n_fft = stft_cfg.n_fft + win_length = stft_cfg.win_length or n_fft + hop_length = stft_cfg.hop_length or win_length // 2 + needs_manual_framing = self._needs_manual_framing(spectrogram_config) + + if spectrogram_config.computation_dtype: + dtype_str = spectrogram_config.computation_dtype + if isinstance(audio, np.ndarray): + audio = audio.astype(dtype_str) + else: + import torch + audio = audio.to(getattr(torch, dtype_str)) + window = self._create_stft_window(win_length, stft_cfg, audio) + window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + if needs_manual_framing: + frames = self._frame_audio(audio, window, frame_length, hop_length, n_fft, stft_cfg) + frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) + stft_out = self._window_and_fft(frames, window, frame_length, n_fft, stft_cfg) + else: + stft_out = self._native_stft(audio, window, frame_length, hop_length, n_fft, stft_cfg) - def _stft(self, audio, *, spectrogram_config, **kwargs): - """Compute the STFT and return a power/magnitude spectrogram. + magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power) + return self._cast_stft_output(magnitudes, spectrogram_config) - Implemented by backend subclasses. Internally runs: - window creation → padding → framing → ``_apply_frame_processing`` → - windowing → FFT → power. + def _create_stft_window(self, win_length, stft_cfg, audio): + """Create the STFT window. Implemented by backend subclasses.""" + raise NotImplementedError - Override in model-specific processors that need a fully custom STFT - (e.g., Gemma3n's unfold-based STFT with extra-sample framing). + def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing): + """Pad/reshape window and determine frame length. Implemented by backend subclasses.""" + raise NotImplementedError + + def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + """Extract overlapping frames from the audio signal. + + Handles center padding and dtype promotion. Returns frames of shape + (..., num_frames, frame_length). Implemented by backend subclasses. + """ + raise NotImplementedError + + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + """Apply window, zero-pad, FFT, and normalize. Returns complex STFT of shape (..., freq, time). + Implemented by backend subclasses.""" + raise NotImplementedError + + def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): + """Native STFT (e.g. torch.stft). Returns complex output. Implemented by backend subclasses.""" + raise NotImplementedError + + def _compute_magnitudes(self, stft_out, power): + """Convert complex STFT output to a real-valued magnitude spectrogram. + Implemented by backend subclasses. Overridable for custom magnitude computation (e.g. Parakeet).""" + raise NotImplementedError + + def _cast_stft_output(self, magnitudes, spectrogram_config): + """Cast STFT output to the desired output dtype. Default: no-op.""" + return magnitudes + + def _needs_manual_framing(self, spectrogram_config): + """Whether the STFT requires manual framing (unfold-based) instead of a native STFT. + + Manual framing is needed when per-frame processing must happen between + frame extraction and windowing (e.g. per-frame preemphasis, DC offset removal, + or left-aligned FFT padding). + + Override in model-specific processors that handle preemphasis at the + waveform level (in ``_stft``) and don't need per-frame processing. + """ + return ( + (spectrogram_config.preemphasis is not None) + or spectrogram_config.remove_dc_offset + ) + + def _compute_magnitudes(self, stft_out, power): + """Convert complex STFT output to a real-valued magnitude spectrogram. + + Only used in the non-manual-framing STFT path. Override for + non-standard magnitude computation (e.g. Parakeet's view_as_real path). """ raise NotImplementedError diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index ff6f9a5788a2..8a16df98207e 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -78,6 +78,7 @@ class StftConfig: onesided: bool | None = None periodic: bool = True left_align_fft: bool = False + window_dtype: str | None = None def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -105,6 +106,7 @@ class MelScaleConfig: frequency_bin_mode: str = "rfft" computation_dtype: str | None = None bands_to_zero: int = 0 + matmul_order: str = "filters_first" def to_dict(self) -> dict: return {f.name: getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None} @@ -123,7 +125,6 @@ class SpectrogramConfig: mel_scale_config: MelScaleConfig | None = None log_mode: str = "log10" chunk_length: int | None = None - global_log_mel_max: float | None = None preemphasis: float | None = None remove_dc_offset: bool = False mel_floor: float = 1e-10 @@ -165,7 +166,6 @@ def from_dict(cls, d: dict) -> "SpectrogramConfig": mel_scale_config=mel_scale_config, log_mode=d.get("log_mode", "log10"), chunk_length=d.get("chunk_length"), - global_log_mel_max=d.get("global_log_mel_max"), preemphasis=d.get("preemphasis"), remove_dc_offset=d.get("remove_dc_offset", False), mel_floor=d.get("mel_floor", 1e-10), diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index c7dce5ad462d..ca6cb558e542 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -23,7 +23,8 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_torch_available() else TorchAudioBackend): sample_rate = 16000 force_mono = True - return_attention_mask = False + return_padding_mask = False + padding = False max_length_frames = 1024 do_normalize = True @@ -55,18 +56,22 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_torc mel_floor=1.192092955078125e-07, ) - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): - if isinstance(audio, np.ndarray) and audio.ndim > 1: - audio = [audio[i] for i in range(audio.shape[0])] - elif hasattr(audio, 'dim') and audio.dim() > 1: - audio = [audio[i] for i in range(audio.shape[0])] - elif not isinstance(audio, list): - audio = [audio] + def _preprocess( + self, + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=None, + do_extract_spectrogram=None, + **kwargs, + ): + # Extract mel spectrogram features from raw audio using the base spectrogram pipeline + features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) - # (n_mels, frames) -> (frames, n_mels) + # extract_spectrogram returns list of (n_mels, frames); transpose to (frames, n_mels) features = [f.T for f in features] # Pad or truncate to max_length_frames @@ -84,12 +89,8 @@ def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): if self.do_normalize: padded = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] - return np.stack(padded, axis=0) - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # AST does all processing at the feature level in extract_spectrogram - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - return BatchFeature({"audio_values": features}, tensor_type=return_tensors) + stacked = np.stack(padded, axis=0) + return BatchFeature({"audio_values": stacked}, tensor_type=return_tensors) __all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index d72ebf972457..e18b8e542007 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -15,102 +15,83 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function -from ...feature_extraction_utils import BatchFeature +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig +from ...utils import PaddingStrategy class ClapAudioProcessor(NumpyAudioBackend): sample_rate = 48000 force_mono = True - max_length_s = 10 + max_length = 480000 truncation_mode = "rand_trunc" # "fusion" or "rand_trunc" - padding_mode = "repeatpad" # "repeatpad", "repeat", or "pad" - - spectrogram_config = SpectrogramConfig( - stft_config=StftConfig( - n_fft=1024, - hop_length=480, - power=2.0, - ), - mel_scale_config=MelScaleConfig( - n_mels=64, - f_min=50, - f_max=14000, - mel_scale="slaney", - norm="slaney", - ), - log_mode="dB", - ) - - # Fusion mode uses a different mel filter bank (htk scale, no norm) - spectrogram_config_fusion = SpectrogramConfig( - stft_config=StftConfig( - n_fft=1024, - hop_length=480, - power=2.0, - ), - mel_scale_config=MelScaleConfig( - n_mels=64, - f_min=0, - f_max=14000, - mel_scale="htk", - ), - log_mode="dB", - ) + + _mel_configs = { + "rand_trunc": MelScaleConfig(n_mels=64, f_min=50, f_max=14000, mel_scale="slaney", norm="slaney", frequency_bin_mode="linspace"), + "fusion": MelScaleConfig(n_mels=64, f_min=50, f_max=14000, mel_scale="htk", frequency_bin_mode="linspace"), + } def __init__(self, **kwargs): + truncation_mode = kwargs.pop("truncation_mode", self.truncation_mode) + self.truncation_mode = truncation_mode + self.spectrogram_config = SpectrogramConfig( + stft_config=StftConfig(n_fft=1024, hop_length=480, power=2.0), + mel_scale_config=self._mel_configs[truncation_mode], + log_mode="dB", + ) super().__init__(**kwargs) - self.nb_max_samples = self.max_length_s * self.sample_rate - self.mel_filters_fusion = self._mel_filter_bank(self.spectrogram_config_fusion) - - def _pad_single_clap(self, audio: np.ndarray, max_length: int, padding_mode: str) -> np.ndarray: - """ - CLAP-specific padding: handles "repeat" and "repeatpad" modes. - This is separate from the standard _pad_single used by the base class. - """ - current_length = audio.shape[-1] - if current_length >= max_length: - return audio - - if padding_mode == "repeat": - # Repeat the audio enough times to cover max_length - n_repeat = int(max_length / current_length) - audio = np.tile(audio, n_repeat + 1)[:max_length] - return audio - elif padding_mode == "repeatpad": - # Repeat then pad with zeros - n_repeat = int(max_length / current_length) - audio = np.tile(audio, n_repeat) - remaining = max_length - audio.shape[-1] - if remaining > 0: - audio = np.pad(audio, (0, remaining), mode="constant", constant_values=0) - return audio - else: - # For other modes, use standard padding via parent's _pad_single - return super()._pad_single(audio, max_length) - - def _extract_single_mel(self, waveform, spectrogram_config=None): - """Extract mel spectrogram for a single waveform using audio_utils.spectrogram.""" - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - stft_cfg = spectrogram_config.stft_config - - # Use the correct mel filters for this config - if spectrogram_config is self.spectrogram_config_fusion: - mel_filters = self.mel_filters_fusion + # rand_trunc: base class truncates via pad() → _truncate_single (random offset) + # fusion: no pre-truncation; full mel is extracted then chunked + self.truncation = truncation_mode == "rand_trunc" + + def _get_padding_strategies(self, padding=False, max_length=None): + # CLAP always pads to max_length, not to the longest in the batch + if padding is True and max_length is not None: + return PaddingStrategy.MAX_LENGTH + return super()._get_padding_strategies(padding=padding, max_length=max_length) + + def pad(self, audio, *args, **kwargs): + self._is_longer_flags = [] + return super().pad(audio, *args, **kwargs) + + def _truncate_single(self, audio_el, max_length): + """Random-offset truncation for rand_trunc mode, also tracks which samples were longer.""" + self._is_longer_flags.append(audio_el.shape[-1] > max_length) + if audio_el.shape[-1] > max_length: + idx = np.random.randint(0, audio_el.shape[-1] - max_length + 1) + return audio_el[..., idx : idx + max_length] + return audio_el + + def extract_spectrogram(self, audio, *, spectrogram_config=None, audio_ranges=None, **kwargs): + """Extract mel spectrogram and shape output (1 view for rand_trunc, 4 for fusion).""" + is_fusion = self.truncation_mode == "fusion" + chunk_frames = self.max_length // self.spectrogram_config.stft_config.hop_length + 1 + + if isinstance(audio, np.ndarray) and audio.ndim == 2: + waveforms = list(audio) + elif isinstance(audio, np.ndarray) and audio.ndim == 1: + waveforms = [audio] else: - mel_filters = self.mel_filters - - log_mel_spectrogram = spectrogram( - waveform, - window_function(stft_cfg.n_fft, "hann"), - frame_length=stft_cfg.n_fft, - hop_length=stft_cfg.hop_length, - power=2.0, - mel_filters=mel_filters, - log_mel="dB", - ) - return log_mel_spectrogram.T + waveforms = audio + + mels = [] + is_longer = [] + for waveform in waveforms: + mel = super().extract_spectrogram(waveform, spectrogram_config=self.spectrogram_config).T # (time, n_mels) + total_frames = mel.shape[0] + + if is_fusion and total_frames > chunk_frames: + mels.append(self._random_mel_fusion(mel, total_frames, chunk_frames)) + is_longer.append(True) + elif is_fusion: + mels.append(np.stack([mel, mel, mel, mel], axis=0)) + is_longer.append(False) + else: + mels.append(mel[np.newaxis]) + is_longer.append(False) + + if is_fusion: + self._is_longer_flags = is_longer + return mels def _random_mel_fusion(self, mel, total_frames, chunk_frames): import torch @@ -135,108 +116,13 @@ def _random_mel_fusion(self, mel, total_frames, chunk_frames): mel_shrink = mel_shrink[0][0].numpy() return np.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], axis=0) - def _get_input_mel(self, waveform, max_length, truncation): - hop_length = self.spectrogram_config.stft_config.hop_length - - if waveform.shape[0] > max_length: - if truncation == "rand_trunc": - longer = True - overflow = len(waveform) - max_length - idx = np.random.randint(0, overflow + 1) - waveform = waveform[idx : idx + max_length] - input_mel = self._extract_single_mel(waveform)[None, :] - elif truncation == "fusion": - mel = self._extract_single_mel(waveform, spectrogram_config=self.spectrogram_config_fusion) - chunk_frames = max_length // hop_length + 1 - total_frames = mel.shape[0] - if chunk_frames == total_frames: - input_mel = np.stack([mel, mel, mel, mel], axis=0) - longer = False - else: - input_mel = self._random_mel_fusion(mel, total_frames, chunk_frames) - longer = True - else: - raise NotImplementedError(f"data_truncating {truncation} not implemented") - else: - longer = False - if truncation == "fusion": - input_mel = self._extract_single_mel(waveform, spectrogram_config=self.spectrogram_config_fusion) - input_mel = np.stack([input_mel, input_mel, input_mel, input_mel], axis=0) - else: - input_mel = self._extract_single_mel(waveform)[None, :] - - return input_mel, longer - - def _preprocess( - self, - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - **kwargs, - ): - # Use instance defaults when not explicitly provided (matching feature extractor behavior) - truncation_mode = self.truncation_mode if truncation is None else truncation - # For padding: use instance default only when not provided (None or False) - # When padding=True is passed, use it directly (feature extractor behavior) - if padding is None or padding is False: - padding_mode = self.padding_mode - else: - padding_mode = padding - nb_max_samples = max_length if isinstance(max_length, int) and max_length > 0 else self.nb_max_samples - - # Handle truncation: only apply if boolean truncation=True OR if using CLAP-specific string modes - # Note: CLAP's _get_input_mel handles truncation internally based on truncation_mode - # We only do pre-truncation here for standard boolean truncation=True case - if truncation is True: - if nb_max_samples is None: - raise ValueError("When setting `truncation=True`, make sure that `max_length` is defined.") - trunc_length = nb_max_samples - if pad_to_multiple_of is not None and (trunc_length % pad_to_multiple_of != 0): - trunc_length = ((trunc_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - audio = [self._truncate_single(audio_el, max_length=trunc_length) for audio_el in audio] - - # Handle padding: CLAP-specific modes ("repeat", "repeatpad") vs standard modes - if padding_mode in ("repeat", "repeatpad"): - # Use CLAP's custom _pad_single_clap which handles repeat/repeatpad - audio = [self._pad_single_clap(audio_el, max_length=nb_max_samples, padding_mode=padding_mode) for audio_el in audio] - elif padding is not False and padding_mode is not False: - # Use standard padding flow for "longest", "max_length", True, etc. - from ...utils import PaddingStrategy - if padding_mode is True and nb_max_samples is not None: - # When padding=True and we have a max length, use MAX_LENGTH strategy - # (matching feature extractor behavior that pads to max_length) - padding_strategy = PaddingStrategy.MAX_LENGTH - elif isinstance(padding_mode, str) and padding_mode not in ("longest", "max_length", "do_not_pad"): - padding_strategy = PaddingStrategy.LONGEST # Default to longest for unknown string values - else: - padding_strategy = padding_mode - audio, _audio_ranges = self.pad(audio, padding_strategy, nb_max_samples, truncation=False, pad_to_multiple_of=pad_to_multiple_of) - - # Process each waveform through CLAP's mel extraction (handles truncation internally) - padded_inputs = [ - self._get_input_mel(np.squeeze(waveform), nb_max_samples, truncation_mode) - for waveform in audio - ] - - input_mel = [] - is_longer = [] - for mel, longer in padded_inputs: - input_mel.append(mel) - is_longer.append(longer) - - if truncation_mode == "fusion" and sum(is_longer) == 0: - rand_idx = np.random.randint(0, len(input_mel)) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + """Return CLAP's is_longer flag instead of a standard attention mask.""" + is_longer = getattr(self, "_is_longer_flags", None) or [False] * len(audio_ranges) + if self.truncation_mode == "fusion" and sum(is_longer) == 0: + rand_idx = np.random.randint(0, len(is_longer)) is_longer[rand_idx] = True - - is_longer = [[longer] for longer in is_longer] - - input_features = {"audio_features": input_mel, "is_longer": is_longer} - return BatchFeature(input_features, tensor_type=return_tensors) + return {"is_longer": [[longer] for longer in is_longer]} __all__ = ["ClapAudioProcessor"] diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index bb2d134fa13f..503c606ab2df 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -15,7 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class ClvpAudioProcessor(NumpyAudioBackend): @@ -47,36 +47,14 @@ def __init__(self, mel_norms=None, **kwargs): super().__init__(**kwargs) self.mel_norms = mel_norms - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - - if isinstance(audio, np.ndarray) and audio.ndim > 1: - audio = [audio[i] for i in range(audio.shape[0])] - elif not isinstance(audio, list): - audio = [audio] - - stft_cfg = spectrogram_config.stft_config - features = [] - for waveform in audio: - waveform = np.squeeze(waveform) - log_spec = spectrogram( - waveform, - window_function(stft_cfg.n_fft, "hann"), - frame_length=stft_cfg.n_fft, - hop_length=stft_cfg.hop_length, - power=2.0, - mel_filters=self.mel_filters, - log_mel=None, - ) - log_spec = np.log(np.clip(log_spec, a_min=1e-5, a_max=None)) - - if self.mel_norms is not None: - log_spec = log_spec / np.array(self.mel_norms)[:, None] - - features.append(log_spec.astype(np.float32)) - - return np.stack(features, axis=0) if len(features) > 1 else features + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + # Compute log and mel_norms division in float64 before casting to float32 + # to match the legacy feature extractor's precision + mel_floor = spectrogram_config.mel_floor + features = np.log(np.maximum(mel_floor, features)) + if self.mel_norms is not None: + features = features / np.array(self.mel_norms)[:, None] + return features.astype(np.float32) def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): """CLVP uses raw-audio-level mask even for spectrogram output.""" diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py index 80a8590c8c54..2f8548eee900 100644 --- a/src/transformers/models/dac/audio_processing_dac.py +++ b/src/transformers/models/dac/audio_processing_dac.py @@ -20,16 +20,9 @@ class DacAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True - add_channel_dim = True def _to_batch(self, audio): return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_values_mask": mask} - __all__ = ["DacAudioProcessor"] diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py index 5766acd746b5..9a2fbac08954 100644 --- a/src/transformers/models/dia/audio_processing_dia.py +++ b/src/transformers/models/dia/audio_processing_dia.py @@ -20,17 +20,10 @@ class DiaAudioProcessor(NumpyAudioBackend): sample_rate = 44100 force_mono = True - add_channel_dim = True pad_to_multiple_of = 512 def _to_batch(self, audio): return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_values_mask": mask} - __all__ = ["DiaAudioProcessor"] diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py index 4208cc5c1ec8..f52dedae59ec 100644 --- a/src/transformers/models/encodec/audio_processing_encodec.py +++ b/src/transformers/models/encodec/audio_processing_encodec.py @@ -20,16 +20,9 @@ class EncodecAudioProcessor(NumpyAudioBackend): sample_rate = 24000 force_mono = True - add_channel_dim = True def _to_batch(self, audio): return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_values_mask": mask} - __all__ = ["EncodecAudioProcessor"] diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index 7ea42246d746..e4359bcd62c3 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -53,6 +53,7 @@ class Gemma3nAudioProcessor(NumpyAudioBackend): f_min=125.0, f_max=7600.0, mel_scale="htk", + matmul_order="features_first", ), mel_floor=1e-5, log_mode="log", @@ -62,7 +63,7 @@ class Gemma3nAudioProcessor(NumpyAudioBackend): def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): super().__init__(**kwargs) - # Pre-compute window from stft_config + # Pre-compute window in float32 to match the upstream FE exactly win_length = self.spectrogram_config.stft_config.win_length hann_arange = np.arange(win_length, dtype=np.float32) self.window = (0.5 * (1 - np.cos(2 * np.pi * hann_arange / win_length))).astype(np.float32) @@ -91,6 +92,11 @@ def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): return frames[..., :-1] def _stft(self, audio, *, spectrogram_config, **kwargs): + """Unfold-based STFT with extra-sample framing for HTK preemphasis. + + Extracts frames of win_length+1 so that _apply_frame_processing can + reduce them to win_length after HTK preemphasis. Returns (batch, time, freq). + """ stft_cfg = spectrogram_config.stft_config frame_size_for_unfold = stft_cfg.win_length + 1 @@ -102,11 +108,6 @@ def _stft(self, audio, *, spectrogram_config, **kwargs): stft = np.fft.rfft(frames, n=stft_cfg.n_fft, axis=-1) return np.abs(stft) - def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): - """Apply mel filterbank. Features are in (batch, time, freq) format.""" - mel_spec = np.matmul(features, self.mel_filters) - return np.maximum(spectrogram_config.mel_floor, mel_spec) - def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): """Apply log compression and per-bin normalization.""" result = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) @@ -119,10 +120,19 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): return result.astype(np.float32) def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): - """Frame count for unfold-based STFT (no centering).""" + """Frame count matching the FE's downsampled attention mask approach. + + The upstream FE computes the mask by slicing the sample-level attention + mask every hop_length steps, which yields ceil(audio_length / hop_length) + valid frames rather than the unfold-based count. + """ hop_length = spectrogram_config.stft_config.hop_length - frame_size = spectrogram_config.stft_config.win_length + 1 - return (audio_lengths - frame_size) // hop_length + 1 + if include_center_frame: + # For padded length we still use the unfold formula to get total frames + frame_size = spectrogram_config.stft_config.win_length + 1 + return (audio_lengths - frame_size) // hop_length + 1 + # Match FE: attention_mask[::hop_length] gives this many valid entries + return (audio_lengths + hop_length - 1) // hop_length __all__ = ["Gemma3nAudioProcessor"] diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index 099b0bcedb48..3d68a8dc60c5 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -94,7 +94,7 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of data = { "audio_features": features, "audio_embed_sizes": audio_embed_sizes, - "input_features_mask": input_features_mask, + "audio_features_mask": input_features_mask, } return BatchFeature(data=data, tensor_type=return_tensors) diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py index 09713556577b..3cbe3782de24 100644 --- a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py @@ -15,46 +15,37 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...feature_extraction_utils import BatchFeature class KyutaiSpeechToTextAudioProcessor(NumpyAudioBackend): sample_rate = 24000 force_mono = True - add_channel_dim = True - - def __init__(self, audio_delay_seconds=2.5, audio_silence_prefix_seconds=1.0, **kwargs): - self.audio_delay_seconds = audio_delay_seconds - self.audio_silence_prefix_seconds = audio_silence_prefix_seconds - super().__init__(**kwargs) + audio_delay_seconds = 2.5 + audio_silence_prefix_seconds = 1.0 def _to_batch(self, audio): - return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_values_mask": mask} + return np.stack(audio)[:, np.newaxis, :] def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - # Pad audio to batch longest - audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = audio[0].shape[-1] - - stacked = self._to_batch(audio) - mask_dict = self._get_mask(audio_ranges, padded_length, do_extract_spectrogram=False, spectrogram_config=None) - audio_values_mask = mask_dict["audio_values_mask"] + kwargs.pop("do_extract_spectrogram", None) + result = super()._preprocess( + audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, + do_extract_spectrogram=False, **kwargs, + ) # Add silence prefix (left) and delay (right) padding pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate) pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate) if pad_left > 0 or pad_right > 0: - stacked = np.pad(stacked, [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0) - audio_values_mask = np.pad(audio_values_mask, [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0) - - return BatchFeature({"audio_values": stacked, "audio_values_mask": audio_values_mask}, tensor_type=return_tensors) + result["audio_values"] = np.pad( + result["audio_values"], [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0, + ) + result["audio_values_mask"] = np.pad( + result["audio_values_mask"], [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0, + ) + + return result __all__ = ["KyutaiSpeechToTextAudioProcessor"] diff --git a/src/transformers/models/lasr/audio_processing_lasr.py b/src/transformers/models/lasr/audio_processing_lasr.py index 3f0c9c92a21e..a1b581628988 100644 --- a/src/transformers/models/lasr/audio_processing_lasr.py +++ b/src/transformers/models/lasr/audio_processing_lasr.py @@ -39,18 +39,13 @@ class LasrAudioProcessor(TorchAudioBackend): triangularize_in_mel_space=True, bands_to_zero=1, computation_dtype="float64", + matmul_order="features_first", ), log_mode="log", mel_floor=1e-5, computation_dtype="float64", ) - def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): - # LASR uses (time, freq) @ (freq, mels) -> (time, mels) ordering, - # matching the upstream FE's unfold-based output layout. - mel_spec = torch.matmul(features.transpose(-2, -1), self.mel_filters.to(device=features.device, dtype=features.dtype)) - return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) - def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): stft_cfg = spectrogram_config.stft_config win_length = stft_cfg.win_length or stft_cfg.n_fft diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index f328bbe12ab8..a26d936075ae 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -90,30 +90,34 @@ def _compute_magnitudes(self, stft_out, power): return magnitudes def _needs_manual_framing(self, spectrogram_config): - # Preemphasis is handled waveform-level in _pre_stft; no per-frame processing needed. + # Preemphasis is handled waveform-level in _stft; no per-frame processing needed. return spectrogram_config.remove_dc_offset or spectrogram_config.stft_config.left_align_fft - def _pre_stft(self, audio, *, spectrogram_config, **kwargs): + def _stft(self, audio, *, spectrogram_config, audio_ranges=None, **kwargs): import torch - if not isinstance(self._audio_lengths, torch.Tensor): - self._audio_lengths = torch.tensor(self._audio_lengths, device=audio.device) + audio_lengths = torch.tensor( + [end - start for start, end in audio_ranges], device=audio.device + ) if audio_ranges is not None else None + # Waveform-level preemphasis with masking to zero out padding preemphasis = spectrogram_config.preemphasis if preemphasis is not None: audio = torch.cat( [audio[:, :1], audio[:, 1:] - preemphasis * audio[:, :-1]], dim=1 ) - timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < self._audio_lengths.unsqueeze(1) - audio = audio.masked_fill(~timemask, 0.0) - return audio + if audio_lengths is not None: + timemask = torch.arange(audio.shape[-1], device=audio.device).unsqueeze(0) < audio_lengths.unsqueeze(1) + audio = audio.masked_fill(~timemask, 0.0) + + return super()._stft(audio, spectrogram_config=spectrogram_config, **kwargs) def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): import torch return torch.matmul(self.mel_filters.T, features) - def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + def _normalize_magnitude(self, features, *, spectrogram_config, audio_ranges=None, **kwargs): import torch # Match FE: log(mel_spec + guard_value) instead of log(clamp(mel_spec, guard_value)) @@ -123,20 +127,21 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = features.permute(0, 2, 1) # Per-utterance normalization - stft_cfg = spectrogram_config.stft_config - audio_lengths = self._audio_lengths - features_lengths = torch.floor_divide( - audio_lengths + stft_cfg.n_fft // 2 * 2 - stft_cfg.n_fft, stft_cfg.hop_length - ) - attention_mask = torch.arange(features.shape[1])[None, :] < features_lengths[:, None] - mask = attention_mask.unsqueeze(-1) - mel_masked = features * mask - mean = mel_masked.sum(dim=1) / features_lengths.unsqueeze(-1) - mean = mean.unsqueeze(1) - variance = ((mel_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) - std = torch.sqrt(variance).unsqueeze(1) - features = (features - mean) / (std + 1e-5) - features *= mask + if audio_ranges is not None: + stft_cfg = spectrogram_config.stft_config + audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) + features_lengths = torch.floor_divide( + audio_lengths + stft_cfg.n_fft // 2 * 2 - stft_cfg.n_fft, stft_cfg.hop_length + ) + attention_mask = torch.arange(features.shape[1])[None, :] < features_lengths[:, None] + mask = attention_mask.unsqueeze(-1) + mel_masked = features * mask + mean = mel_masked.sum(dim=1) / features_lengths.unsqueeze(-1) + mean = mean.unsqueeze(1) + variance = ((mel_masked - mean) ** 2 * mask).sum(dim=1) / (features_lengths - 1).unsqueeze(-1) + std = torch.sqrt(variance).unsqueeze(1) + features = (features - mean) / (std + 1e-5) + features *= mask return features diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index 7667c2a21737..18d40f1a5c82 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch + from spectrograms import numpy_mel_spectrogram as _np_spec from ...audio_processing_backends import TorchAudioBackend @@ -22,30 +24,34 @@ class Phi4MultimodalAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True - preemphasis = 0.97 - n_fft = 512 - hop_length = 160 - win_length = 400 - n_mels = 80 - mel_min_frequency = 0 - mel_max_frequency = 7690 audio_compression_rate = 8 audio_downsample_rate = 1 audio_feat_stride = 1 spectrogram_config = SpectrogramConfig( - stft_config=StftConfig(n_fft=512), + stft_config=StftConfig( + n_fft=512, + win_length=400, + hop_length=160, + window_fn="hamming_window", + periodic=False, + center=False, + power=2.0, + window_dtype="float64", + ), + preemphasis=0.97, mel_scale_config=MelScaleConfig( n_mels=80, f_min=0, f_max=7690, mel_scale="kaldi", triangularize_in_mel_space=True, + matmul_order="features_first", ), + mel_floor=1.0, + log_mode="log", ) def _mel_filter_bank(self, spectrogram_config): - import torch - stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config mel_filters_np = _np_spec.mel_filter_bank( @@ -60,55 +66,50 @@ def _mel_filter_bank(self, spectrogram_config): ) return torch.from_numpy(mel_filters_np).to(torch.float32) - def extract_spectrogram(self, audio, **kwargs): - import torch - - waveform = torch.stack(audio) # (batch, length) - batch_size = waveform.shape[0] - audio_lengths = kwargs.get("audio_lengths") - - fft_window = torch.hamming_window(self.win_length, periodic=False, dtype=torch.float64) - frames = waveform.unfold(-1, self.win_length, self.hop_length) - + def _apply_frame_processing(self, frames, *, spectrogram_config, audio_ranges=None, **kwargs): # Mask frames that overlap the boundary between real audio and padding - if batch_size > 1 and audio_lengths is not None: - frames = frames.clone() - to_mask_batch_idxs = torch.arange(batch_size)[audio_lengths != audio_lengths.max()] - if to_mask_batch_idxs.numel() > 0: - batch_idxs_down = (audio_lengths[to_mask_batch_idxs] - self.win_length) // self.hop_length + 1 - batch_idxs_up = (audio_lengths[to_mask_batch_idxs] // self.hop_length) - 1 - offset_idx = batch_idxs_down.min() - max_idx = batch_idxs_up.max() - - mask = torch.arange(max_idx - offset_idx).expand(to_mask_batch_idxs.shape[0], -1) - mask = ((batch_idxs_down - offset_idx).unsqueeze(1) <= mask) & ( - mask < (batch_idxs_up - offset_idx).unsqueeze(1) - ) - mask = mask.unsqueeze(-1).expand(-1, -1, self.win_length) - masked_frames = frames[to_mask_batch_idxs, offset_idx:max_idx].masked_fill_(mask, 0) - frames[to_mask_batch_idxs, offset_idx:max_idx] = masked_frames - - # Pre-emphasis on frames with scaling - frames_prev = torch.roll(frames, 1, dims=-1) - frames_prev[:, :, 0] = frames_prev[:, :, 1] - frames = (frames - self.preemphasis * frames_prev) * 32768 - - # FFT - S = torch.fft.rfft(fft_window * frames.view(-1, self.win_length), n=self.n_fft, dim=1) - S = S.view(frames.shape[0], -1, S.shape[-1]) - S = S.to(torch.complex64) - - spec_power = torch.abs(S) ** 2 - - # Mel filterbank + log - mel_filters = self.mel_filters.to(torch.float32) - log_spec = torch.clamp(spec_power @ mel_filters, min=1.0) - log_spec = torch.log(log_spec) + stft_cfg = spectrogram_config.stft_config + win_length = stft_cfg.win_length or stft_cfg.n_fft + hop_length = stft_cfg.hop_length or win_length // 2 + batch_size = frames.shape[0] + + if audio_ranges is not None and batch_size > 1: + audio_lengths_t = torch.tensor([end - start for start, end in audio_ranges]) + to_mask_idxs = torch.arange(batch_size)[audio_lengths_t != audio_lengths_t.max()] + if to_mask_idxs.numel() > 0: + frames = frames.clone() + down = (audio_lengths_t[to_mask_idxs] - win_length) // hop_length + 1 + up = audio_lengths_t[to_mask_idxs] // hop_length - 1 + offset = down.min() + max_idx = up.max() + + mask_range = torch.arange(max_idx - offset).expand(to_mask_idxs.shape[0], -1) + mask = ((down - offset).unsqueeze(1) <= mask_range) & (mask_range < (up - offset).unsqueeze(1)) + mask = mask.unsqueeze(-1).expand(-1, -1, win_length) + + masked_frames = frames[to_mask_idxs, offset:max_idx].masked_fill_(mask, 0) + frames[to_mask_idxs, offset:max_idx] = masked_frames - return [log_spec[i] for i in range(batch_size)] + frames_prev = torch.roll(frames, 1, dims=-1) + frames_prev[..., 0] = frames_prev[..., 1] + return (frames - spectrogram_config.preemphasis * frames_prev) * 32768 + + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + frames = frames * window + if frame_length < n_fft: + frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length)) + # Cast to complex64 before abs() to match the FE's precision path + spec = torch.fft.rfft(frames, n=n_fft).to(torch.complex64) + if stft_cfg.normalized: + spec = spec / window.pow(2.0).sum().sqrt() + return spec.transpose(-2, -1) + + def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): + win_length = spectrogram_config.stft_config.win_length or spectrogram_config.stft_config.n_fft + hop_length = spectrogram_config.stft_config.hop_length or win_length // 2 + return (audio_lengths - win_length) // hop_length + 1 def _compute_audio_embed_size(self, audio_frames): - integer = audio_frames // self.audio_compression_rate remainder = audio_frames % self.audio_compression_rate result = integer + (remainder > 0).to(integer.dtype) @@ -131,30 +132,22 @@ def _preprocess( do_extract_spectrogram=None, **kwargs, ) -> BatchFeature: - import torch - - # Capture original lengths before padding - audio_lengths = torch.tensor([a.shape[-1] for a in audio]) - - # Pad and truncate - audio, _audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - - # Extract spectrogram - features = self.extract_spectrogram(audio, audio_lengths=audio_lengths) + output = super()._preprocess( + audio, + padding, + max_length, + truncation, + pad_to_multiple_of, + return_tensors, + spectrogram_config=spectrogram_config, + do_extract_spectrogram=do_extract_spectrogram, + **kwargs, + ) - # Compute audio_embed_sizes - feature_lengths = (audio_lengths - self.win_length) // self.hop_length + 1 + feature_lengths = output["audio_features_mask"].sum(dim=-1) feature_lengths = feature_lengths * self.audio_feat_stride - audio_embed_sizes = self._compute_audio_embed_size(feature_lengths) - - data = {"audio_features": features, "audio_embed_sizes": audio_embed_sizes} - - # Attention mask for batched inputs with different lengths - if len(audio_lengths) > 1: - feature_attention_mask = torch.arange(0, feature_lengths.max())[None, :] < feature_lengths[:, None] - data["audio_attention_mask"] = feature_attention_mask + output["audio_embed_sizes"] = self._compute_audio_embed_size(feature_lengths) - output = BatchFeature(data, tensor_type=return_tensors) return output diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index 4f91a50b1f2e..9c66ea182404 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -116,7 +116,7 @@ def _preprocess( stacked = np.stack(normalized, axis=0) data = {"audio_features": stacked} - if self.return_attention_mask: + if self.return_padding_mask: attention_mask = np.zeros((len(lengths), max_len), dtype=np.int32) for i, length in enumerate(lengths): attention_mask[i, :length] = 1 diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index 65a25c85eeb7..633ce121aca5 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -15,8 +15,7 @@ import numpy as np from ...audio_processing_backends import NumpyAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, spectrogram, window_function -from ...feature_extraction_utils import BatchFeature +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig class UnivNetAudioProcessor(NumpyAudioBackend): @@ -35,7 +34,14 @@ class UnivNetAudioProcessor(NumpyAudioBackend): normalize_max = 2.3143386840820312 max_length_s = 10 spectrogram_config = SpectrogramConfig( - stft_config=StftConfig(n_fft=1024), + stft_config=StftConfig( + n_fft=1024, + hop_length=256, + center=False, + window_fn="hann", + periodic=True, + power=1.0, + ), mel_scale_config=MelScaleConfig( n_mels=100, f_min=0.0, @@ -43,87 +49,50 @@ class UnivNetAudioProcessor(NumpyAudioBackend): mel_scale="slaney", norm="slaney", ), + log_mode="log", + mel_floor=1e-5, ) def __init__(self, **kwargs): super().__init__(**kwargs) self.num_max_samples = self.max_length_s * self.sample_rate - self.window = window_function(self.n_fft, "hann", periodic=True) - def mel_spectrogram(self, waveform): - # Reflect-pad waveform + def _stft(self, audio, *, spectrogram_config, **kwargs): + # UnivNet uses reflect padding with (n_fft - hop_length) / 2 instead of center padding pad_amount = int((self.n_fft - self.hop_length) / 2) - waveform = np.pad(waveform, (pad_amount, pad_amount), mode="reflect") - - # Complex spectrogram - complex_spec = spectrogram( - waveform, - window=self.window, - frame_length=self.n_fft, - hop_length=self.hop_length, - fft_length=self.n_fft, - power=None, - center=False, - mel_filters=None, - mel_floor=None, - ) - - # Custom amplitude spectrogram: sqrt(real^2 + imag^2 + mel_floor) - amplitude_spec = np.sqrt(np.real(complex_spec) ** 2 + np.imag(complex_spec) ** 2 + self.mel_floor) - - # Apply mel filter bank - mel_spec = np.matmul(self.mel_filters.T, amplitude_spec) - - # Log compression - log_mel = np.log(np.clip(mel_spec, a_min=self.compression_clip_val, a_max=None) * self.compression_factor) - - return log_mel.T # (frames, n_mels) - - def normalize(self, spectrogram_data): - return 2 * ((spectrogram_data - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 - - def extract_spectrogram(self, audio, *, spectrogram_config): - features = [] - for waveform in audio: - waveform = np.squeeze(waveform) - mel = self.mel_spectrogram(waveform) - if self.do_normalize: - mel = self.normalize(mel) - features.append(mel.astype(np.float32)) + if audio.ndim > 1: + audio = np.pad(audio, ((0, 0), (pad_amount, pad_amount)), mode="reflect") + else: + audio = np.pad(audio, (pad_amount, pad_amount), mode="reflect") + return super()._stft(audio, spectrogram_config=spectrogram_config, **kwargs) + + def _compute_magnitudes(self, stft_out, power): + # UnivNet adds mel_floor inside the sqrt: sqrt(real² + imag² + mel_floor) + return np.sqrt(np.real(stft_out) ** 2 + np.imag(stft_out) ** 2 + self.mel_floor) + + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + # UnivNet applies mel filterbank without a floor + return np.matmul(self.mel_filters.T, features) + + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) + if self.do_normalize: + features = 2 * ((features - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 return features - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, generator=None, **kwargs): - # Pad raw audio - if padding: - audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) - - # Extract mel spectrograms - features = self.extract_spectrogram(audio, spectrogram_config=None) - - # Pad features - max_feat_len = max(f.shape[0] for f in features) - padded = [] - for f in features: - if f.shape[0] < max_feat_len: - pad_amount = max_feat_len - f.shape[0] - f = np.pad(f, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) - padded.append(f) - - output_key = "audio_features" - stacked = np.stack(padded, axis=0) - - # Generate noise sequence matching the FE - if generator is None: - generator = np.random.default_rng() - noise = [ - generator.standard_normal((f.shape[0], 64), dtype=np.float32) - for f in padded - ] - - return BatchFeature( - data={output_key: stacked, "noise_sequence": noise}, - tensor_type=return_tensors, - ) + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + # UnivNet uses waveform-level padding mask even when extracting spectrograms + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} + + def extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): + features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + # Transpose from (..., n_mels, frames) to (..., frames, n_mels) + if isinstance(features, list): + return [np.swapaxes(f, -2, -1) for f in features] + return np.swapaxes(features, -2, -1) __all__ = ["UnivNetAudioProcessor"] diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index df882f7c0805..1895e49fdcbe 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -38,11 +38,5 @@ def _preprocess(self, audio, **kwargs): result["audio_values"] = result["audio_values"].unsqueeze(1) return result - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_values_mask": mask} - __all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index 1554b3dcfbb1..2427ce9b36dc 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -14,8 +14,6 @@ import torch -from spectrograms import numpy_mel_spectrogram as _np_spec - from ...audio_processing_backends import TorchAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig @@ -33,67 +31,28 @@ class VoxtralRealtimeAudioProcessor(TorchAudioBackend): n_mels=128, mel_scale="slaney", norm="slaney", + computation_dtype="float64", ), log_mode="log10", - global_log_mel_max=1.5, ) + global_log_mel_max = 1.5 - def extract_spectrogram(self, audio, *, spectrogram_config=None, **kwargs): - if spectrogram_config is None: - spectrogram_config = self.spectrogram_config - - stft_cfg = spectrogram_config.stft_config - global_log_mel_max = spectrogram_config.global_log_mel_max + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): + features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) + features = features[..., :-1] - if isinstance(audio, list): - waveform = torch.stack(audio) + if self.global_log_mel_max is not None: + spec_max = torch.tensor(self.global_log_mel_max, device=features.device, dtype=features.dtype) else: - waveform = audio - - device = waveform.device - window = torch.hann_window(stft_cfg.n_fft, device=device) - stft = torch.stft( - waveform, stft_cfg.n_fft, stft_cfg.hop_length, - window=window, return_complex=True, center=True, - ) - magnitudes = stft[..., :-1].abs() ** 2 - - mel_filters = self.mel_filters.to(device, torch.float32) - mel_spec = mel_filters.T @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - - processed = [] - for i in range(log_spec.shape[0]): - spec = log_spec[i] - if global_log_mel_max is not None: - spec_max = torch.tensor(global_log_mel_max, device=spec.device, dtype=spec.dtype) - else: - spec_max = spec.max() - spec = torch.maximum(spec, spec_max - 8.0) - spec = (spec + 4.0) / 4.0 - processed.append(spec) - return processed + spec_max = features.amax(dim=(-2, -1), keepdim=True) + features = torch.maximum(features, spec_max - 8.0) + features = (features + 4.0) / 4.0 + return features def _get_features_lengths(self, audio_lengths, spectrogram_config, include_center_frame=False): stft_cfg = spectrogram_config.stft_config win_length = stft_cfg.win_length or stft_cfg.n_fft return (audio_lengths - win_length) // stft_cfg.hop_length + 1 - def _mel_filter_bank(self, spectrogram_config): - stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config - mel_filters_np = _np_spec.mel_filter_bank( - num_frequency_bins=1 + stft_cfg.n_fft // 2, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - ) - return torch.from_numpy(mel_filters_np).to(torch.float32) - __all__ = ["VoxtralRealtimeAudioProcessor"] diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index 63bc1058fc3e..b3aaacd4afa1 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -21,7 +21,7 @@ class WhisperAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True - return_attention_mask = False + return_padding_mask = False truncation = True max_length = 480000 # 30 seconds at 16000 Hz spectrogram_config = SpectrogramConfig( @@ -49,11 +49,5 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): return features - def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): - # Override to match WhisperFeatureExtractor's mel transformation order for numerical compatibility. - stacked = torch.stack(features) if isinstance(features, list) else features - mel_spec = torch.matmul(self.mel_filters.T, stacked) - return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) - __all__ = ["WhisperAudioProcessor"] From c7940d55959a681cd19faee64d7361ad84dfe408 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Fri, 3 Apr 2026 16:01:54 +0200 Subject: [PATCH 24/28] another round of updates --- src/transformers/audio_processing_backends.py | 126 +++++++++++++++++- src/transformers/audio_processing_utils.py | 92 ++++++++++--- src/transformers/audio_utils.py | 2 + ...rocessing_audio_spectrogram_transformer.py | 52 ++------ .../models/clvp/audio_processing_clvp.py | 9 +- .../models/dac/audio_processing_dac.py | 6 +- .../models/dia/audio_processing_dia.py | 6 +- .../encodec/audio_processing_encodec.py | 6 +- .../audio_processing_granite_speech.py | 93 ++++++------- .../audio_processing_kyutai_speech_to_text.py | 22 +-- .../audio_processing_musicgen_melody.py | 32 +---- .../audio_processing_phi4_multimodal.py | 33 +---- .../audio_processing_seamless_m4t.py | 78 ++++------- .../audio_processing_speech_to_text.py | 74 +++------- .../univnet/audio_processing_univnet.py | 16 +-- ...processing_vibevoice_acoustic_tokenizer.py | 6 +- .../audio_processing_voxtral_realtime.py | 2 +- .../whisper/audio_processing_whisper.py | 2 +- 18 files changed, 317 insertions(+), 340 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 5f5795ddaa42..620e94748f39 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -22,7 +22,7 @@ from .audio_processing_utils import BaseAudioProcessor from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db from .feature_extraction_utils import BatchFeature -from .utils import is_torch_available, logging +from .utils import PaddingStrategy, is_torch_available, logging logger = logging.get_logger(__name__) @@ -383,11 +383,76 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): return mel_filters + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + + if truncation and max_length is not None: + features = [f[:max_length] for f in features] + + actual_lengths = [f.shape[0] for f in features] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(actual_lengths) + padding_strategy = PaddingStrategy.MAX_LENGTH + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: + padded = [] + for f in features: + if f.shape[0] < max_length: + pad_width = [(0, max_length - f.shape[0])] + [(0, 0)] * (f.ndim - 1) + f = np.pad(f, pad_width, mode="constant", constant_values=self.padding_value) + padded.append(f) + features = padded + + feature_ranges = [(0, length) for length in actual_lengths] + return features, feature_ranges + + def _stack_features(self, features): + return np.stack(features) + + def _get_feature_mask(self, feature_ranges, padded_length): + mask = np.zeros((len(feature_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(feature_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} + + def _kaldi_fbank(self, waveform, num_mel_bins, sample_frequency=None, **kwargs): + """Extract kaldi-compatible fbank features for a single waveform. + + Uses torchaudio when available, falls back to the base spectrogram pipeline. + Returns a numpy array of shape (time, num_mel_bins). + """ + from .utils import is_speech_available + + if sample_frequency is None: + sample_frequency = self.sample_rate + + if is_speech_available(): + import torch + import torchaudio.compliance.kaldi as ta_kaldi + + waveform_tensor = torch.from_numpy(np.asarray(waveform)).unsqueeze(0) + fbank = ta_kaldi.fbank( + waveform_tensor, num_mel_bins=num_mel_bins, sample_frequency=sample_frequency, **kwargs + ) + return fbank.numpy() + else: + waveform = np.squeeze(waveform) + features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) + return features[0].T + def _to_batch(self, audio): - return np.stack(audio) + batch = np.stack(audio) + if self.add_channel_dim: + batch = batch[:, np.newaxis, :] + return batch def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - if do_extract_spectrogram: + use_audio_mask = self.mask_level == "audio" + if do_extract_spectrogram and not use_audio_mask: spec_cfg = spectrogram_config or self.spectrogram_config audio_lengths = np.array([end - start for start, end in audio_ranges]) features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) @@ -398,7 +463,8 @@ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectro mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) for i, (start, end) in enumerate(audio_ranges): mask[i, start:end] = 1 - return {"audio_values_mask": mask} + key = "audio_features_mask" if do_extract_spectrogram else "audio_values_mask" + return {key: mask} class TorchAudioBackend(BaseAudioProcessor): @@ -605,6 +671,9 @@ def _normalize_magnitude( else: raise ValueError(f"Unknown log_mel option: {log_mel}") + if spectrogram_config.skip_last_frame: + result = result[..., :-1] + return result def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): @@ -698,11 +767,53 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): mel_filters = mel_filters.to(torch.get_default_dtype()) return mel_filters + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + + if truncation and max_length is not None: + features = [f[:max_length] for f in features] + + actual_lengths = [f.shape[0] for f in features] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(actual_lengths) + padding_strategy = PaddingStrategy.MAX_LENGTH + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: + padded = [] + for f in features: + if f.shape[0] < max_length: + pad_amount = max_length - f.shape[0] + # Pad last dim=0 (time axis): F.pad takes innermost dims first + pad_args = [0, 0] * (f.ndim - 1) + [0, pad_amount] + f = torch.nn.functional.pad(f, pad_args, "constant", self.padding_value) + padded.append(f) + features = padded + + feature_ranges = [(0, length) for length in actual_lengths] + return features, feature_ranges + + def _stack_features(self, features): + return torch.stack(features) + + def _get_feature_mask(self, feature_ranges, padded_length): + mask = torch.zeros((len(feature_ranges), padded_length), dtype=torch.int32) + for i, (start, end) in enumerate(feature_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} + def _to_batch(self, audio): - return torch.stack(audio) + batch = torch.stack(audio) + if self.add_channel_dim: + batch = batch.unsqueeze(1) + return batch def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - if do_extract_spectrogram: + use_audio_mask = self.mask_level == "audio" + if do_extract_spectrogram and not use_audio_mask: spec_cfg = spectrogram_config or self.spectrogram_config audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) @@ -713,4 +824,5 @@ def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectro mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) for i, (start, end) in enumerate(audio_ranges): mask[i, start:end] = 1 - return {"audio_values_mask": mask} + key = "audio_features_mask" if do_extract_spectrogram else "audio_values_mask" + return {key: mask} diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index c1110bb993de..6395ed204416 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -51,6 +51,7 @@ class BaseAudioProcessor(AudioProcessingMixin): # global defaults sample_rate: int = None force_mono: bool = None + add_channel_dim: bool = False # padding defaults padding = True @@ -61,6 +62,7 @@ class BaseAudioProcessor(AudioProcessingMixin): pad_to_multiple_of = None return_padding_mask = True + mask_level = None # None = auto (features for spectrogram, audio for raw), "audio" = always audio-level spectrogram_config = None do_extract_spectrogram = None @@ -151,27 +153,75 @@ def _preprocess( return_tensors, spectrogram_config=None, do_extract_spectrogram=None, - do_batch_spectrogram=True, + do_batch_spectrogram=None, **kwargs, ) -> BatchFeature: - # pad and truncate - audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) - padded_length = audio[0].shape[-1] - - if do_extract_spectrogram: - audio = self._to_batch(audio) if do_batch_spectrogram else audio - feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, audio_ranges=audio_ranges) - output = {"audio_features": feature} + if do_batch_spectrogram is None: + do_batch_spectrogram = getattr(self, "do_batch_spectrogram", True) + if do_extract_spectrogram and not do_batch_spectrogram: + # Per-waveform extraction path: extract → postprocess → pad features → mask + features = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) + feature_lengths = [f.shape[0] for f in features] + features = self._postprocess_features(features, feature_lengths) + features, feature_ranges = self._pad_features( + features, padding, max_length, truncation, pad_to_multiple_of + ) + output = {"audio_features": self._stack_features(features)} + if self.return_padding_mask: + padded_length = features[0].shape[0] + output.update(self._get_feature_mask(feature_ranges, padded_length)) + output = self._postprocess_output(output, feature_ranges=feature_ranges, **kwargs) else: - output = {"audio_values": self._to_batch(audio)} + # Standard path: pad audio → optionally batch → extract/passthrough + audio, audio_ranges = self.pad(audio, padding, max_length, truncation, pad_to_multiple_of) + padded_length = audio[0].shape[-1] + + if do_extract_spectrogram: + audio = self._to_batch(audio) if do_batch_spectrogram else audio + feature = self.extract_spectrogram(audio, spectrogram_config=spectrogram_config, audio_ranges=audio_ranges, **kwargs) + output = {"audio_features": feature} + else: + output = {"audio_values": self._to_batch(audio)} - if self.return_padding_mask: - output.update(self._get_mask( - audio_ranges, padded_length, do_extract_spectrogram=do_extract_spectrogram, spectrogram_config=spectrogram_config - )) + if self.return_padding_mask: + output.update(self._get_mask( + audio_ranges, padded_length, do_extract_spectrogram=do_extract_spectrogram, spectrogram_config=spectrogram_config + )) + output = self._postprocess_output(output, audio_ranges=audio_ranges, **kwargs) return BatchFeature(data=output, tensor_type=return_tensors) + def _postprocess_features(self, features, feature_lengths): + """Hook: per-utterance feature processing after extraction, before feature-level padding. + + Override for normalization that must happen on unpadded features + (e.g., SeamlessM4t mean/variance normalization). + """ + return features + + def _postprocess_output(self, output, audio_ranges=None, feature_ranges=None, **kwargs): + """Hook: augment or modify the output dict after main processing. + + Override to add custom fields (e.g., audio_embed_sizes) or + post-hoc normalization on the stacked/batched output. + """ + return output + + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + """Pad a list of 2D feature arrays along the time axis (axis 0). + Implemented by backend subclasses.""" + raise NotImplementedError + + def _stack_features(self, features): + """Stack a list of feature arrays/tensors into a batch. + Implemented by backend subclasses.""" + raise NotImplementedError + + def _get_feature_mask(self, feature_ranges, padded_length): + """Build attention mask dict from feature_ranges. + Implemented by backend subclasses.""" + raise NotImplementedError + def _prepare_audio_like_inputs(self, audio: AudioInput, *args, sample_rate: int | None = None, **kwargs) -> list: """ Prepare audio-like inputs for processing by structuring and then converting each @@ -284,14 +334,20 @@ def _pad_single(self, audio, max_length: int) -> AudioInput: raise NotImplementedError def extract_spectrogram(self, audio, *, spectrogram_config: SpectrogramConfig | None = None, **kwargs): - # TODO: it might be a bit unclear to have extract_spectrogram and _extract_spectrogram methods. """ - Both the numpy and torch backends implement this method in a batched/ sequential manner. - Is is batched by default, but can be set to be sequential. + Extract spectrogram features from audio. + + Both the numpy and torch backends implement this method in a batched/sequential manner. + It is batched by default, but can be set to be sequential. This can extract just a spectrogram or a Mel spectrogram if a mel config is provided. Any extra kwargs whose names match ``SpectrogramConfig`` fields will override the corresponding value on the config for this call. + + Note: Models that bypass the base STFT pipeline entirely (e.g., GraniteSpeech + using torchaudio.transforms.MelSpectrogram, or MusicgenMelody using chroma + features) can set ``do_extract_spectrogram=True`` without providing a + ``spectrogram_config``. They must override this method completely. """ if spectrogram_config is None: spectrogram_config = self.spectrogram_config @@ -382,6 +438,8 @@ def _stft(self, audio, *, spectrogram_config, **kwargs): else: import torch audio = audio.to(getattr(torch, dtype_str)) + if spectrogram_config.waveform_scale is not None: + audio = audio * spectrogram_config.waveform_scale window = self._create_stft_window(win_length, stft_cfg, audio) window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 8a16df98207e..0db0cd2fc5e4 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -130,6 +130,7 @@ class SpectrogramConfig: mel_floor: float = 1e-10 waveform_scale: float | None = None computation_dtype: str | None = None + skip_last_frame: bool = False def __getitem__(self, key): if hasattr(self, key): @@ -170,6 +171,7 @@ def from_dict(cls, d: dict) -> "SpectrogramConfig": remove_dc_offset=d.get("remove_dc_offset", False), mel_floor=d.get("mel_floor", 1e-10), waveform_scale=d.get("waveform_scale"), + skip_last_frame=d.get("skip_last_frame", False), ) diff --git a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py index ca6cb558e542..7d9ba2cddec7 100644 --- a/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/audio_processing_audio_spectrogram_transformer.py @@ -14,17 +14,14 @@ import numpy as np -from ...audio_processing_backends import NumpyAudioBackend, TorchAudioBackend +from ...audio_processing_backends import NumpyAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature -from ...utils import is_torch_available - -class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_torch_available() else TorchAudioBackend): +class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True return_padding_mask = False - padding = False + do_batch_spectrogram = False max_length_frames = 1024 do_normalize = True @@ -56,41 +53,20 @@ class AudioSpectrogramTransformerAudioProcessor(NumpyAudioBackend if not is_torc mel_floor=1.192092955078125e-07, ) - def _preprocess( - self, - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - **kwargs, - ): - # Extract mel spectrogram features from raw audio using the base spectrogram pipeline - features = self.extract_spectrogram(audio, spectrogram_config=self.spectrogram_config) - - # extract_spectrogram returns list of (n_mels, frames); transpose to (frames, n_mels) - features = [f.T for f in features] + def extract_spectrogram(self, audio, **kwargs): + return [self._kaldi_fbank(waveform, num_mel_bins=128, window_type="hanning") for waveform in audio] - # Pad or truncate to max_length_frames - padded = [] - for fbank in features: - n_frames = fbank.shape[0] - if n_frames < self.max_length_frames: - pad_amount = self.max_length_frames - n_frames - fbank = np.pad(fbank, ((0, pad_amount), (0, 0)), mode="constant", constant_values=0.0) - elif n_frames > self.max_length_frames: - fbank = fbank[: self.max_length_frames, :] - padded.append(fbank) + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + # Always pad/truncate to max_length_frames regardless of caller's padding args + return super()._pad_features(features, "max_length", self.max_length_frames, True, pad_to_multiple_of) - # Normalize with AudioSet stats + def _postprocess_output(self, output, **kwargs): + # Rename to audio_values (AST convention) and apply AudioSet normalization + features = output.pop("audio_features") if self.do_normalize: - padded = [(f - self.ast_mean) / (self.ast_std * 2) for f in padded] - - stacked = np.stack(padded, axis=0) - return BatchFeature({"audio_values": stacked}, tensor_type=return_tensors) + features = (features - self.ast_mean) / (self.ast_std * 2) + output["audio_values"] = features + return output __all__ = ["AudioSpectrogramTransformerAudioProcessor"] diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index 503c606ab2df..7a57b890f86e 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -23,6 +23,7 @@ class ClvpAudioProcessor(NumpyAudioBackend): force_mono = True max_length = 132300 # 6 seconds at 22050 Hz truncation = True + mask_level = "audio" spectrogram_config = SpectrogramConfig( stft_config=StftConfig( @@ -56,12 +57,4 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = features / np.array(self.mel_norms)[:, None] return features.astype(np.float32) - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - """CLVP uses raw-audio-level mask even for spectrogram output.""" - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_features_mask": mask} - - __all__ = ["ClvpAudioProcessor"] diff --git a/src/transformers/models/dac/audio_processing_dac.py b/src/transformers/models/dac/audio_processing_dac.py index 2f8548eee900..f0a27bd57555 100644 --- a/src/transformers/models/dac/audio_processing_dac.py +++ b/src/transformers/models/dac/audio_processing_dac.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from ...audio_processing_backends import NumpyAudioBackend class DacAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True - - def _to_batch(self, audio): - return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + add_channel_dim = True __all__ = ["DacAudioProcessor"] diff --git a/src/transformers/models/dia/audio_processing_dia.py b/src/transformers/models/dia/audio_processing_dia.py index 9a2fbac08954..e1b7b0301e71 100644 --- a/src/transformers/models/dia/audio_processing_dia.py +++ b/src/transformers/models/dia/audio_processing_dia.py @@ -12,18 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from ...audio_processing_backends import NumpyAudioBackend class DiaAudioProcessor(NumpyAudioBackend): sample_rate = 44100 force_mono = True + add_channel_dim = True pad_to_multiple_of = 512 - def _to_batch(self, audio): - return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) - __all__ = ["DiaAudioProcessor"] diff --git a/src/transformers/models/encodec/audio_processing_encodec.py b/src/transformers/models/encodec/audio_processing_encodec.py index f52dedae59ec..022a7e145313 100644 --- a/src/transformers/models/encodec/audio_processing_encodec.py +++ b/src/transformers/models/encodec/audio_processing_encodec.py @@ -12,17 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np - from ...audio_processing_backends import NumpyAudioBackend class EncodecAudioProcessor(NumpyAudioBackend): sample_rate = 24000 force_mono = True - - def _to_batch(self, audio): - return np.stack(audio)[:, np.newaxis, :] # (batch, 1, length) + add_channel_dim = True __all__ = ["EncodecAudioProcessor"] diff --git a/src/transformers/models/granite_speech/audio_processing_granite_speech.py b/src/transformers/models/granite_speech/audio_processing_granite_speech.py index 3d68a8dc60c5..98915a5afeb9 100644 --- a/src/transformers/models/granite_speech/audio_processing_granite_speech.py +++ b/src/transformers/models/granite_speech/audio_processing_granite_speech.py @@ -17,69 +17,55 @@ import torch from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature class GraniteSpeechAudioProcessor(TorchAudioBackend): sample_rate = 16000 force_mono = True + return_padding_mask = False + do_extract_spectrogram = True projector_window_size = 15 projector_downsample_rate = 5 - spectrogram_config = SpectrogramConfig( - stft_config=StftConfig( - n_fft=512, - win_length=400, - hop_length=160, - power=2.0, - ), - mel_scale_config=MelScaleConfig( - n_mels=80, - ), - log_mode="log10", - ) + n_fft = 512 + win_length = 400 + hop_length = 160 + n_mels = 80 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + import torchaudio + + self.mel_filters_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.sample_rate, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + n_mels=self.n_mels, + ) def extract_spectrogram(self, audio, **kwargs): - features = super().extract_spectrogram(audio, **kwargs) - - processed = [] - for f in features: - # f is (n_mels, frames) from base; transpose to (frames, n_mels) - f = f.T - - # Apply max-8 normalization matching the FE - mx = f.amax(dim=(-2, -1), keepdim=True) - f = torch.maximum(f, mx - 8.0) - f = f / 4.0 + 1.0 - + # Use torchaudio MelSpectrogram to match upstream FE exactly + melspec = self.mel_filters_transform.to(device=audio.device) + with torch.no_grad(): + mel = melspec(audio.float()) + logmel = mel.transpose(-1, -2).clip_(min=1e-10).log10_() + mx = logmel.amax(dim=(-2, -1), keepdim=True) + logmel = torch.maximum(logmel, mx - 8.0).div_(4).add_(1) # Remove last frame if odd - if f.shape[0] % 2 == 1: - f = f[:-1] - - # Stack pairs of frames: (frames//2, n_mels*2) - f = f.reshape(-1, 2 * f.shape[-1]) - processed.append(f) - - return processed - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, - spectrogram_config=None, do_extract_spectrogram=None, **kwargs): - hop_length = self.spectrogram_config.stft_config.hop_length - - # Record original lengths before padding - audio_lengths = [a.shape[-1] for a in audio] - - # Pad audio to longest in batch - audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) + if logmel.shape[1] % 2 == 1: + logmel = logmel[:, :-1] + # Stacking by 2 + features = logmel.reshape(audio.shape[0], -1, 2 * logmel.shape[-1]) + return features - # Stack and extract spectrogram - audio_stacked = torch.stack(audio) - features = self.extract_spectrogram(audio_stacked, spectrogram_config=spectrogram_config) + def _postprocess_output(self, output, audio_ranges=None, **kwargs): + hop_length = self.hop_length - # Compute audio_embed_sizes matching the FE + # Compute audio_embed_sizes from original audio lengths effective_window_size = self.projector_window_size // self.projector_downsample_rate audio_embed_sizes = [] - for raw_length in audio_lengths: + for start, end in audio_ranges: + raw_length = end - start mel_length = raw_length // hop_length + 1 encoder_length = mel_length // 2 nblocks = math.ceil(encoder_length / self.projector_window_size) @@ -91,12 +77,9 @@ def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of audio_embed_sizes ).view(-1, 1) - data = { - "audio_features": features, - "audio_embed_sizes": audio_embed_sizes, - "audio_features_mask": input_features_mask, - } - return BatchFeature(data=data, tensor_type=return_tensors) + output["audio_embed_sizes"] = audio_embed_sizes + output["audio_features_mask"] = input_features_mask + return output __all__ = ["GraniteSpeechAudioProcessor"] diff --git a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py index 3cbe3782de24..a07b213a2c9d 100644 --- a/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/audio_processing_kyutai_speech_to_text.py @@ -20,32 +20,24 @@ class KyutaiSpeechToTextAudioProcessor(NumpyAudioBackend): sample_rate = 24000 force_mono = True + add_channel_dim = True audio_delay_seconds = 2.5 audio_silence_prefix_seconds = 1.0 - def _to_batch(self, audio): - return np.stack(audio)[:, np.newaxis, :] - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - kwargs.pop("do_extract_spectrogram", None) - result = super()._preprocess( - audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, - do_extract_spectrogram=False, **kwargs, - ) - + def _postprocess_output(self, output, **kwargs): # Add silence prefix (left) and delay (right) padding pad_left = int(self.audio_silence_prefix_seconds * self.sample_rate) pad_right = int((self.audio_delay_seconds + 1.0) * self.sample_rate) if pad_left > 0 or pad_right > 0: - result["audio_values"] = np.pad( - result["audio_values"], [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0, + output["audio_values"] = np.pad( + output["audio_values"], [(0, 0), (0, 0), (pad_left, pad_right)], mode="constant", constant_values=0.0, ) - result["audio_values_mask"] = np.pad( - result["audio_values_mask"], [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0, + output["audio_values_mask"] = np.pad( + output["audio_values_mask"], [(0, 0), (pad_left, pad_right)], mode="constant", constant_values=0, ) - return result + return output __all__ = ["KyutaiSpeechToTextAudioProcessor"] diff --git a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py index 0373dce62f86..1585ffae93d0 100644 --- a/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/audio_processing_musicgen_melody.py @@ -13,13 +13,14 @@ # limitations under the License. from ...audio_processing_backends import TorchAudioBackend -from ...feature_extraction_utils import BatchFeature from ...utils.import_utils import requires class MusicgenMelodyAudioProcessor(TorchAudioBackend): sample_rate = 32000 force_mono = True + do_extract_spectrogram = True + return_padding_mask = False n_fft = 16384 hop_length = 4096 n_chroma = 12 @@ -35,11 +36,11 @@ def __init__(self, **kwargs): librosa.filters.chroma(sr=self.sample_rate, n_fft=self.n_fft, tuning=0, n_chroma=self.n_chroma) ).float() - def extract_spectrogram(self, audio, *, spectrogram_config): + def extract_spectrogram(self, audio, **kwargs): import torch import torchaudio - waveform = torch.stack(audio, dim=0) + waveform = audio # Already a batched tensor from _to_batch device = waveform.device batch_size = waveform.shape[0] @@ -74,30 +75,7 @@ def extract_spectrogram(self, audio, *, spectrogram_config): norm_chroma[:] = 0 norm_chroma.scatter_(dim=-1, index=idx, value=1) - return [norm_chroma[i] for i in range(batch_size)] - - def _preprocess(self, audio, padding, max_length, truncation, pad_to_multiple_of, return_tensors, **kwargs): - import torch - - # Pad raw audio - if padding: - audio, _audio_ranges = self.pad(audio, padding=True, max_length=max_length) - - # Extract chroma features - features = self.extract_spectrogram(audio, spectrogram_config=None) - - # Pad features - max_feat_len = max(f.shape[0] for f in features) - padded = [] - for f in features: - if f.shape[0] < max_feat_len: - pad_amount = max_feat_len - f.shape[0] - f = torch.nn.functional.pad(f, (0, 0, 0, pad_amount), mode="constant", value=0.0) - padded.append(f) - - output_key = "audio_features" - stacked = torch.stack(padded, dim=0) - return BatchFeature(data={output_key: stacked}, tensor_type=return_tensors) + return norm_chroma __all__ = ["MusicgenMelodyAudioProcessor"] diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index 18d40f1a5c82..787d28ffb401 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -14,11 +14,8 @@ import torch -from spectrograms import numpy_mel_spectrogram as _np_spec - from ...audio_processing_backends import TorchAudioBackend -from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature +from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig, mel_filter_bank class Phi4MultimodalAudioProcessor(TorchAudioBackend): @@ -54,7 +51,7 @@ class Phi4MultimodalAudioProcessor(TorchAudioBackend): def _mel_filter_bank(self, spectrogram_config): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - mel_filters_np = _np_spec.mel_filter_bank( + mel_filters_np = mel_filter_bank( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -120,34 +117,10 @@ def _compute_audio_embed_size(self, audio_frames): return result - def _preprocess( - self, - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - **kwargs, - ) -> BatchFeature: - output = super()._preprocess( - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=spectrogram_config, - do_extract_spectrogram=do_extract_spectrogram, - **kwargs, - ) - + def _postprocess_output(self, output, **kwargs): feature_lengths = output["audio_features_mask"].sum(dim=-1) feature_lengths = feature_lengths * self.audio_feat_stride output["audio_embed_sizes"] = self._compute_audio_embed_size(feature_lengths) - return output diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index c734409952b4..3b53af82d2a6 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -16,13 +16,14 @@ from ...audio_processing_backends import NumpyAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature class SeamlessM4tAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True + do_batch_spectrogram = False stride = 2 + pad_to_multiple_of = 2 # Align feature padding to stride spectrogram_config = SpectrogramConfig( stft_config=StftConfig( @@ -48,14 +49,17 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): waveform_scale=32768.0, ) - def _extract_fbank_features(self, waveform): - """Extract log-mel filterbank features for a single waveform using the base spectrogram pipeline.""" - waveform = np.squeeze(waveform) * self.spectrogram_config.waveform_scale - features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) - # extract_spectrogram returns list of (n_mels, time); transpose to (time, n_mels) - return features[0].T - - def feature_normalize(self, features): + def extract_spectrogram(self, audio, **kwargs): + # Per-waveform fbank extraction returning (time, n_mels) + features = [] + for waveform in audio: + waveform = np.squeeze(waveform) + f = super().extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) + features.append(f[0].T) + return features + + def _postprocess_features(self, features, feature_lengths): + # Per-utterance mean/variance normalization (before padding) normalized = [] for f in features: mean = np.expand_dims(f.mean(axis=0), 0) @@ -63,55 +67,27 @@ def feature_normalize(self, features): normalized.append((f - mean) / np.sqrt(var + 1e-7)) return normalized - def _preprocess( - self, - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - **kwargs, - ): - # Extract features from raw (unpadded) audio, then pad at feature level - features = [self._extract_fbank_features(waveform) for waveform in audio] - features = self.feature_normalize(features) - - feature_lengths = [f.shape[0] for f in features] - - # Pad features to longest (pad_to_multiple_of stride) - max_len = max(feature_lengths) - if max_len % self.stride != 0: - max_len = ((max_len // self.stride) + 1) * self.stride - padded = [] - for f in features: - if f.shape[0] < max_len: - f = np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), mode="constant", constant_values=0.0) - padded.append(f) - - stacked = np.stack(padded, axis=0) - batch_size, num_frames, num_channels = stacked.shape - - # Feature-level attention_mask - attention_mask = np.zeros((batch_size, num_frames), dtype=np.int32) - for i, length in enumerate(feature_lengths): - attention_mask[i, :length] = 1 + def _postprocess_output(self, output, feature_ranges=None, **kwargs): + features = output["audio_features"] # (batch, num_frames, num_channels) + batch_size, num_frames, num_channels = features.shape # Stride concatenation remainder = num_frames % self.stride if remainder != 0: - stacked = stacked[:, : num_frames - remainder, :] - attention_mask = attention_mask[:, : num_frames - remainder] + features = features[:, :num_frames - remainder, :] num_frames = num_frames - remainder - stacked = stacked.reshape(batch_size, num_frames // self.stride, num_channels * self.stride) - indices = np.arange(0, num_frames) - attention_mask = attention_mask[:, indices % self.stride == 1] + output["audio_features"] = features.reshape(batch_size, num_frames // self.stride, num_channels * self.stride) + + # Adjust mask for stride + if "audio_features_mask" in output: + mask = output["audio_features_mask"] + if remainder != 0: + mask = mask[:, :num_frames] + indices = np.arange(0, num_frames) + output["audio_features_mask"] = mask[:, indices % self.stride == 1] - data = {"audio_features": stacked, "audio_features_mask": attention_mask} - return BatchFeature(data=data, tensor_type=return_tensors) + return output __all__ = ["SeamlessM4tAudioProcessor"] diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index 9c66ea182404..29d80e383f50 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -16,18 +16,11 @@ from ...audio_processing_backends import NumpyAudioBackend from ...audio_utils import MelScaleConfig, SpectrogramConfig, StftConfig -from ...feature_extraction_utils import BatchFeature -from ...utils import is_speech_available - - -if is_speech_available(): - import torch - import torchaudio.compliance.kaldi as ta_kaldi - class SpeechToTextAudioProcessor(NumpyAudioBackend): sample_rate = 16000 force_mono = True + do_batch_spectrogram = False spectrogram_config = SpectrogramConfig( stft_config=StftConfig( @@ -61,15 +54,11 @@ def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): def _extract_fbank_features(self, waveform): """Extract log-mel filterbank features for a single waveform.""" waveform = waveform * self.spectrogram_config.waveform_scale - if is_speech_available(): - waveform_tensor = torch.from_numpy(waveform).unsqueeze(0) - features = ta_kaldi.fbank(waveform_tensor, num_mel_bins=80, sample_frequency=self.sample_rate) - return features.numpy() - else: - waveform = np.squeeze(waveform) - features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) - # extract_spectrogram returns list of (n_mels, time); transpose to (time, n_mels) - return features[0].T + return self._kaldi_fbank(waveform, num_mel_bins=80) + + def extract_spectrogram(self, audio, **kwargs): + # Per-waveform fbank extraction returning (time, n_mels) + return [self._extract_fbank_features(waveform) for waveform in audio] @staticmethod def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, padding_value=0.0): @@ -83,46 +72,17 @@ def utterance_cmvn(x, input_length, normalize_means=True, normalize_vars=True, p x[input_length:] = padding_value return x.astype(np.float32) - def _preprocess( - self, - audio, - padding, - max_length, - truncation, - pad_to_multiple_of, - return_tensors, - spectrogram_config=None, - do_extract_spectrogram=None, - **kwargs, - ): - # Extract features from raw (unpadded) audio, then pad at feature level - features = [self._extract_fbank_features(waveform) for waveform in audio] - lengths = [f.shape[0] for f in features] - - # Pad features to longest - max_len = max(lengths) - padded = [] - for f in features: - if f.shape[0] < max_len: - f = np.pad(f, ((0, max_len - f.shape[0]), (0, 0)), mode="constant", constant_values=0.0) - padded.append(f) - - # Utterance CMVN normalization - normalized = [ - self.utterance_cmvn(f, length, self.normalize_means, self.normalize_vars, self.padding_value) - for f, length in zip(padded, lengths) - ] - - stacked = np.stack(normalized, axis=0) - data = {"audio_features": stacked} - - if self.return_padding_mask: - attention_mask = np.zeros((len(lengths), max_len), dtype=np.int32) - for i, length in enumerate(lengths): - attention_mask[i, :length] = 1 - data["audio_features_mask"] = attention_mask - - return BatchFeature(data=data, tensor_type=return_tensors) + def _postprocess_output(self, output, feature_ranges=None, **kwargs): + # Apply utterance CMVN normalization on the padded, stacked features + features = output["audio_features"] # (batch, time, n_mels) + normalized = [] + for i, (start, end) in enumerate(feature_ranges): + length = end - start + normalized.append( + self.utterance_cmvn(features[i], length, self.normalize_means, self.normalize_vars, self.padding_value) + ) + output["audio_features"] = np.stack(normalized) + return output __all__ = ["SpeechToTextAudioProcessor"] diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index 633ce121aca5..58fd8e4ea32c 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -21,11 +21,7 @@ class UnivNetAudioProcessor(NumpyAudioBackend): sample_rate = 24000 force_mono = True - n_fft = 1024 - hop_length = 256 - n_mels = 100 - fmin = 0.0 - fmax = 12000.0 + mask_level = "audio" mel_floor = 1e-9 compression_clip_val = 1e-5 compression_factor = 1.0 @@ -59,7 +55,8 @@ def __init__(self, **kwargs): def _stft(self, audio, *, spectrogram_config, **kwargs): # UnivNet uses reflect padding with (n_fft - hop_length) / 2 instead of center padding - pad_amount = int((self.n_fft - self.hop_length) / 2) + stft_cfg = spectrogram_config.stft_config + pad_amount = int((stft_cfg.n_fft - stft_cfg.hop_length) / 2) if audio.ndim > 1: audio = np.pad(audio, ((0, 0), (pad_amount, pad_amount)), mode="reflect") else: @@ -80,13 +77,6 @@ def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = 2 * ((features - self.normalize_min) / (self.normalize_max - self.normalize_min)) - 1 return features - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - # UnivNet uses waveform-level padding mask even when extracting spectrograms - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - return {"audio_features_mask": mask} - def extract_spectrogram(self, audio, *, spectrogram_config, **kwargs): features = super().extract_spectrogram(audio, spectrogram_config=spectrogram_config, **kwargs) # Transpose from (..., n_mels, frames) to (..., frames, n_mels) diff --git a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py index 1895e49fdcbe..866113b39b82 100644 --- a/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py +++ b/src/transformers/models/vibevoice_acoustic_tokenizer/audio_processing_vibevoice_acoustic_tokenizer.py @@ -20,6 +20,7 @@ class VibevoiceAcousticTokenizerAudioProcessor(TorchAudioBackend): sample_rate = 24000 force_mono = True + add_channel_dim = True target_dB_FS = -25 eps = 1e-6 @@ -33,10 +34,5 @@ def _process_audio(self, audio_el): audio_el = audio_el / (max_val + self.eps) return audio_el - def _preprocess(self, audio, **kwargs): - result = super()._preprocess(audio, **kwargs) - result["audio_values"] = result["audio_values"].unsqueeze(1) - return result - __all__ = ["VibevoiceAcousticTokenizerAudioProcessor"] diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index 2427ce9b36dc..7fff372c318d 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -34,12 +34,12 @@ class VoxtralRealtimeAudioProcessor(TorchAudioBackend): computation_dtype="float64", ), log_mode="log10", + skip_last_frame=True, ) global_log_mel_max = 1.5 def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) - features = features[..., :-1] if self.global_log_mel_max is not None: spec_max = torch.tensor(self.global_log_mel_max, device=features.device, dtype=features.dtype) diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index b3aaacd4afa1..eb9f3dd570f5 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -37,11 +37,11 @@ class WhisperAudioProcessor(TorchAudioBackend): computation_dtype="float64", ), log_mode="log10", + skip_last_frame=True, ) def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) - features = features[..., :-1] # whisper skips last frame max_vals = features.amax(dim=(-2, -1), keepdim=True) features = torch.maximum(features, max_vals - 8.0) From 48912a9f19bef83e046c49d95675b4876e14d6db Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Fri, 3 Apr 2026 22:37:36 +0200 Subject: [PATCH 25/28] some more updates --- src/transformers/audio_processing_backends.py | 105 ++-- src/transformers/audio_processing_utils.py | 9 +- src/transformers/audio_utils.py | 514 ------------------ .../models/clap/audio_processing_clap.py | 1 + .../models/clvp/audio_processing_clvp.py | 1 + .../gemma3n/audio_processing_gemma3n.py | 1 + .../parakeet/audio_processing_parakeet.py | 2 +- .../audio_processing_phi4_multimodal.py | 2 +- .../audio_processing_seamless_m4t.py | 5 +- .../audio_processing_speech_to_text.py | 4 +- .../univnet/audio_processing_univnet.py | 3 +- 11 files changed, 66 insertions(+), 581 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 620e94748f39..10155a4002d1 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -17,10 +17,9 @@ import math import numpy as np -import librosa from .audio_processing_utils import BaseAudioProcessor -from .audio_utils import SpectrogramConfig, amplitude_to_db, power_to_db +from .audio_utils import SpectrogramConfig, amplitude_to_db, mel_filter_bank, power_to_db from .feature_extraction_utils import BatchFeature from .utils import PaddingStrategy, is_torch_available, logging @@ -195,6 +194,14 @@ def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_fr frame_length = n_fft return window, frame_length + @staticmethod + def _np_frame(x, frame_length, hop_length): + """Create overlapping frames from a 1D array using stride tricks (replaces librosa.util.frame).""" + n_frames = 1 + (x.shape[-1] - frame_length) // hop_length + strides = x.strides[:-1] + (x.strides[-1] * hop_length, x.strides[-1]) + shape = x.shape[:-1] + (n_frames, frame_length) + return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad_mode): squeezed = waveform.ndim == 1 if squeezed: @@ -217,9 +224,8 @@ def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad padding, mode=pad_mode, ) - y_frames_pre = librosa.util.frame(y_pre, frame_length=frame_length, hop_length=hop_length) - y_frames_pre = y_frames_pre[..., :start_k] - y_frames_pre = np.moveaxis(y_frames_pre, -2, -1) + y_frames_pre = self._np_frame(y_pre, frame_length, hop_length) + y_frames_pre = y_frames_pre[..., :start_k, :] extra = y_frames_pre.shape[-2] padding[-1] = (0, frame_length // 2) @@ -228,15 +234,13 @@ def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad padding, mode=pad_mode, ) - y_frames_post = librosa.util.frame(y_post, frame_length=frame_length, hop_length=hop_length) - y_frames_post = np.moveaxis(y_frames_post, -2, -1) + y_frames_post = self._np_frame(y_post, frame_length, hop_length) extra += y_frames_post.shape[-2] start = start_k * hop_length - n_fft // 2 - y_frames_middle = librosa.util.frame( - waveform[..., start:], frame_length=frame_length, hop_length=hop_length + y_frames_middle = self._np_frame( + np.ascontiguousarray(waveform[..., start:]), frame_length, hop_length ) - y_frames_middle = np.moveaxis(y_frames_middle, -2, -1) num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2] frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2) @@ -255,8 +259,11 @@ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg) compute_dtype = np.result_type(audio.dtype, window.dtype) return frames.astype(compute_dtype, copy=False) - def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None): frames = frames * window + # Always store FFT output as complex64, matching the upstream spectrogram() function. + # FFT is computed in float64 (numpy default), but the complex64 cast ensures consistent + # precision with librosa and the legacy FE code path. spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64) if stft_cfg.normalized: spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) @@ -272,8 +279,13 @@ def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg) spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) return np.moveaxis(spec, -1, -2) - def _compute_magnitudes(self, stft_out, power): - return np.abs(stft_out, dtype=np.float64) ** power + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): + # When computation_dtype is set (e.g., "float64"), compute magnitudes in that dtype + # to match the upstream FE precision path: np.abs(complex64, dtype=float64) ** power. + # Otherwise, use the natural dtype (float32 from complex64) to match librosa. + if spectrogram_config and spectrogram_config.computation_dtype: + return np.abs(stft_out, dtype=np.float64) ** power + return np.abs(stft_out) ** power def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): """Apply per-frame signal conditioning using the numpy backend.""" @@ -295,10 +307,11 @@ def _apply_mel_scale( **kwargs, ) -> list[np.ndarray]: """Apply mel filterbank to spectrogram features using the numpy backend.""" + mel_filters = self.mel_filters.astype(features.dtype, copy=False) if spectrogram_config.mel_scale_config.matmul_order == "features_first": - mel_spec = np.matmul(features, self.mel_filters) + mel_spec = np.matmul(features, mel_filters) else: - mel_spec = np.matmul(self.mel_filters.T, features) + mel_spec = np.matmul(mel_filters.T, features) return np.maximum(spectrogram_config.mel_floor, mel_spec) def _normalize_magnitude( @@ -315,14 +328,14 @@ def _normalize_magnitude( """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. Accepts a single or batched spectrogram (not a list). - Mirrors the normalization logic in `audio_utils.spectrogram()`. + Mirrors the normalization logic in the spectrogram pipeline. """ log_mel = spectrogram_config.log_mode mel_floor = spectrogram_config.mel_floor power = spectrogram_config.stft_config.power if log_mel is None: - return features + return features.astype(dtype) # Clamp to mel_floor before taking log result = np.maximum(mel_floor, features) @@ -346,42 +359,21 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - num_frequency_bins = 1 + stft_cfg.n_fft // 2 - num_mel_filters = mel_cfg.n_mels - min_frequency = mel_cfg.f_min - max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 - sampling_rate = self.sample_rate - - mel_min = _np_hertz_to_mel(min_frequency, mel_scale=mel_cfg.mel_scale) - mel_max = _np_hertz_to_mel(max_frequency, mel_scale=mel_cfg.mel_scale) - mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2) - filter_freqs = _np_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale) - - n_fft = (num_frequency_bins - 1) * 2 - - if mel_cfg.triangularize_in_mel_space: - fft_bin_width = sampling_rate / n_fft - fft_freqs = _np_hertz_to_mel( - fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_cfg.mel_scale - ) - filter_freqs = mel_freqs - elif mel_cfg.frequency_bin_mode == "rfft": - fft_freqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) - else: - fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) - - # Triangular filter bank - filter_diff = np.diff(filter_freqs) - slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1) - down_slopes = -slopes[:, :-2] / filter_diff[:-1] - up_slopes = slopes[:, 2:] / filter_diff[1:] - mel_filters = np.maximum(0, np.minimum(down_slopes, up_slopes)) - - if mel_cfg.norm == "slaney": - enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) - mel_filters *= np.expand_dims(enorm, 0) - - return mel_filters + filters = mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + ) + # Store as float32 to match librosa's precision path. Processors needing float64 + # set computation_dtype, which keeps filters in float64 for exact upstream FE matching. + if not spectrogram_config.computation_dtype: + filters = filters.astype(np.float32) + return filters def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) @@ -557,7 +549,7 @@ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg) ) return audio.unfold(-1, frame_length, hop_length) - def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None): frames = frames * window if frame_length < n_fft: frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length)) @@ -587,7 +579,7 @@ def _cast_stft_output(self, magnitudes, spectrogram_config): return magnitudes return magnitudes.float() - def _compute_magnitudes(self, stft_out, power): + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): """Convert complex STFT output to a real-valued magnitude spectrogram.""" return stft_out.abs() ** power @@ -615,7 +607,8 @@ def _apply_mel_scale( if spectrogram_config.mel_scale_config.matmul_order == "features_first": mel_spec = torch.matmul(features.transpose(-2, -1), mel_filters) else: - mel_spec = torch.matmul(mel_filters.T, features) + # Use F.linear to match torchaudio's MelScale implementation exactly + mel_spec = torch.nn.functional.linear(features.transpose(-2, -1), mel_filters.T).transpose(-2, -1) return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) def _normalize_magnitude( diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 6395ed204416..0419d3cdc008 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -444,13 +444,14 @@ def _stft(self, audio, *, spectrogram_config, **kwargs): window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) if needs_manual_framing: + audio_dtype = audio.dtype frames = self._frame_audio(audio, window, frame_length, hop_length, n_fft, stft_cfg) frames = self._apply_frame_processing(frames, spectrogram_config=spectrogram_config, **kwargs) - stft_out = self._window_and_fft(frames, window, frame_length, n_fft, stft_cfg) + stft_out = self._window_and_fft(frames, window, frame_length, n_fft, stft_cfg, audio_dtype=audio_dtype) else: stft_out = self._native_stft(audio, window, frame_length, hop_length, n_fft, stft_cfg) - magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power) + magnitudes = self._compute_magnitudes(stft_out, stft_cfg.power, spectrogram_config=spectrogram_config) return self._cast_stft_output(magnitudes, spectrogram_config) def _create_stft_window(self, win_length, stft_cfg, audio): @@ -478,7 +479,7 @@ def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg) """Native STFT (e.g. torch.stft). Returns complex output. Implemented by backend subclasses.""" raise NotImplementedError - def _compute_magnitudes(self, stft_out, power): + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): """Convert complex STFT output to a real-valued magnitude spectrogram. Implemented by backend subclasses. Overridable for custom magnitude computation (e.g. Parakeet).""" raise NotImplementedError @@ -502,7 +503,7 @@ def _needs_manual_framing(self, spectrogram_config): or spectrogram_config.remove_dc_offset ) - def _compute_magnitudes(self, stft_out, power): + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): """Convert complex STFT output to a real-valued magnitude spectrogram. Only used in the non-manual-framing STFT path. Override for diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index 0db0cd2fc5e4..e9d2da3707d3 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -741,428 +741,6 @@ def window_function( return padded_window -# Note: This method processes a single waveform. For batch processing, use spectrogram_batch(). -def spectrogram( - waveform: np.ndarray, - window: np.ndarray, - frame_length: int, - hop_length: int, - fft_length: int | None = None, - power: float | None = 1.0, - center: bool = True, - pad_mode: str = "reflect", - onesided: bool = True, - dither: float = 0.0, - preemphasis: float | None = None, - mel_filters: np.ndarray | None = None, - mel_floor: float = 1e-10, - log_mel: str | None = None, - reference: float = 1.0, - min_value: float = 1e-10, - db_range: float | None = None, - remove_dc_offset: bool = False, - dtype: np.dtype = np.float32, -) -> np.ndarray: - """ - Calculates a spectrogram over one waveform using the Short-Time Fourier Transform. - - This function can create the following kinds of spectrograms: - - - amplitude spectrogram (`power = 1.0`) - - power spectrogram (`power = 2.0`) - - complex-valued spectrogram (`power = None`) - - log spectrogram (use `log_mel` argument) - - mel spectrogram (provide `mel_filters`) - - log-mel spectrogram (provide `mel_filters` and `log_mel`) - - How this works: - - 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length - - hop_length` samples. - 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. - 3. The DFT is taken of each windowed frame. - 4. The results are stacked into a spectrogram. - - We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: - - - The analysis frame. This is the size of the time slices that the input waveform is split into. - - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. - - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. - - In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A - padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, - typically the next power of two. - - Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and - `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms - can be constructed. - - Args: - waveform (`np.ndarray` of shape `(length,)`): - The input waveform. This must be a single real-valued, mono waveform. - window (`np.ndarray` of shape `(frame_length,)`): - The windowing function to apply, including zero-padding if necessary. The actual window length may be - shorter than `frame_length`, but we're assuming the array has already been zero-padded. - frame_length (`int`): - The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also - allow smaller sizes. - hop_length (`int`): - The stride between successive analysis frames in samples. - fft_length (`int`, *optional*): - The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have. - For optimal speed, this should be a power of two. If `None`, uses `frame_length`. - power (`float`, *optional*, defaults to 1.0): - If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns - complex numbers. - center (`bool`, *optional*, defaults to `True`): - Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame - `t` will start at time `t * hop_length`. - pad_mode (`str`, *optional*, defaults to `"reflect"`): - Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"` - (pad with edge values), `"reflect"` (pads with mirrored values). - onesided (`bool`, *optional*, defaults to `True`): - If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1` - frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering. In other words, adds a small Gaussian noise to each frame. - E.g. use 4.0 to add dithering with a normal distribution centered - around 0.0 with standard deviation 4.0, 0.0 means no dithering. - Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank - values for signals with hard-zero sections, when VAD cutoff is present in the signal. - preemphasis (`float`, *optional*) - Coefficient for a low-pass filter that applies pre-emphasis before the DFT. - mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*): - The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram. - mel_floor (`float`, *optional*, defaults to 1e-10): - Minimum value of mel frequency banks. - log_mel (`str`, *optional*): - How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take - the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be - used when `power` is not `None`. - reference (`float`, *optional*, defaults to 1.0): - Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set - the loudest part to 0 dB. Must be greater than zero. - min_value (`float`, *optional*, defaults to `1e-10`): - The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking - `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an - amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero. - db_range (`float`, *optional*): - Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the - peak value and the smallest value will never be more than 80 dB. Must be greater than zero. - remove_dc_offset (`bool`, *optional*): - Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in - order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters. - dtype (`np.dtype`, *optional*, defaults to `np.float32`): - Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be - `np.complex64`. - - Returns: - `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape - `(num_mel_filters, length)` for a mel spectrogram. - """ - window_length = len(window) - - if fft_length is None: - fft_length = frame_length - - if frame_length > fft_length: - raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") - - if window_length != frame_length: - raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") - - if hop_length <= 0: - raise ValueError("hop_length must be greater than zero") - - if waveform.ndim != 1: - raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") - - if np.iscomplexobj(waveform): - raise ValueError("Complex-valued input waveforms are not currently supported") - - if power is None and mel_filters is not None: - raise ValueError( - "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram." - "Specify `power` to fix this issue." - ) - - # center pad the waveform - if center: - padding = [(int(frame_length // 2), int(frame_length // 2))] - waveform = np.pad(waveform, padding, mode=pad_mode) - - # promote to float64, since np.fft uses float64 internally - waveform = waveform.astype(np.float64) - window = window.astype(np.float64) - - # split waveform into frames of frame_length size - num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length)) - - num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length - spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64) - - # rfft is faster than fft - fft_func = np.fft.rfft if onesided else np.fft.fft - buffer = np.zeros(fft_length) - - timestep = 0 - for frame_idx in range(num_frames): - buffer[:frame_length] = waveform[timestep : timestep + frame_length] - - if dither != 0.0: - buffer[:frame_length] += dither * np.random.randn(frame_length) - - if remove_dc_offset: - buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean() - - if preemphasis is not None: - buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1] - buffer[0] *= 1 - preemphasis - - buffer[:frame_length] *= window - - spectrogram[frame_idx] = fft_func(buffer) - timestep += hop_length - - # note: ** is much faster than np.power - if power is not None: - spectrogram = np.abs(spectrogram, dtype=np.float64) ** power - - spectrogram = spectrogram.T - - if mel_filters is not None: - spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram)) - - if power is not None and log_mel is not None: - if log_mel == "log": - spectrogram = np.log(spectrogram) - elif log_mel == "log10": - spectrogram = np.log10(spectrogram) - elif log_mel == "dB": - if power == 1.0: - spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range) - elif power == 2.0: - spectrogram = power_to_db(spectrogram, reference, min_value, db_range) - else: - raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") - else: - raise ValueError(f"Unknown log_mel option: {log_mel}") - - spectrogram = np.asarray(spectrogram, dtype) - - return spectrogram - - -def spectrogram_batch( - waveform_list: list[np.ndarray], - window: np.ndarray, - frame_length: int, - hop_length: int, - fft_length: int | None = None, - power: float | None = 1.0, - center: bool = True, - pad_mode: str = "reflect", - onesided: bool = True, - dither: float = 0.0, - preemphasis: float | None = None, - mel_filters: np.ndarray | None = None, - mel_floor: float = 1e-10, - log_mel: str | None = None, - reference: float = 1.0, - min_value: float = 1e-10, - db_range: float | None = None, - remove_dc_offset: bool = False, - dtype: np.dtype = np.float32, -) -> list[np.ndarray]: - """ - Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing. - This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting. - - It supports generating various types of spectrograms: - - - amplitude spectrogram (`power = 1.0`) - - power spectrogram (`power = 2.0`) - - complex-valued spectrogram (`power = None`) - - log spectrogram (use `log_mel` argument) - - mel spectrogram (provide `mel_filters`) - - log-mel spectrogram (provide `mel_filters` and `log_mel`) - - How this works: - - 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length - - hop_length` samples. - 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`. - 3. The DFT is taken of each windowed frame. - 4. The results are stacked into a spectrogram. - - We make a distinction between the following "blocks" of sample data, each of which may have a different lengths: - - - The analysis frame. This is the size of the time slices that the input waveform is split into. - - The window. Each analysis frame is multiplied by the window to avoid spectral leakage. - - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram. - - In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A - padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame, - typically the next power of two. - - Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`. - - Args: - waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`): - The list of input waveforms, each a single-channel (mono) signal. - window (`np.ndarray` of shape `(frame_length,)`): - The windowing function to apply, including zero-padding if necessary. - frame_length (`int`): - The length of each frame for analysis. - hop_length (`int`): - The step size between successive frames. - fft_length (`int`, *optional*): - The size of the FFT buffer, defining frequency bin resolution. - power (`float`, *optional*, defaults to 1.0): - Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex. - center (`bool`, *optional*, defaults to `True`): - Whether to center-pad the waveform frames. - pad_mode (`str`, *optional*, defaults to `"reflect"`): - The padding strategy when `center` is `True`. - onesided (`bool`, *optional*, defaults to `True`): - If True, returns a one-sided spectrogram for real input signals. - dither (`float`, *optional*, defaults to 0.0): - Adds dithering. In other words, adds a small Gaussian noise to each frame. - E.g. use 4.0 to add dithering with a normal distribution centered - around 0.0 with standard deviation 4.0, 0.0 means no dithering. - preemphasis (`float`, *optional*): - Applies a pre-emphasis filter to each frame. - mel_filters (`np.ndarray`, *optional*): - Mel filter bank for converting to mel spectrogram. - mel_floor (`float`, *optional*, defaults to 1e-10): - Floor value for mel spectrogram to avoid log(0). - log_mel (`str`, *optional*): - Specifies log scaling strategy; options are None, "log", "log10", "dB". - reference (`float`, *optional*, defaults to 1.0): - Reference value for dB conversion in log_mel. - min_value (`float`, *optional*, defaults to 1e-10): - Minimum floor value for log scale conversions. - db_range (`float`, *optional*): - Dynamic range for dB scale spectrograms. - remove_dc_offset (`bool`, *optional*): - Whether to remove the DC offset from each frame. - dtype (`np.dtype`, *optional*, defaults to `np.float32`): - Data type of the output spectrogram. - - Returns: - list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform. - """ - window_length = len(window) - - if fft_length is None: - fft_length = frame_length - - if frame_length > fft_length: - raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})") - - if window_length != frame_length: - raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})") - - if hop_length <= 0: - raise ValueError("hop_length must be greater than zero") - - # Check the dimensions of the waveform , and if waveform is complex - for waveform in waveform_list: - if waveform.ndim != 1: - raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}") - if np.iscomplexobj(waveform): - raise ValueError("Complex-valued input waveforms are not currently supported") - # Center pad the waveform - if center: - padding = [(int(frame_length // 2), int(frame_length // 2))] - waveform_list = [ - np.pad( - waveform, - padding, - mode=pad_mode, - ) - for waveform in waveform_list - ] - original_waveform_lengths = [ - len(waveform) for waveform in waveform_list - ] # these lengths will be used to remove padding later - - # Batch pad the waveform - max_length = max(original_waveform_lengths) - padded_waveform_batch = np.array( - [ - np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0) - for waveform in waveform_list - ], - dtype=dtype, - ) - - # Promote to float64, since np.fft uses float64 internally - padded_waveform_batch = padded_waveform_batch.astype(np.float64) - window = window.astype(np.float64) - - # Split waveform into frames of frame_length size - num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length)) - # these lengths will be used to remove padding later - true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths] - num_batches = padded_waveform_batch.shape[0] - - num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length - spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64) - - # rfft is faster than fft - fft_func = np.fft.rfft if onesided else np.fft.fft - buffer = np.zeros((num_batches, fft_length)) - - for frame_idx in range(num_frames): - timestep = frame_idx * hop_length - buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length] - - if dither != 0.0: - buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape) - - if remove_dc_offset: - buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True) - - if preemphasis is not None: - buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1] - buffer[:, 0] *= 1 - preemphasis - - buffer[:, :frame_length] *= window - - spectrogram[:, frame_idx] = fft_func(buffer) - - # Note: ** is much faster than np.power - if power is not None: - spectrogram = np.abs(spectrogram, dtype=np.float64) ** power - - # Apply mel filters if provided - if mel_filters is not None: - result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1])) - spectrogram = np.maximum(mel_floor, result) - - # Convert to log scale if specified - if power is not None and log_mel is not None: - if log_mel == "log": - spectrogram = np.log(spectrogram) - elif log_mel == "log10": - spectrogram = np.log10(spectrogram) - elif log_mel == "dB": - if power == 1.0: - spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range) - elif power == 2.0: - spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range) - else: - raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}") - else: - raise ValueError(f"Unknown log_mel option: {log_mel}") - - spectrogram = np.asarray(spectrogram, dtype) - - spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))] - - return spectrogram_list - def power_to_db( spectrogram: np.ndarray, @@ -1215,55 +793,6 @@ def power_to_db( return spectrogram -def power_to_db_batch( - spectrogram: np.ndarray, - reference: float = 1.0, - min_value: float = 1e-10, - db_range: float | None = None, -) -> np.ndarray: - """ - Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`, - using basic logarithm properties for numerical stability. - - This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram. - - Args: - spectrogram (`np.ndarray`): - The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). - Note that a power spectrogram has the amplitudes squared! - reference (`float`, *optional*, defaults to 1.0): - Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set - the loudest part to 0 dB. Must be greater than zero. - min_value (`float`, *optional*, defaults to `1e-10`): - The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking - `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero. - db_range (`float`, *optional*): - Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the - peak value and the smallest value will never be more than 80 dB. Must be greater than zero. - - Returns: - `np.ndarray`: the batch of spectrograms in decibels - """ - if reference <= 0.0: - raise ValueError("reference must be greater than zero") - if min_value <= 0.0: - raise ValueError("min_value must be greater than zero") - - reference = max(min_value, reference) - - spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) - spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference)) - - if db_range is not None: - if db_range <= 0.0: - raise ValueError("db_range must be greater than zero") - # Apply db_range clipping per batch item - max_values = spectrogram.max(axis=(1, 2), keepdims=True) - spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) - - return spectrogram - - def amplitude_to_db( spectrogram: np.ndarray, reference: float = 1.0, @@ -1313,46 +842,3 @@ def amplitude_to_db( return spectrogram -def amplitude_to_db_batch( - spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: float | None = None -) -> np.ndarray: - """ - Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`, - using basic logarithm properties for numerical stability. - - The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram. - - Args: - spectrogram (`np.ndarray`): - The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape). - reference (`float`, *optional*, defaults to 1.0): - Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set - the loudest part to 0 dB. Must be greater than zero. - min_value (`float`, *optional*, defaults to `1e-5`): - The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking - `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero. - db_range (`float`, *optional*): - Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the - peak value and the smallest value will never be more than 80 dB. Must be greater than zero. - - Returns: - `np.ndarray`: the batch of spectrograms in decibels - """ - if reference <= 0.0: - raise ValueError("reference must be greater than zero") - if min_value <= 0.0: - raise ValueError("min_value must be greater than zero") - - reference = max(min_value, reference) - - spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None) - spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference)) - - if db_range is not None: - if db_range <= 0.0: - raise ValueError("db_range must be greater than zero") - # Apply db_range clipping per batch item - max_values = spectrogram.max(axis=(1, 2), keepdims=True) - spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None) - - return spectrogram diff --git a/src/transformers/models/clap/audio_processing_clap.py b/src/transformers/models/clap/audio_processing_clap.py index e18b8e542007..20525c3bfce3 100644 --- a/src/transformers/models/clap/audio_processing_clap.py +++ b/src/transformers/models/clap/audio_processing_clap.py @@ -37,6 +37,7 @@ def __init__(self, **kwargs): stft_config=StftConfig(n_fft=1024, hop_length=480, power=2.0), mel_scale_config=self._mel_configs[truncation_mode], log_mode="dB", + computation_dtype="float64", ) super().__init__(**kwargs) # rand_trunc: base class truncates via pad() → _truncate_single (random offset) diff --git a/src/transformers/models/clvp/audio_processing_clvp.py b/src/transformers/models/clvp/audio_processing_clvp.py index 7a57b890f86e..6272c795d1fd 100644 --- a/src/transformers/models/clvp/audio_processing_clvp.py +++ b/src/transformers/models/clvp/audio_processing_clvp.py @@ -42,6 +42,7 @@ class ClvpAudioProcessor(NumpyAudioBackend): ), log_mode="log", mel_floor=1e-5, + computation_dtype="float64", ) def __init__(self, mel_norms=None, **kwargs): diff --git a/src/transformers/models/gemma3n/audio_processing_gemma3n.py b/src/transformers/models/gemma3n/audio_processing_gemma3n.py index e4359bcd62c3..23f63b8bdb19 100644 --- a/src/transformers/models/gemma3n/audio_processing_gemma3n.py +++ b/src/transformers/models/gemma3n/audio_processing_gemma3n.py @@ -58,6 +58,7 @@ class Gemma3nAudioProcessor(NumpyAudioBackend): mel_floor=1e-5, log_mode="log", preemphasis=0.97, + computation_dtype="float64", ) def __init__(self, per_bin_mean=None, per_bin_stddev=None, **kwargs): diff --git a/src/transformers/models/parakeet/audio_processing_parakeet.py b/src/transformers/models/parakeet/audio_processing_parakeet.py index a26d936075ae..5df813fabae5 100644 --- a/src/transformers/models/parakeet/audio_processing_parakeet.py +++ b/src/transformers/models/parakeet/audio_processing_parakeet.py @@ -80,7 +80,7 @@ def _mel_filter_bank(self, spectrogram_config): return torch.from_numpy(weights.T).to(torch.float32) - def _compute_magnitudes(self, stft_out, power): + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): import torch magnitudes = torch.view_as_real(stft_out) diff --git a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py index 787d28ffb401..a63321c9a346 100644 --- a/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/audio_processing_phi4_multimodal.py @@ -91,7 +91,7 @@ def _apply_frame_processing(self, frames, *, spectrogram_config, audio_ranges=No frames_prev[..., 0] = frames_prev[..., 1] return (frames - spectrogram_config.preemphasis * frames_prev) * 32768 - def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg): + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None): frames = frames * window if frame_length < n_fft: frames = torch.nn.functional.pad(frames, (0, n_fft - frame_length)) diff --git a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py index 3b53af82d2a6..de597178b446 100644 --- a/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/audio_processing_seamless_m4t.py @@ -46,14 +46,15 @@ class SeamlessM4tAudioProcessor(NumpyAudioBackend): preemphasis=0.97, remove_dc_offset=True, mel_floor=1.192092955078125e-07, - waveform_scale=32768.0, + computation_dtype="float64", ) + waveform_scale = 32768.0 def extract_spectrogram(self, audio, **kwargs): # Per-waveform fbank extraction returning (time, n_mels) features = [] for waveform in audio: - waveform = np.squeeze(waveform) + waveform = np.squeeze(waveform) * self.waveform_scale f = super().extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) features.append(f[0].T) return features diff --git a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py index 29d80e383f50..5f9717738982 100644 --- a/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py +++ b/src/transformers/models/speech_to_text/audio_processing_speech_to_text.py @@ -43,8 +43,8 @@ class SpeechToTextAudioProcessor(NumpyAudioBackend): preemphasis=0.97, remove_dc_offset=True, mel_floor=1.192092955078125e-07, - waveform_scale=32768.0, ) + waveform_scale = 32768.0 def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): super().__init__(**kwargs) @@ -53,7 +53,7 @@ def __init__(self, normalize_means=True, normalize_vars=True, **kwargs): def _extract_fbank_features(self, waveform): """Extract log-mel filterbank features for a single waveform.""" - waveform = waveform * self.spectrogram_config.waveform_scale + waveform = waveform * self.waveform_scale return self._kaldi_fbank(waveform, num_mel_bins=80) def extract_spectrogram(self, audio, **kwargs): diff --git a/src/transformers/models/univnet/audio_processing_univnet.py b/src/transformers/models/univnet/audio_processing_univnet.py index 58fd8e4ea32c..6bd6e16b2af3 100644 --- a/src/transformers/models/univnet/audio_processing_univnet.py +++ b/src/transformers/models/univnet/audio_processing_univnet.py @@ -47,6 +47,7 @@ class UnivNetAudioProcessor(NumpyAudioBackend): ), log_mode="log", mel_floor=1e-5, + computation_dtype="float64", ) def __init__(self, **kwargs): @@ -63,7 +64,7 @@ def _stft(self, audio, *, spectrogram_config, **kwargs): audio = np.pad(audio, (pad_amount, pad_amount), mode="reflect") return super()._stft(audio, spectrogram_config=spectrogram_config, **kwargs) - def _compute_magnitudes(self, stft_out, power): + def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): # UnivNet adds mel_floor inside the sqrt: sqrt(real² + imag² + mel_floor) return np.sqrt(np.real(stft_out) ** 2 + np.imag(stft_out) ** 2 + self.mel_floor) From 792a87147851c37e0d3b764c49a9715e47184df7 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:00:25 +0200 Subject: [PATCH 26/28] the closert I get the further it gets --- src/transformers/audio_processing_backends.py | 11 +++++------ src/transformers/audio_utils.py | 16 +++++++++++++++- .../audio_processing_voxtral_realtime.py | 4 ++++ .../models/whisper/audio_processing_whisper.py | 4 ++++ 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 10155a4002d1..4ec4ad7dfb47 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -359,7 +359,10 @@ def _normalize_magnitude( def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): stft_cfg = spectrogram_config.stft_config mel_cfg = spectrogram_config.mel_scale_config - filters = mel_filter_bank( + # Use float32 dtype for per-band rounding matching librosa. Processors requiring + # float64 precision set computation_dtype, which skips the dtype to keep float64. + filter_dtype = None if spectrogram_config.computation_dtype else np.float32 + return mel_filter_bank( num_frequency_bins=1 + stft_cfg.n_fft // 2, num_mel_filters=mel_cfg.n_mels, min_frequency=mel_cfg.f_min, @@ -368,12 +371,8 @@ def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): norm=mel_cfg.norm, mel_scale=mel_cfg.mel_scale, triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + dtype=filter_dtype, ) - # Store as float32 to match librosa's precision path. Processors needing float64 - # set computation_dtype, which keeps filters in float64 for exact upstream FE matching. - if not spectrogram_config.computation_dtype: - filters = filters.astype(np.float32) - return filters def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) diff --git a/src/transformers/audio_utils.py b/src/transformers/audio_utils.py index e9d2da3707d3..14e70e12e45c 100644 --- a/src/transformers/audio_utils.py +++ b/src/transformers/audio_utils.py @@ -580,6 +580,7 @@ def mel_filter_bank( norm: str | None = None, mel_scale: str = "htk", triangularize_in_mel_space: bool = False, + dtype: np.dtype | None = None, ) -> np.ndarray: """ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and @@ -648,7 +649,20 @@ def mel_filter_bank( # frequencies of FFT bins in Hz fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins) - mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) + if dtype is not None: + # Per-band computation matching librosa's precision path: compute slopes in float64, + # cast each band to dtype immediately. This replicates librosa's per-row assignment + # to a dtype-initialized array, which produces different rounding than computing all + # bands in float64 and casting at the end. + filter_diff = np.diff(filter_freqs) + ramps = np.subtract.outer(filter_freqs, fft_freqs) # (num_mel_filters+2, num_frequency_bins) + mel_filters = np.zeros((num_frequency_bins, num_mel_filters), dtype=dtype) + for i in range(num_mel_filters): + lower = -ramps[i] / filter_diff[i] + upper = ramps[i + 2] / filter_diff[i + 1] + mel_filters[:, i] = np.maximum(0, np.minimum(lower, upper)).astype(dtype) + else: + mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs) if norm is not None and norm == "slaney": # Slaney-style mel is scaled to be approx constant energy per channel diff --git a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py index 7fff372c318d..59ff6ad89176 100644 --- a/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/audio_processing_voxtral_realtime.py @@ -38,6 +38,10 @@ class VoxtralRealtimeAudioProcessor(TorchAudioBackend): ) global_log_mel_max = 1.5 + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + mel_filters = self.mel_filters.to(device=features.device) + return torch.clamp(torch.matmul(mel_filters.T, features), min=spectrogram_config.mel_floor) + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) diff --git a/src/transformers/models/whisper/audio_processing_whisper.py b/src/transformers/models/whisper/audio_processing_whisper.py index eb9f3dd570f5..0a7f5bffa9be 100644 --- a/src/transformers/models/whisper/audio_processing_whisper.py +++ b/src/transformers/models/whisper/audio_processing_whisper.py @@ -40,6 +40,10 @@ class WhisperAudioProcessor(TorchAudioBackend): skip_last_frame=True, ) + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): + mel_filters = self.mel_filters.to(device=features.device) + return torch.clamp(torch.matmul(mel_filters.T, features), min=spectrogram_config.mel_floor) + def _normalize_magnitude(self, features, *, spectrogram_config, **kwargs): features = super()._normalize_magnitude(features, spectrogram_config=spectrogram_config, **kwargs) From 7ea32fecf28317ea363c2fcbc222621617008f02 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Fri, 3 Apr 2026 23:58:17 +0200 Subject: [PATCH 27/28] caching window --- src/transformers/audio_processing_utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/transformers/audio_processing_utils.py b/src/transformers/audio_processing_utils.py index 0419d3cdc008..33f4eaff143d 100644 --- a/src/transformers/audio_processing_utils.py +++ b/src/transformers/audio_processing_utils.py @@ -96,11 +96,11 @@ def __init__( for key, value in attributes.items(): setattr(self, key, value) - # Derive mel_filters from spectrogram_config if mel_scale_config is set - # TODO: maybe the mel spectrogram initialization should be lazy? - if self.spectrogram_config is not None and self.spectrogram_config.mel_scale_config is not None: - if not hasattr(self, "mel_filters"): + # Pre-compute mel filters from spectrogram_config + if self.spectrogram_config is not None: + if self.spectrogram_config.mel_scale_config is not None and not hasattr(self, "mel_filters"): self.mel_filters = self._mel_filter_bank(self.spectrogram_config) + self._cached_stft_window = None def __call__(self, audio: AudioInput, *args, **kwargs: Unpack[AudioKwargs]) -> BatchFeature: return self.preprocess(audio, *args, **kwargs) @@ -440,8 +440,15 @@ def _stft(self, audio, *, spectrogram_config, **kwargs): audio = audio.to(getattr(torch, dtype_str)) if spectrogram_config.waveform_scale is not None: audio = audio * spectrogram_config.waveform_scale - window = self._create_stft_window(win_length, stft_cfg, audio) - window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + + # Cache window on first call; reuse on subsequent calls with same config + if self._cached_stft_window is not None and spectrogram_config is self.spectrogram_config: + window, frame_length = self._cached_stft_window + else: + window = self._create_stft_window(win_length, stft_cfg, audio) + window, frame_length = self._prepare_window_and_framing(window, win_length, n_fft, needs_manual_framing) + if spectrogram_config is self.spectrogram_config: + self._cached_stft_window = (window, frame_length) if needs_manual_framing: audio_dtype = audio.dtype From 5c7bf4f0ce5f7279cb3fb9601839540fcb48a8d8 Mon Sep 17 00:00:00 2001 From: eustlb <94853470+eustlb@users.noreply.github.com> Date: Sat, 4 Apr 2026 00:16:33 +0200 Subject: [PATCH 28/28] cleaning backend --- src/transformers/audio_processing_backends.py | 786 ++++++++---------- 1 file changed, 334 insertions(+), 452 deletions(-) diff --git a/src/transformers/audio_processing_backends.py b/src/transformers/audio_processing_backends.py index 4ec4ad7dfb47..20a4c8a1f4c8 100644 --- a/src/transformers/audio_processing_backends.py +++ b/src/transformers/audio_processing_backends.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math import numpy as np from .audio_processing_utils import BaseAudioProcessor from .audio_utils import SpectrogramConfig, amplitude_to_db, mel_filter_bank, power_to_db -from .feature_extraction_utils import BatchFeature from .utils import PaddingStrategy, is_torch_available, logging @@ -31,52 +29,14 @@ import torch -# ── NumPy frequency conversion utilities ────────────────────────────── - -def _np_hertz_to_mel(freq, mel_scale="htk"): - if mel_scale == "htk": - return 2595.0 * np.log10(1.0 + (freq / 700.0)) - elif mel_scale == "kaldi": - return 1127.0 * np.log(1.0 + (freq / 700.0)) - # slaney - min_log_hertz = 1000.0 - min_log_mel = 15.0 - logstep = 27.0 / np.log(6.4) - mels = 3.0 * freq / 200.0 - if isinstance(freq, np.ndarray): - log_region = freq >= min_log_hertz - mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep - elif freq >= min_log_hertz: - mels = min_log_mel + np.log(freq / min_log_hertz) * logstep - return mels - - -def _np_mel_to_hertz(mels, mel_scale="htk"): - if mel_scale == "htk": - return 700.0 * (np.power(10, mels / 2595.0) - 1.0) - elif mel_scale == "kaldi": - return 700.0 * (np.exp(mels / 1127.0) - 1.0) - # slaney - min_log_hertz = 1000.0 - min_log_mel = 15.0 - logstep = np.log(6.4) / 27.0 - freq = 200.0 * mels / 3.0 - if isinstance(mels, np.ndarray): - log_region = mels >= min_log_mel - freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel)) - elif mels >= min_log_mel: - freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel)) - return freq +# ── Torch frequency conversion utilities (used by TorchAudioBackend._mel_filter_bank) ── -# ── Torch frequency conversion utilities ────────────────────────────── - def _torch_hertz_to_mel_scalar(freq: float, mel_scale: str = "htk") -> float: if mel_scale == "htk": return 2595.0 * math.log10(1.0 + freq / 700.0) elif mel_scale == "kaldi": return 1127.0 * math.log(1.0 + freq / 700.0) - # slaney f_sp = 200.0 / 3 min_log_hz = 1000.0 min_log_mel = (min_log_hz - 0.0) / f_sp @@ -91,7 +51,6 @@ def _torch_hertz_to_mel(freq: "torch.Tensor", mel_scale: str = "htk") -> "torch. return 2595.0 * torch.log10(1.0 + freq / 700.0) elif mel_scale == "kaldi": return 1127.0 * torch.log(1.0 + freq / 700.0) - # slaney f_sp = 200.0 / 3 min_log_hertz = 1000.0 min_log_mel = min_log_hertz / f_sp @@ -107,7 +66,6 @@ def _torch_mel_to_hertz(mels: "torch.Tensor", mel_scale: str = "htk") -> "torch. return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) elif mel_scale == "kaldi": return 700.0 * (torch.exp(mels / 1127.0) - 1.0) - # slaney f_sp = 200.0 / 3 min_log_hz = 1000.0 min_log_mel = (min_log_hz - 0.0) / f_sp @@ -118,6 +76,22 @@ def _torch_mel_to_hertz(mels: "torch.Tensor", mel_scale: str = "htk") -> "torch. return freq +def _torch_triangular_filter_bank(fft_freqs, filter_freqs, computation_dtype=None): + """Compute triangular mel filter bank (shared by non-kaldi TorchAudioBackend paths).""" + num_mel_filters = len(filter_freqs) - 2 + filter_diff = filter_freqs[1:] - filter_freqs[:-1] + slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + zero = torch.zeros(1, dtype=computation_dtype) if computation_dtype else torch.zeros(1) + return torch.clamp(torch.minimum(down_slopes, up_slopes), min=0) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# NumpyAudioBackend +# ═══════════════════════════════════════════════════════════════════════════════ + + class NumpyAudioBackend(BaseAudioProcessor): """NumPy backend for portable CPU-only audio processing.""" @@ -125,18 +99,12 @@ class NumpyAudioBackend(BaseAudioProcessor): def backend(self) -> str: return "numpy" - def _process_audio(self, audio_el): - """ - Process a single raw audio input into a np.ndarray. + # ── Audio input processing ──────────────────────────────────────────── - Handles mono conversion (averaging channels) and numpy conversion. - Closely mirrors the torch backend logic: expects channel-first. - """ + def _process_audio(self, audio_el): if not isinstance(audio_el, np.ndarray): audio_el = np.asarray(audio_el) - if audio_el.ndim > 1: - # Expecting channel-first: (channels, samples) if self.force_mono and audio_el.shape[0] > 1: audio_el = audio_el.mean(axis=0) elif audio_el.shape[0] == 1: @@ -145,18 +113,12 @@ def _process_audio(self, audio_el): raise ValueError("Audio has more than one channel but force_mono is False") return audio_el + # ── Padding & batching ──────────────────────────────────────────────── + def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray: - """Pad a single audio array to a target length using np.pad.""" current_length = audio.shape[-1] if current_length >= max_length: return audio - - if self.padding_value is None: - raise ValueError( - "Asking to pad but the audio processor does not have a padding value. Please select a value to use" - " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." - ) - pad_length = max_length - current_length if self.padding_side == "right": pad_width = [(0, 0)] * (audio.ndim - 1) + [(0, pad_length)] @@ -164,9 +126,60 @@ def _pad_single(self, audio: np.ndarray, max_length: int) -> np.ndarray: pad_width = [(0, 0)] * (audio.ndim - 1) + [(pad_length, 0)] else: raise ValueError(f"Invalid padding side: {self.padding_side}") - return np.pad(audio, pad_width, mode="constant", constant_values=self.padding_value) + def _to_batch(self, audio): + batch = np.stack(audio) + if self.add_channel_dim: + batch = batch[:, np.newaxis, :] + return batch + + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + if truncation and max_length is not None: + features = [f[:max_length] for f in features] + actual_lengths = [f.shape[0] for f in features] + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(actual_lengths) + padding_strategy = PaddingStrategy.MAX_LENGTH + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: + features = [ + np.pad(f, [(0, max_length - f.shape[0])] + [(0, 0)] * (f.ndim - 1), + mode="constant", constant_values=self.padding_value) + if f.shape[0] < max_length else f + for f in features + ] + return features, [(0, length) for length in actual_lengths] + + def _stack_features(self, features): + return np.stack(features) + + # ── Masking ─────────────────────────────────────────────────────────── + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + use_audio_mask = self.mask_level == "audio" + if do_extract_spectrogram and not use_audio_mask: + spec_cfg = spectrogram_config or self.spectrogram_config + audio_lengths = np.array([end - start for start, end in audio_ranges]) + features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) + n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) + mask = (np.arange(n_features)[None, :] < features_lengths[:, None]).astype(np.int32) + return {"audio_features_mask": mask} + mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {("audio_features_mask" if do_extract_spectrogram else "audio_values_mask"): mask} + + def _get_feature_mask(self, feature_ranges, padded_length): + mask = np.zeros((len(feature_ranges), padded_length), dtype=np.int32) + for i, (start, end) in enumerate(feature_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} + + # ── STFT pipeline ───────────────────────────────────────────────────── + def _create_stft_window(self, win_length, stft_cfg, audio): N = win_length + 1 if stft_cfg.periodic else win_length fac = np.linspace(-np.pi, np.pi, N) @@ -185,18 +198,16 @@ def _create_stft_window(self, win_length, stft_cfg, audio): def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing): if needs_manual_framing and win_length < n_fft: - frame_length = win_length - else: - if win_length < n_fft: - left_pad = (n_fft - win_length) // 2 - right_pad = n_fft - win_length - left_pad - window = np.pad(window, (left_pad, right_pad)) - frame_length = n_fft - return window, frame_length + return window, win_length + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = np.pad(window, (left_pad, right_pad)) + return window, n_fft @staticmethod def _np_frame(x, frame_length, hop_length): - """Create overlapping frames from a 1D array using stride tricks (replaces librosa.util.frame).""" + """Create overlapping frames using stride tricks (replaces librosa.util.frame).""" n_frames = 1 + (x.shape[-1] - frame_length) // hop_length strides = x.strides[:-1] + (x.strides[-1] * hop_length, x.strides[-1]) shape = x.shape[:-1] + (n_frames, frame_length) @@ -206,49 +217,40 @@ def _frame_waveform(self, waveform, frame_length, hop_length, n_fft, center, pad squeezed = waveform.ndim == 1 if squeezed: waveform = waveform[np.newaxis, :] + if center: start_k = int(np.ceil(n_fft // 2 / hop_length)) tail_k = (waveform.shape[-1] + n_fft // 2 - n_fft) // hop_length + 1 if tail_k <= start_k: + # Short audio: simple center-pad and index-based framing waveform = np.pad(waveform, ((0, 0), (frame_length // 2, frame_length // 2)), mode=pad_mode) num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length frame_starts = np.arange(num_frames) * hop_length - frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) - frames = waveform[:, frame_indices] + frames = waveform[:, frame_starts[:, np.newaxis] + np.arange(frame_length)] else: + # Long audio: split into pre (left-padded), middle (no pad), post (right-padded) + # to handle edge effects from center padding correctly padding = [(0, 0) for _ in range(waveform.ndim)] + padding[-1] = (frame_length // 2, 0) - y_pre = np.pad( - waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1], - padding, - mode=pad_mode, - ) - y_frames_pre = self._np_frame(y_pre, frame_length, hop_length) - y_frames_pre = y_frames_pre[..., :start_k, :] - extra = y_frames_pre.shape[-2] + y_pre = np.pad(waveform[..., : (start_k - 1) * hop_length - n_fft // 2 + n_fft + 1], padding, mode=pad_mode) + y_frames_pre = self._np_frame(y_pre, frame_length, hop_length)[..., :start_k, :] padding[-1] = (0, frame_length // 2) - y_post = np.pad( - waveform[..., (tail_k) * hop_length - n_fft // 2 :], - padding, - mode=pad_mode, - ) + y_post = np.pad(waveform[..., tail_k * hop_length - n_fft // 2 :], padding, mode=pad_mode) y_frames_post = self._np_frame(y_post, frame_length, hop_length) - extra += y_frames_post.shape[-2] start = start_k * hop_length - n_fft // 2 - y_frames_middle = self._np_frame( - np.ascontiguousarray(waveform[..., start:]), frame_length, hop_length - ) + y_frames_middle = self._np_frame(np.ascontiguousarray(waveform[..., start:]), frame_length, hop_length) num_frames = y_frames_pre.shape[-2] + y_frames_middle.shape[-2] + y_frames_post.shape[-2] frames = np.concatenate([y_frames_pre, y_frames_middle, y_frames_post], axis=-2) else: + # Non-centered: simple index-based framing num_frames = 1 + (waveform.shape[-1] - frame_length) // hop_length frame_starts = np.arange(num_frames) * hop_length - frame_indices = frame_starts[:, np.newaxis] + np.arange(frame_length) - frames = waveform[:, frame_indices] + frames = waveform[:, frame_starts[:, np.newaxis] + np.arange(frame_length)] if squeezed: frames = frames.squeeze(0) @@ -259,11 +261,18 @@ def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg) compute_dtype = np.result_type(audio.dtype, window.dtype) return frames.astype(compute_dtype, copy=False) + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + if spectrogram_config.remove_dc_offset: + frames = frames - frames.mean(axis=-1, keepdims=True) + preemphasis = spectrogram_config.preemphasis + if preemphasis is not None: + preemph_src = preemphasis * frames[..., :-1] + frames[..., 1:] = frames[..., 1:] - preemph_src + frames[..., 0] = frames[..., 0] * (1 - preemphasis) + return frames + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None): frames = frames * window - # Always store FFT output as complex64, matching the upstream spectrogram() function. - # FFT is computed in float64 (numpy default), but the complex64 cast ensures consistent - # precision with librosa and the legacy FE code path. spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64) if stft_cfg.normalized: spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) @@ -272,41 +281,38 @@ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_d def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): frames, _ = self._frame_waveform(audio, frame_length, hop_length, n_fft, stft_cfg.center, stft_cfg.pad_mode) compute_dtype = np.result_type(audio.dtype, window.dtype) - frames = frames.astype(compute_dtype, copy=False) - frames = frames * window + frames = frames.astype(compute_dtype, copy=False) * window spec = np.fft.rfft(frames, n=n_fft, axis=-1).astype(np.complex64) if stft_cfg.normalized: spec = spec / np.sqrt(np.sum(window**2)).astype(spec.real.dtype) return np.moveaxis(spec, -1, -2) def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): - # When computation_dtype is set (e.g., "float64"), compute magnitudes in that dtype - # to match the upstream FE precision path: np.abs(complex64, dtype=float64) ** power. - # Otherwise, use the natural dtype (float32 from complex64) to match librosa. + # computation_dtype signals that upstream FE used float64 magnitudes if spectrogram_config and spectrogram_config.computation_dtype: return np.abs(stft_out, dtype=np.float64) ** power return np.abs(stft_out) ** power - def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): - """Apply per-frame signal conditioning using the numpy backend.""" - compute_dtype = frames.dtype - if spectrogram_config.remove_dc_offset: - frames = frames - frames.mean(axis=-1, keepdims=True) - preemphasis = spectrogram_config.preemphasis - if preemphasis is not None: - preemph_src = preemphasis * frames[..., :-1] - frames[..., 1:] = frames[..., 1:] - preemph_src - frames[..., 0] = frames[..., 0] * (1 - preemphasis) - return frames + # ── Mel scale & normalization ───────────────────────────────────────── - def _apply_mel_scale( - self, - features: list[np.ndarray], - *, - spectrogram_config: SpectrogramConfig, - **kwargs, - ) -> list[np.ndarray]: - """Apply mel filterbank to spectrogram features using the numpy backend.""" + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + # float32 dtype matches librosa's per-band rounding; computation_dtype keeps float64 + filter_dtype = None if spectrogram_config.computation_dtype else np.float32 + return mel_filter_bank( + num_frequency_bins=1 + stft_cfg.n_fft // 2, + num_mel_filters=mel_cfg.n_mels, + min_frequency=mel_cfg.f_min, + max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, + sampling_rate=self.sample_rate, + norm=mel_cfg.norm, + mel_scale=mel_cfg.mel_scale, + triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, + dtype=filter_dtype, + ) + + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): mel_filters = self.mel_filters.astype(features.dtype, copy=False) if spectrogram_config.mel_scale_config.matmul_order == "features_first": mel_spec = np.matmul(features, mel_filters) @@ -314,30 +320,14 @@ def _apply_mel_scale( mel_spec = np.matmul(mel_filters.T, features) return np.maximum(spectrogram_config.mel_floor, mel_spec) - def _normalize_magnitude( - self, - features: np.ndarray, - *, - spectrogram_config: SpectrogramConfig, - reference: float = 1.0, - min_value: float = 1e-10, - db_range: float | None = None, - dtype: np.dtype = np.float32, - **kwargs, - ) -> np.ndarray: - """Apply magnitude normalization (log, log10, or dB scaling) to spectrogram features. - - Accepts a single or batched spectrogram (not a list). - Mirrors the normalization logic in the spectrogram pipeline. - """ + def _normalize_magnitude(self, features, *, spectrogram_config, + reference=1.0, min_value=1e-10, db_range=None, + dtype=np.float32, **kwargs): log_mel = spectrogram_config.log_mode - mel_floor = spectrogram_config.mel_floor - power = spectrogram_config.stft_config.power - if log_mel is None: return features.astype(dtype) - # Clamp to mel_floor before taking log + mel_floor = spectrogram_config.mel_floor result = np.maximum(mel_floor, features) if log_mel == "log": @@ -345,6 +335,7 @@ def _normalize_magnitude( elif log_mel == "log10": result = np.log10(result).astype(dtype) elif log_mel == "dB": + power = spectrogram_config.stft_config.power if power == 1.0: result = amplitude_to_db(result, reference, min_value, db_range).astype(dtype) elif power == 2.0: @@ -353,68 +344,14 @@ def _normalize_magnitude( raise ValueError(f"Cannot use log_mel option 'dB' with power {power}") else: raise ValueError(f"Unknown log_mel option: {log_mel}") - return result - def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): - stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config - # Use float32 dtype for per-band rounding matching librosa. Processors requiring - # float64 precision set computation_dtype, which skips the dtype to keep float64. - filter_dtype = None if spectrogram_config.computation_dtype else np.float32 - return mel_filter_bank( - num_frequency_bins=1 + stft_cfg.n_fft // 2, - num_mel_filters=mel_cfg.n_mels, - min_frequency=mel_cfg.f_min, - max_frequency=mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2, - sampling_rate=self.sample_rate, - norm=mel_cfg.norm, - mel_scale=mel_cfg.mel_scale, - triangularize_in_mel_space=mel_cfg.triangularize_in_mel_space, - dtype=filter_dtype, - ) - - def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): - padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) - - if truncation and max_length is not None: - features = [f[:max_length] for f in features] - - actual_lengths = [f.shape[0] for f in features] - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = max(actual_lengths) - padding_strategy = PaddingStrategy.MAX_LENGTH - - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: - padded = [] - for f in features: - if f.shape[0] < max_length: - pad_width = [(0, max_length - f.shape[0])] + [(0, 0)] * (f.ndim - 1) - f = np.pad(f, pad_width, mode="constant", constant_values=self.padding_value) - padded.append(f) - features = padded - - feature_ranges = [(0, length) for length in actual_lengths] - return features, feature_ranges - - def _stack_features(self, features): - return np.stack(features) - - def _get_feature_mask(self, feature_ranges, padded_length): - mask = np.zeros((len(feature_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(feature_ranges): - mask[i, start:end] = 1 - return {"audio_features_mask": mask} + # ── Kaldi fbank helper ──────────────────────────────────────────────── def _kaldi_fbank(self, waveform, num_mel_bins, sample_frequency=None, **kwargs): - """Extract kaldi-compatible fbank features for a single waveform. + """Extract kaldi-compatible fbank features using torchaudio (or fallback to base pipeline). - Uses torchaudio when available, falls back to the base spectrogram pipeline. - Returns a numpy array of shape (time, num_mel_bins). + Returns numpy array of shape (time, num_mel_bins). """ from .utils import is_speech_available @@ -426,36 +363,18 @@ def _kaldi_fbank(self, waveform, num_mel_bins, sample_frequency=None, **kwargs): import torchaudio.compliance.kaldi as ta_kaldi waveform_tensor = torch.from_numpy(np.asarray(waveform)).unsqueeze(0) - fbank = ta_kaldi.fbank( - waveform_tensor, num_mel_bins=num_mel_bins, sample_frequency=sample_frequency, **kwargs - ) + fbank = ta_kaldi.fbank(waveform_tensor, num_mel_bins=num_mel_bins, + sample_frequency=sample_frequency, **kwargs) return fbank.numpy() - else: - waveform = np.squeeze(waveform) - features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) - return features[0].T - def _to_batch(self, audio): - batch = np.stack(audio) - if self.add_channel_dim: - batch = batch[:, np.newaxis, :] - return batch + waveform = np.squeeze(waveform) + features = self.extract_spectrogram([waveform], spectrogram_config=self.spectrogram_config) + return features[0].T - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - use_audio_mask = self.mask_level == "audio" - if do_extract_spectrogram and not use_audio_mask: - spec_cfg = spectrogram_config or self.spectrogram_config - audio_lengths = np.array([end - start for start, end in audio_ranges]) - features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) - n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) - mask = (np.arange(n_features)[None, :] < features_lengths[:, None]).astype(np.int32) - return {"audio_features_mask": mask} - else: - mask = np.zeros((len(audio_ranges), padded_length), dtype=np.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - key = "audio_features_mask" if do_extract_spectrogram else "audio_values_mask" - return {key: mask} + +# ═══════════════════════════════════════════════════════════════════════════════ +# TorchAudioBackend +# ═══════════════════════════════════════════════════════════════════════════════ class TorchAudioBackend(BaseAudioProcessor): @@ -465,53 +384,88 @@ class TorchAudioBackend(BaseAudioProcessor): def backend(self) -> str: return "torch" - def _process_audio(self, audio_el): - """ - Process a single raw audio input into a torch.Tensor. - - Handles mono conversion (averaging channels) and numpy-to-torch conversion. - """ - import torch + # ── Audio input processing ──────────────────────────────────────────── + def _process_audio(self, audio_el): if isinstance(audio_el, np.ndarray): audio_el = torch.from_numpy(audio_el) - if audio_el.ndim > 1: - # TODO: we would need to ensure somewhere audio is channel first if self.force_mono and audio_el.shape[0] > 1: audio_el = audio_el.mean(dim=0) elif audio_el.shape[0] == 1: audio_el = audio_el.squeeze(0) else: raise ValueError("Audio has more than one channel but force_mono is False") - return audio_el - def _pad_single(self, audio: "torch.Tensor", max_length: int) -> "torch.Tensor": - """Pad a single audio tensor to a target length using torch.nn.functional.pad.""" - import torch.nn.functional as F + # ── Padding & batching ──────────────────────────────────────────────── + def _pad_single(self, audio, max_length): current_length = audio.shape[-1] if current_length >= max_length: return audio - - if self.padding_value is None: - raise ValueError( - "Asking to pad but the audio processor does not have a padding value. Please select a value to use" - " as `padding_value`. For example: `audio_processor.padding_value = 0.0`." - ) - if self.padding_side == "right": pad_args = (0, max_length - current_length) elif self.padding_side == "left": pad_args = (max_length - current_length, 0) else: raise ValueError(f"Invalid padding side: {self.padding_side}") + return torch.nn.functional.pad(audio, pad_args, "constant", self.padding_value) + + def _to_batch(self, audio): + batch = torch.stack(audio) + if self.add_channel_dim: + batch = batch.unsqueeze(1) + return batch + + def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): + padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) + if truncation and max_length is not None: + features = [f[:max_length] for f in features] + actual_lengths = [f.shape[0] for f in features] + if padding_strategy == PaddingStrategy.LONGEST: + max_length = max(actual_lengths) + padding_strategy = PaddingStrategy.MAX_LENGTH + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: + padded = [] + for f in features: + if f.shape[0] < max_length: + pad_args = [0, 0] * (f.ndim - 1) + [0, max_length - f.shape[0]] + f = torch.nn.functional.pad(f, pad_args, "constant", self.padding_value) + padded.append(f) + features = padded + return features, [(0, length) for length in actual_lengths] + + def _stack_features(self, features): + return torch.stack(features) + + # ── Masking ─────────────────────────────────────────────────────────── + + def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): + use_audio_mask = self.mask_level == "audio" + if do_extract_spectrogram and not use_audio_mask: + spec_cfg = spectrogram_config or self.spectrogram_config + audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) + features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) + n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) + mask = (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32) + return {"audio_features_mask": mask} + mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) + for i, (start, end) in enumerate(audio_ranges): + mask[i, start:end] = 1 + return {("audio_features_mask" if do_extract_spectrogram else "audio_values_mask"): mask} + + def _get_feature_mask(self, feature_ranges, padded_length): + mask = torch.zeros((len(feature_ranges), padded_length), dtype=torch.int32) + for i, (start, end) in enumerate(feature_ranges): + mask[i, start:end] = 1 + return {"audio_features_mask": mask} - return F.pad(audio, pad_args, "constant", self.padding_value) + # ── STFT pipeline ───────────────────────────────────────────────────── def _needs_manual_framing(self, spectrogram_config): - """Extends the base check with ``left_align_fft`` which also requires manual framing.""" return super()._needs_manual_framing(spectrogram_config) or spectrogram_config.stft_config.left_align_fft def _create_stft_window(self, win_length, stft_cfg, audio): @@ -532,22 +486,29 @@ def _create_stft_window(self, win_length, stft_cfg, audio): def _prepare_window_and_framing(self, window, win_length, n_fft, needs_manual_framing): if needs_manual_framing and win_length < n_fft: - frame_length = win_length - else: - if win_length < n_fft: - left_pad = (n_fft - win_length) // 2 - right_pad = n_fft - win_length - left_pad - window = torch.nn.functional.pad(window, (left_pad, right_pad)) - frame_length = n_fft - return window, frame_length + return window, win_length + if win_length < n_fft: + left_pad = (n_fft - win_length) // 2 + right_pad = n_fft - win_length - left_pad + window = torch.nn.functional.pad(window, (left_pad, right_pad)) + return window, n_fft def _frame_audio(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): if stft_cfg.center: - audio = torch.nn.functional.pad( - audio, (frame_length // 2, frame_length // 2), mode=stft_cfg.pad_mode - ) + audio = torch.nn.functional.pad(audio, (frame_length // 2, frame_length // 2), mode=stft_cfg.pad_mode) return audio.unfold(-1, frame_length, hop_length) + def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): + if spectrogram_config.remove_dc_offset: + frames = frames - frames.mean(dim=-1, keepdim=True) + preemphasis = spectrogram_config.preemphasis + if preemphasis is not None: + frames = torch.cat([ + frames[..., :1] * (1 - preemphasis), + frames[..., 1:] - preemphasis * frames[..., :-1], + ], dim=-1) + return frames + def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_dtype=None): frames = frames * window if frame_length < n_fft: @@ -559,15 +520,9 @@ def _window_and_fft(self, frames, window, frame_length, n_fft, stft_cfg, audio_d def _native_stft(self, audio, window, frame_length, hop_length, n_fft, stft_cfg): stft_out = torch.stft( - audio, - n_fft=n_fft, - hop_length=hop_length, - win_length=frame_length, - window=window, - center=stft_cfg.center, - pad_mode=stft_cfg.pad_mode, - normalized=False, - return_complex=True, + audio, n_fft=n_fft, hop_length=hop_length, win_length=frame_length, + window=window, center=stft_cfg.center, pad_mode=stft_cfg.pad_mode, + normalized=False, return_complex=True, ) if stft_cfg.normalized: stft_out = stft_out / window.pow(2.0).sum().sqrt() @@ -579,62 +534,141 @@ def _cast_stft_output(self, magnitudes, spectrogram_config): return magnitudes.float() def _compute_magnitudes(self, stft_out, power, spectrogram_config=None): - """Convert complex STFT output to a real-valued magnitude spectrogram.""" return stft_out.abs() ** power - def _apply_frame_processing(self, frames, *, spectrogram_config, **kwargs): - """Apply per-frame signal conditioning using the torch backend.""" - if spectrogram_config.remove_dc_offset: - frames = frames - frames.mean(dim=-1, keepdim=True) - preemphasis = spectrogram_config.preemphasis - if preemphasis is not None: - frames = torch.cat([ - frames[..., :1] * (1 - preemphasis), - frames[..., 1:] - preemphasis * frames[..., :-1], - ], dim=-1) - return frames + # ── Mel scale & normalization ───────────────────────────────────────── + + def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): + stft_cfg = spectrogram_config.stft_config + mel_cfg = spectrogram_config.mel_scale_config + computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None + num_frequency_bins = 1 + stft_cfg.n_fft // 2 + num_mel_filters = mel_cfg.n_mels + min_frequency = mel_cfg.f_min + max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 + n_fft = (num_frequency_bins - 1) * 2 - def _apply_mel_scale( - self, - features: list["torch.Tensor"], - *, - spectrogram_config: SpectrogramConfig, - **kwargs, - ) -> list["torch.Tensor"]: - """Apply mel filterbank to spectrogram features using the torch backend.""" + if mel_cfg.triangularize_in_mel_space and mel_cfg.bands_to_zero == 0: + # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks + mel_filters = self._kaldi_exact_mel_banks( + num_mel_filters, num_frequency_bins, min_frequency, max_frequency, + self.sample_rate, n_fft, + ) + elif mel_cfg.triangularize_in_mel_space: + mel_filters = self._kaldi_mel_banks_with_zero_bands( + num_mel_filters, num_frequency_bins, min_frequency, max_frequency, + self.sample_rate, n_fft, mel_cfg, computation_dtype, + ) + else: + mel_filters = self._standard_mel_banks( + num_mel_filters, num_frequency_bins, min_frequency, max_frequency, + self.sample_rate, n_fft, mel_cfg, computation_dtype, + ) + + # Cast back when mel computation_dtype doesn't match spectrogram computation_dtype + if computation_dtype is not None and not spectrogram_config.computation_dtype: + mel_filters = mel_filters.to(torch.get_default_dtype()) + return mel_filters + + @staticmethod + def _kaldi_exact_mel_banks(num_mel_filters, num_frequency_bins, min_frequency, + max_frequency, sampling_rate, n_fft): + """Matches torchaudio.compliance.kaldi.get_mel_banks exactly.""" + num_fft_bins = n_fft // 2 + fft_bin_width = sampling_rate / n_fft + mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0) + mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0) + mel_delta = (mel_high - mel_low) / (num_mel_filters + 1) + + bin_idx = torch.arange(num_mel_filters).unsqueeze(1) + left_mel = mel_low + bin_idx * mel_delta + center_mel = mel_low + (bin_idx + 1.0) * mel_delta + right_mel = mel_low + (bin_idx + 2.0) * mel_delta + + mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log() + mel = mel.unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) + banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0) + return banks.T + + @staticmethod + def _kaldi_mel_banks_with_zero_bands(num_mel_filters, num_frequency_bins, min_frequency, + max_frequency, sampling_rate, n_fft, mel_cfg, computation_dtype): + """Kaldi-style with bands_to_zero > 0.""" + mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) + mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) + mel_delta = (mel_max - mel_min) / (num_mel_filters + 1) + bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1) + left_mel = mel_min + bin_idx * mel_delta + center_mel = mel_min + (bin_idx + 1.0) * mel_delta + right_mel = mel_min + (bin_idx + 2.0) * mel_delta + + fft_bin_width = sampling_rate / n_fft + hz_freqs = fft_bin_width * torch.arange(mel_cfg.bands_to_zero, num_frequency_bins, dtype=computation_dtype) + mel = _torch_hertz_to_mel(hz_freqs, mel_scale=mel_cfg.mel_scale).unsqueeze(0) + + up_slope = (mel - left_mel) / (center_mel - left_mel) + down_slope = (right_mel - mel) / (right_mel - center_mel) + zero = torch.zeros(1, dtype=computation_dtype) + mel_filters = torch.max(zero, torch.min(up_slope, down_slope)).T + if mel_cfg.bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) + return mel_filters + + @staticmethod + def _standard_mel_banks(num_mel_filters, num_frequency_bins, min_frequency, + max_frequency, sampling_rate, n_fft, mel_cfg, computation_dtype): + """Standard (non-kaldi) triangular mel filter bank.""" + mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) + mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) + mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype) + filter_freqs = _torch_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale) + + if mel_cfg.frequency_bin_mode == "rfft": + fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) + else: + fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins) + if computation_dtype is not None: + fft_freqs = fft_freqs.to(computation_dtype) + + filter_diff = filter_freqs[1:] - filter_freqs[:-1] + slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1) + down_slopes = -slopes[:, :-2] / filter_diff[:-1] + up_slopes = slopes[:, 2:] / filter_diff[1:] + mel_filters = torch.clamp(torch.minimum(down_slopes, up_slopes), min=0) + + if mel_cfg.norm == "slaney": + enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) + mel_filters = mel_filters * enorm.unsqueeze(0) + + if mel_cfg.bands_to_zero > 0: + mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) + return mel_filters + + def _apply_mel_scale(self, features, *, spectrogram_config, **kwargs): mel_filters = self.mel_filters.to(device=features.device) if spectrogram_config.mel_scale_config.matmul_order == "features_first": mel_spec = torch.matmul(features.transpose(-2, -1), mel_filters) else: - # Use F.linear to match torchaudio's MelScale implementation exactly + # F.linear matches torchaudio's MelScale implementation exactly mel_spec = torch.nn.functional.linear(features.transpose(-2, -1), mel_filters.T).transpose(-2, -1) return torch.clamp(mel_spec, min=spectrogram_config.mel_floor) - def _normalize_magnitude( - self, - features: "torch.Tensor", - *, - spectrogram_config: SpectrogramConfig, - reference: float = 1.0, - min_value: float = 1e-10, - db_range: float | None = None, - dtype: "torch.dtype | None" = None, - **kwargs, - ) -> "torch.Tensor": - """Apply magnitude normalization (log, log10, or dB scaling) to batched spectrogram features (torch.Tensor only).""" - import torch - + def _normalize_magnitude(self, features, *, spectrogram_config, + reference=1.0, min_value=1e-10, db_range=None, + dtype=None, **kwargs): log_mel = spectrogram_config.log_mode mel_floor = spectrogram_config.mel_floor power = spectrogram_config.stft_config.power - if dtype is None: dtype = torch.float32 if log_mel is None: return features - # Clamp to mel_floor before taking log result = torch.clamp(features, min=mel_floor) if log_mel == "log": @@ -656,7 +690,6 @@ def _normalize_magnitude( if db_range is not None: if db_range <= 0.0: raise ValueError("db_range must be greater than zero") - # Clamp each sample so the minimum value is (max - db_range) max_vals = result.amax(dim=-2, keepdim=True) if result.ndim > 2 else result.max() result = torch.clamp(result, min=max_vals - db_range) result = result.to(dtype) @@ -667,154 +700,3 @@ def _normalize_magnitude( result = result[..., :-1] return result - - def _mel_filter_bank(self, spectrogram_config: SpectrogramConfig): - stft_cfg = spectrogram_config.stft_config - mel_cfg = spectrogram_config.mel_scale_config - computation_dtype = getattr(torch, mel_cfg.computation_dtype) if mel_cfg.computation_dtype else None - num_frequency_bins = 1 + stft_cfg.n_fft // 2 - num_mel_filters = mel_cfg.n_mels - min_frequency = mel_cfg.f_min - max_frequency = mel_cfg.f_max if mel_cfg.f_max is not None else self.sample_rate / 2 - sampling_rate = self.sample_rate - - if mel_cfg.triangularize_in_mel_space and mel_cfg.bands_to_zero == 0: - # Kaldi-exact path: matches torchaudio.compliance.kaldi.get_mel_banks. - n_fft = (num_frequency_bins - 1) * 2 - num_fft_bins = n_fft // 2 - fft_bin_width = sampling_rate / n_fft - - mel_low = 1127.0 * math.log(1.0 + min_frequency / 700.0) - mel_high = 1127.0 * math.log(1.0 + max_frequency / 700.0) - mel_delta = (mel_high - mel_low) / (num_mel_filters + 1) - - bin_idx = torch.arange(num_mel_filters).unsqueeze(1) - left_mel = mel_low + bin_idx * mel_delta - center_mel = mel_low + (bin_idx + 1.0) * mel_delta - right_mel = mel_low + (bin_idx + 2.0) * mel_delta - - mel = 1127.0 * (1.0 + fft_bin_width * torch.arange(num_fft_bins) / 700.0).log() - mel = mel.unsqueeze(0) - - up_slope = (mel - left_mel) / (center_mel - left_mel) - down_slope = (right_mel - mel) / (right_mel - center_mel) - banks = torch.max(torch.zeros(1), torch.min(up_slope, down_slope)) - banks = torch.nn.functional.pad(banks, (0, 1), mode="constant", value=0) - - mel_filters = banks.T - elif mel_cfg.triangularize_in_mel_space: - # Kaldi-style with bands_to_zero > 0 - n_fft = (num_frequency_bins - 1) * 2 - mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) - mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) - mel_delta = (mel_max - mel_min) / (num_mel_filters + 1) - bin_idx = torch.arange(num_mel_filters, dtype=computation_dtype).unsqueeze(1) - left_mel = mel_min + bin_idx * mel_delta - center_mel = mel_min + (bin_idx + 1.0) * mel_delta - right_mel = mel_min + (bin_idx + 2.0) * mel_delta - - fft_bin_width = sampling_rate / n_fft - hz_freqs = fft_bin_width * torch.arange(mel_cfg.bands_to_zero, num_frequency_bins, dtype=computation_dtype) - mel = _torch_hertz_to_mel(hz_freqs, mel_scale=mel_cfg.mel_scale).unsqueeze(0) - - up_slope = (mel - left_mel) / (center_mel - left_mel) - down_slope = (right_mel - mel) / (right_mel - center_mel) - mel_filters = torch.max(torch.zeros(1, dtype=computation_dtype), torch.min(up_slope, down_slope)) - - mel_filters = mel_filters.T - if mel_cfg.bands_to_zero > 0: - mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) - else: - n_fft = (num_frequency_bins - 1) * 2 - mel_min = _torch_hertz_to_mel_scalar(min_frequency, mel_scale=mel_cfg.mel_scale) - mel_max = _torch_hertz_to_mel_scalar(max_frequency, mel_scale=mel_cfg.mel_scale) - mel_freqs = torch.linspace(mel_min, mel_max, num_mel_filters + 2, dtype=computation_dtype) - filter_freqs = _torch_mel_to_hertz(mel_freqs, mel_scale=mel_cfg.mel_scale) - - if mel_cfg.frequency_bin_mode == "rfft": - fft_freqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sampling_rate) - else: - fft_freqs = torch.linspace(0, sampling_rate // 2, num_frequency_bins) - if computation_dtype is not None: - fft_freqs = fft_freqs.to(computation_dtype) - - # Triangular filter bank - filter_diff = filter_freqs[1:] - filter_freqs[:-1] - slopes = filter_freqs.unsqueeze(0) - fft_freqs.unsqueeze(1) - down_slopes = -slopes[:, :-2] / filter_diff[:-1] - up_slopes = slopes[:, 2:] / filter_diff[1:] - mel_filters = torch.clamp(torch.minimum(down_slopes, up_slopes), min=0) - - if mel_cfg.norm == "slaney": - enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters]) - mel_filters = mel_filters * enorm.unsqueeze(0) - - if mel_cfg.bands_to_zero > 0: - mel_filters = torch.nn.functional.pad(mel_filters, (0, 0, mel_cfg.bands_to_zero, 0)) - - # When computation_dtype is set only on the mel config (not on the - # spectrogram config), the filters were computed in high precision for - # accuracy but the spectrogram will be in the default dtype — cast back. - if computation_dtype is not None and not spectrogram_config.computation_dtype: - mel_filters = mel_filters.to(torch.get_default_dtype()) - return mel_filters - - def _pad_features(self, features, padding, max_length, truncation, pad_to_multiple_of): - padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length) - - if truncation and max_length is not None: - features = [f[:max_length] for f in features] - - actual_lengths = [f.shape[0] for f in features] - - if padding_strategy == PaddingStrategy.LONGEST: - max_length = max(actual_lengths) - padding_strategy = PaddingStrategy.MAX_LENGTH - - if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): - max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of - - if padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: - padded = [] - for f in features: - if f.shape[0] < max_length: - pad_amount = max_length - f.shape[0] - # Pad last dim=0 (time axis): F.pad takes innermost dims first - pad_args = [0, 0] * (f.ndim - 1) + [0, pad_amount] - f = torch.nn.functional.pad(f, pad_args, "constant", self.padding_value) - padded.append(f) - features = padded - - feature_ranges = [(0, length) for length in actual_lengths] - return features, feature_ranges - - def _stack_features(self, features): - return torch.stack(features) - - def _get_feature_mask(self, feature_ranges, padded_length): - mask = torch.zeros((len(feature_ranges), padded_length), dtype=torch.int32) - for i, (start, end) in enumerate(feature_ranges): - mask[i, start:end] = 1 - return {"audio_features_mask": mask} - - def _to_batch(self, audio): - batch = torch.stack(audio) - if self.add_channel_dim: - batch = batch.unsqueeze(1) - return batch - - def _get_mask(self, audio_ranges, padded_length, do_extract_spectrogram, spectrogram_config): - use_audio_mask = self.mask_level == "audio" - if do_extract_spectrogram and not use_audio_mask: - spec_cfg = spectrogram_config or self.spectrogram_config - audio_lengths = torch.tensor([end - start for start, end in audio_ranges]) - features_lengths = self._get_features_lengths(audio_lengths, spec_cfg) - n_features = self._get_features_lengths(padded_length, spec_cfg, include_center_frame=True) - mask = (torch.arange(n_features)[None, :] < features_lengths[:, None]).to(torch.int32) - return {"audio_features_mask": mask} - else: - mask = torch.zeros((len(audio_ranges), padded_length), dtype=torch.int32) - for i, (start, end) in enumerate(audio_ranges): - mask[i, start:end] = 1 - key = "audio_features_mask" if do_extract_spectrogram else "audio_values_mask" - return {key: mask}