diff --git a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py index 9eca494dbd79..311d94a9cc9c 100644 --- a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py @@ -63,6 +63,14 @@ class AudioFlamingo3EncoderConfig(PretrainedConfig): Scale embeddings by dividing by sqrt(hidden_size). max_source_positions (`int`, *optional*, defaults to 1500): The maximum sequence length of log-mel filter-bank features that this model might ever be used with. + use_rotary_embedding (`bool`, *optional*, defaults to `False`): + Whether to use rotary embeddings (RoTE) in the encoder. + rotary_dim (`int`, *optional*, defaults to 256): + Dimension for the rotary embeddings. + rotary_freqs_for (`str`, *optional*, defaults to `"lang"`): + Frequency type for rotary embeddings. + rotary_max_time (`float`, *optional*, defaults to 1200.0): + Maximum time (in seconds) for rotary embeddings scaling. Example: @@ -104,6 +112,10 @@ def __init__( initializer_range=0.02, scale_embedding=False, max_source_positions=1500, + use_rotary_embedding=False, + rotary_dim=256, + rotary_freqs_for="lang", + rotary_max_time=1200.0, **kwargs, ): super().__init__(**kwargs) @@ -122,6 +134,10 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.scale_embedding = scale_embedding self.max_source_positions = max_source_positions + self.use_rotary_embedding = use_rotary_embedding + self.rotary_dim = rotary_dim + self.rotary_freqs_for = rotary_freqs_for + self.rotary_max_time = rotary_max_time class AudioFlamingo3Config(PretrainedConfig): diff --git a/src/transformers/models/audioflamingo3/convert_musicflamingo_to_hf.py b/src/transformers/models/audioflamingo3/convert_musicflamingo_to_hf.py new file mode 100644 index 000000000000..68ca2910ada2 --- /dev/null +++ b/src/transformers/models/audioflamingo3/convert_musicflamingo_to_hf.py @@ -0,0 +1,307 @@ +# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Convert MusicFlamingo checkpoints into a Hugging Face repository layout.""" + +from __future__ import annotations + +import argparse +import json +import logging +from collections import defaultdict +from pathlib import Path +from typing import Any + +import torch +from safetensors.torch import safe_open + +from transformers import ( + AudioFlamingo3Config, + AudioFlamingo3EncoderConfig, + AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Processor, + AutoTokenizer, + GenerationConfig, + Qwen2Config, + WhisperFeatureExtractor, +) + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + +def _load_json(p: Path): + if not p.is_file(): + raise FileNotFoundError(f"Missing JSON: {p}") + with p.open("r", encoding="utf-8") as f: + return json.load(f) + + +def write_processor(src_root: Path, dst_root: Path): + llm_dir = src_root / "llm" + + system_prompt = ( + "You are Music Flamingo, a multimodal assistant for language and music. " + "On each turn you receive an audio clip which contains music and optional text, " + "you will receive at least one or both; use your world knowledge and reasoning " + "to help the user with any task. Interpret the entirety of the content any input music" + "--regardlenss of whether the user calls it audio, music, or sound." + ) + + # fmt: off + tokenizer_chat_template = ( + "{% if messages[0]['role'] != 'system' %}" + "{{ '<|im_start|>system\\n" + system_prompt + "<|im_end|>\\n' }}" + "{% endif %}" + "{% for message in messages if message['content'] is not none %}" + "{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" + ) + # fmt: on + + # fmt: off + processor_chat_template = ( + "{% if messages[0]['role'] != 'system' %}" + "<|im_start|>system\n" + system_prompt + "<|im_end|>\n" + "{% endif %}" + "{% for m in messages if m['content'] is not none %}" + "<|im_start|>{{ m['role'] }}\n" + "{% if m['content'] is string %}" + "{{ m['content'] }}" + "{% else %}" + "{% set audio = namespace(found=False) %}" + "{% set text_buf = namespace(v='') %}" + "{% for c in m['content'] %}" + "{% if c.get('type') == 'audio' or 'audio' in c %}" + "{% set audio.found = True %}" + "{% elif c.get('type') == 'text' or 'text' in c %}" + "{% set text_buf.v = text_buf.v + c['text'] %}" + "{% endif %}" + "{% endfor %}" + "{% if audio.found %}{{ '' }}{% endif %}{{ text_buf.v }}" + "{% endif %}" + "<|im_end|>\n" + "{% endfor %}" + "{% if add_generation_prompt %}" + "<|im_start|>assistant\n" + "{% endif %}" + ) + # fmt: on + + processor = AudioFlamingo3Processor( + feature_extractor=WhisperFeatureExtractor(feature_size=128, return_attention_mask=True), + tokenizer=AutoTokenizer.from_pretrained(str(llm_dir), chat_template=tokenizer_chat_template, use_fast=True), + chat_template=processor_chat_template, + max_audio_len=1200, + audio_bos_token="<|sound_bos|>", + audio_eos_token="<|sound_eos|>", + ) + processor.save_pretrained(str(dst_root)) + + logger.info("processor (tokenizer + preprocessor)") + return processor + + +PREFIX_MAP = { + "llm": "language_model", + "sound_tower": "audio_tower", + "sound_mm_projector": "multi_modal_projector", +} + + +def _resolve_component_dir(dirpath: Path): + if not dirpath.is_dir(): + return None + idx = dirpath / "model.safetensors.index.json" + mono = dirpath / "model.safetensors" + if idx.exists(): + wm = _load_json(idx).get("weight_map") or {} + by_shard: dict[str, list[str]] = defaultdict(list) + for k, shard in wm.items(): + by_shard[shard].append(k) + return ("sharded", dirpath, {k: sorted(v) for k, v in sorted(by_shard.items())}) + if mono.exists(): + return ("file", mono) + cands = sorted([x for x in dirpath.iterdir() if x.suffix == ".safetensors"]) + return ("file", cands[0]) if len(cands) == 1 else None + + +def merge_and_shard_weights(src_root: Path, dst_root: Path, processor: AudioFlamingo3Processor): + state: dict[str, Any] = {} + for tag in PREFIX_MAP.keys(): + comp = _resolve_component_dir(src_root / tag) + if not comp: + continue + + out_prefix = PREFIX_MAP.get(tag, tag) + + if comp[0] == "file": + fp: Path = comp[1] + with safe_open(str(fp), framework="pt", device="cpu") as f: + for k in f.keys(): + if k == "__metadata__": + continue + state[f"{out_prefix}.{k}"] = f.get_tensor(k) + else: + base: Path = comp[1] + shard_map: dict[str, list[str]] = comp[2] + for shard, keys in shard_map.items(): + sp = base / shard + with safe_open(str(sp), framework="pt", device="cpu") as f: + for k in keys: + state[f"{out_prefix}.{k}"] = f.get_tensor(k) + + if not state: + raise FileNotFoundError("No tensors found in llm/, sound_tower/, or sound_mm_projector/.") + + tok = processor.tokenizer + + text_config = Qwen2Config( + bos_token_id=tok.bos_token_id, + eos_token_id=tok.eos_token_id, + pad_token_id=tok.pad_token_id, + vocab_size=len(tok), + hidden_size=3584, + intermediate_size=18944, + model_max_length=8192, + num_attention_heads=28, + num_hidden_layers=28, + num_key_value_heads=4, + rope_theta=1000000.0, + use_cache=False, + ) + + audio_encoder_config = AudioFlamingo3EncoderConfig( + use_rotary_embedding=True, + rotary_max_time=1200.0, + rotary_freqs_for="lang", + ) + config = AudioFlamingo3Config( + text_config=text_config, + audio_config=audio_encoder_config, + audio_token_id=tok.get_vocab()[""], + ) + model = AudioFlamingo3ForConditionalGeneration(config).to(dtype=torch.bfloat16) + + projector_key_mapping = { + "multi_modal_projector.layers.0.weight": "multi_modal_projector.linear_1.weight", + "multi_modal_projector.layers.0.bias": "multi_modal_projector.linear_1.bias", + "multi_modal_projector.layers.2.weight": "multi_modal_projector.linear_2.weight", + "multi_modal_projector.layers.2.bias": "multi_modal_projector.linear_2.bias", + "audio_tower.sound_tower.pos_emb.freqs": "audio_tower.pos_emb.freqs", + } + for old_key, new_key in projector_key_mapping.items(): + if old_key in state: + state[new_key] = state.pop(old_key) + + # Load weights into the instantiated model so we can push via `push_to_hub` later. + load_res = model.load_state_dict(state, strict=True) + # Enforce a clean load + if getattr(load_res, "missing_keys", None) and load_res.missing_keys: + mk = load_res.missing_keys + raise ValueError(f"Missing keys when loading: {mk[:10]}{' ...' if len(mk) > 10 else ''}") + if getattr(load_res, "unexpected_keys", None) and load_res.unexpected_keys: + uk = load_res.unexpected_keys + raise ValueError(f"Unexpected keys when loading: {uk[:10]}{' ...' if len(uk) > 10 else ''}") + + generation_config = GenerationConfig( + bos_token_id=tok.bos_token_id, + eos_token_id=tok.eos_token_id, + pad_token_id=tok.pad_token_id, + max_new_tokens=2048, + ) + model.generation_config = generation_config + + model.save_pretrained(save_directory=str(dst_root)) + logger.info("model.safetensors index and shards") + return model + + +""" +Reproducible Usage +================== + +1) Download the original MusicFlamingo weights from NVIDIA (requires Git LFS): + +``` +git lfs install +git clone https://huggingface.co/nvidia/music-flamingo +``` + +This will create a folder `music-flamingo/` containing the original components: +`llm/`, `sound_tower/`, and `sound_mm_projector/`. + +2) Convert to the Hugging Face Transformers format (locally): + +``` +python src/transformers/models/audioflamingo3/convert_musicflamingo_to_hf.py \ + --src_dir music-flamingo \ + --dst_dir music-flamingo-hf +``` + +3) Convert and push directly to the Hub (requires `huggingface-cli login` or `HF_TOKEN`): + +``` +python src/transformers/models/audioflamingo3/convert_musicflamingo_to_hf.py \ + --src_dir music-flamingo \ + --dst_dir music-flamingo-hf \ + --push_to_hub /music-flamingo-hf +``` + +This command uploads both the processor (tokenizer + feature extractor) and the converted +model (sharded safetensors + configs) to the specified Hub repository. +""" + + +def main() -> None: + ap = argparse.ArgumentParser(description="Convert MusicFlamingo to Hugging Face format.") + ap.add_argument("--src_dir", required=True, help="Source model root directory") + ap.add_argument("--dst_dir", required=True, help="Destination directory for converted model") + ap.add_argument( + "--push_to_hub", + default=None, + type=str, + help=( + "Optional repository ID to push the converted assets to the Hugging Face Hub, " + "e.g. 'username/music-flamingo-hf'." + ), + ) + args = ap.parse_args() + + src_root = Path(args.src_dir).resolve() + if not src_root.is_dir(): + raise FileNotFoundError(f"Source directory not found: {src_root}") + + dst_root = Path(args.dst_dir).resolve() + if dst_root.exists(): + raise FileExistsError(f"Destination already exists: {dst_root}") + + processor = write_processor(src_root, dst_root) + model = merge_and_shard_weights(src_root, dst_root, processor) + + # Optionally push converted assets using native push_to_hub only + if args.push_to_hub: + logger.info("Pushing processor to the Hub ...") + processor.push_to_hub(args.push_to_hub) + logger.info("Pushing model to the Hub ...") + model.push_to_hub(args.push_to_hub) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index f88a19796f34..00ea4641d5e7 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -19,11 +19,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + import math from collections.abc import Callable +from math import pi import torch -from torch import nn +from torch import Tensor, broadcast_tensors, einsum, nn +from torch.amp import autocast +from torch.nn import Module from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache @@ -43,6 +47,246 @@ logger = logging.get_logger(__name__) +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def rotate_half(x): + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +@autocast("cuda", enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + ori_dtype = t.dtype + embed_dtype = torch.float64 + t = t.to(embed_dtype) + if t.ndim == 3: + seq_len = t.shape[seq_dim] + if freqs.ndim == 2: + freqs = freqs[-seq_len:].to(t) + else: + freqs = freqs.to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], ( + f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + ) + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim=-1).to(ori_dtype) + + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for="lang", + theta=50000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1.0, + theta_rescale_factor=1.0, + seq_before_head_dim=False, + cache_if_possible=True, + max_time=7200, + ): + super().__init__() + + self.dim = dim + self.freqs_for = freqs_for + self.max_freq = max_freq + self.num_freqs = num_freqs + self.learned_freq = learned_freq + self.use_xpos = use_xpos + self.xpos_scale_base = xpos_scale_base + self.interpolate_factor = interpolate_factor + self.theta_rescale_factor = theta_rescale_factor + self.cache_if_possible = cache_if_possible + self.max_time = max_time + + self.tmp_store("cached_freqs", None) + self.tmp_store("cached_scales", None) + + if exists(max_time) and freqs_for == "lang": + theta = max_time / (2 * pi) + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.theta = theta + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.learned_freq = learned_freq + + self.tmp_store("dummy", torch.tensor(0)) + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + assert interpolate_factor >= 1.0 + self.interpolate_factor = interpolate_factor + + if not use_xpos: + self.tmp_store("scale", None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store("scale", scale) + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent=False) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, ( + "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" + ) + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + freqs = self.forward( + self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), seq_len=seq_len, offset=offset + ) + + if seq_dim == -3: + freqs = freqs.unsqueeze(1) + + return apply_rotary_emb(freqs, t, seq_dim=seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + + freqs = self.forward(seq, seq_len=seq_len) + scale = self.get_scale(seq, seq_len=seq_len).to(dtype) + + if seq_dim == -3: + freqs = freqs.unsqueeze(1) + scale = scale.unsqueeze(1) + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0): + assert self.use_xpos + + should_cache = self.cache_if_possible and exists(seq_len) + + if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: + return self.cached_scales[offset : (offset + seq_len)] + + scale = 1.0 + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** power.unsqueeze(-1) + scale = torch.cat((scale, scale), dim=-1) + + if should_cache: + self.tmp_store("cached_scales", scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == "pixel": + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + freqs = self.forward(pos, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + @autocast("cuda", enabled=False) + def forward(self, t: Tensor, seq_len=None, offset=0): + should_cache = ( + self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" + ) + + if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]: + return self.cached_freqs[offset : (offset + seq_len)].detach() + + freqs = self.freqs + + if hasattr(self, "max_time") and self.max_time is not None: + t = t / self.max_time * (2 * pi) + + freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = torch.repeat_interleave(freqs, 2, dim=-1) + + if should_cache: + self.tmp_store("cached_freqs", freqs.detach()) + + return freqs + + def eager_attention_forward( module: nn.Module, query: torch.Tensor, @@ -263,6 +507,48 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights for AudioFlamingo3-specific modules.""" + if isinstance(module, RotaryEmbedding): + # Reinitialize freqs parameter + dim = module.dim + freqs_for = module.freqs_for + max_time = module.max_time + theta_rescale_factor = module.theta_rescale_factor + custom_freqs = None + + # Adjust theta + if max_time is not None and freqs_for == "lang": + theta = max_time / (2 * pi) + else: + theta = 50000 # default value + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + # Generate freqs + if custom_freqs is not None: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, module.max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(module.num_freqs).float() + + module.freqs.data = freqs + + # Reinitialize dummy buffer + module.dummy.data = torch.tensor(0) + + # Reinitialize scale if using xpos + if module.use_xpos and module.scale is not None: + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + module.scale.data = scale + else: + # Delegate to parent class for other modules + super()._init_weights(module) + @auto_docstring( custom_intro=""" @@ -285,7 +571,7 @@ class AudioFlamingo3Encoder(AudioFlamingo3PreTrainedModel): "attentions": AudioFlamingo3Attention, } - def __init__(self, config: AudioFlamingo3EncoderConfig): + def __init__(self, config: AudioFlamingo3Config): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop @@ -307,6 +593,12 @@ def __init__(self, config: AudioFlamingo3EncoderConfig): self.avg_pooler = nn.AvgPool1d(2, stride=2) self.gradient_checkpointing = False + if getattr(config, "use_rotary_embedding", False): + self.pos_emb = RotaryEmbedding( + dim=config.rotary_dim, + freqs_for=config.rotary_freqs_for, + max_time=config.rotary_max_time, + ) # Initialize weights and apply final processing self.post_init() @@ -326,6 +618,7 @@ def forward( self, input_features: torch.Tensor, input_features_mask: torch.Tensor | None = None, + audio_times: torch.Tensor | None = None, **kwargs, ) -> tuple | BaseModelOutputWithPooling: r""" @@ -338,6 +631,8 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + audio_times (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. Only used if rotary embeddings are enabled. """ seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling @@ -345,6 +640,9 @@ def forward( input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None] + # Cast to model dtype + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + # Conv front-end inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) @@ -371,6 +669,19 @@ def forward( hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) + if ( + hasattr(self.config, "use_rotary_embedding") + and self.config.use_rotary_embedding + and audio_times is not None + ): + times = audio_times.to(hidden_states.device) + freqs = self.pos_emb.get_axial_freqs(times.shape[0], hidden_states.shape[-2]).to(self.conv1.weight.device) + angle = (-times * 2 * pi).to(self.conv1.weight.device) + angle_expanded = angle.unsqueeze(2).expand(times.shape[0], hidden_states.shape[-2], freqs.shape[-1]) + freqs = freqs * angle_expanded + + hidden_states = apply_rotary_emb(freqs, hidden_states) + return BaseModelOutputWithPooling( last_hidden_state=hidden_states, ) @@ -454,6 +765,7 @@ def get_audio_features( self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor, + audio_times: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -465,11 +777,17 @@ def get_audio_features( and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. + audio_times (`torch.Tensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. """ # Encode audio audio_output = self.audio_tower( - input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs + input_features, + input_features_mask=input_features_mask, + audio_times=audio_times, + return_dict=True, + **kwargs, ) audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state) @@ -488,6 +806,7 @@ def forward( input_ids: torch.LongTensor | None = None, input_features: torch.FloatTensor | None = None, input_features_mask: torch.Tensor | None = None, + audio_times: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, @@ -504,6 +823,8 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + audio_times (`torch.Tensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored @@ -565,7 +886,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + audio_embeds = self.get_audio_features( + input_features, input_features_mask, audio_times=audio_times, return_dict=True + ).pooler_output # replace text-audio token placeholders with audio embeddings audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) @@ -591,6 +914,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs): input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) + audio_times = kwargs.pop("audio_times", None) cache_position = kwargs.get("cache_position") model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) @@ -601,6 +925,8 @@ def prepare_inputs_for_generation(self, *args, **kwargs): model_inputs["input_features"] = input_features if input_features_mask is not None: model_inputs["input_features_mask"] = input_features_mask + if audio_times is not None: + model_inputs["audio_times"] = audio_times return model_inputs diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index b846957940cc..6241939c3019 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -13,8 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. + +from math import pi + import torch -from torch import nn +from torch import Tensor, broadcast_tensors, einsum, nn +from torch.amp import autocast +from torch.nn import Module from ...activations import ACT2FN from ...cache_utils import Cache @@ -35,6 +40,261 @@ logger = logging.get_logger(__name__) +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def broadcat(tensors, dim=-1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim=dim) + + +def rotate_half(x): + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +@autocast("cuda", enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + ori_dtype = t.dtype + embed_dtype = torch.float64 + t = t.to(embed_dtype) + if t.ndim == 3: + seq_len = t.shape[seq_dim] + if freqs.ndim == 2: + freqs = freqs[-seq_len:].to(t) + else: + freqs = freqs.to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], ( + f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + ) + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim=-1).to(ori_dtype) + + +# learned rotation helpers +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rotations.flatten(-2) + + rotations = torch.repeat_interleave(rotations, 2, dim=-1) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for="lang", + theta=50000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1.0, + theta_rescale_factor=1.0, + seq_before_head_dim=False, + cache_if_possible=True, + max_time=7200, + ): + super().__init__() + + self.dim = dim + self.freqs_for = freqs_for + self.max_freq = max_freq + self.num_freqs = num_freqs + self.learned_freq = learned_freq + self.use_xpos = use_xpos + self.xpos_scale_base = xpos_scale_base + self.interpolate_factor = interpolate_factor + self.theta_rescale_factor = theta_rescale_factor + self.cache_if_possible = cache_if_possible + self.max_time = max_time + + self.tmp_store("cached_freqs", None) + self.tmp_store("cached_scales", None) + + if exists(max_time) and freqs_for == "lang": + theta = max_time / (2 * pi) + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.theta = theta + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.learned_freq = learned_freq + + self.tmp_store("dummy", torch.tensor(0)) + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + assert interpolate_factor >= 1.0 + self.interpolate_factor = interpolate_factor + + if not use_xpos: + self.tmp_store("scale", None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store("scale", scale) + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent=False) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, ( + "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" + ) + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + freqs = self.forward( + self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), seq_len=seq_len, offset=offset + ) + + if seq_dim == -3: + freqs = freqs.unsqueeze(1) + + return apply_rotary_emb(freqs, t, seq_dim=seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, offset=offset) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + + freqs = self.forward(seq, seq_len=seq_len) + scale = self.get_scale(seq, seq_len=seq_len).to(dtype) + + if seq_dim == -3: + freqs = freqs.unsqueeze(1) + scale = scale.unsqueeze(1) + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0): + assert self.use_xpos + + should_cache = self.cache_if_possible and exists(seq_len) + + if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]: + return self.cached_scales[offset : (offset + seq_len)] + + scale = 1.0 + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** power.unsqueeze(-1) + scale = torch.cat((scale, scale), dim=-1) + + if should_cache: + self.tmp_store("cached_scales", scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == "pixel": + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + freqs = self.forward(pos, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + @autocast("cuda", enabled=False) + def forward(self, t: Tensor, seq_len=None, offset=0): + should_cache = ( + self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" + ) + + if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs.shape[0]: + return self.cached_freqs[offset : (offset + seq_len)].detach() + + freqs = self.freqs + + if hasattr(self, "max_time") and self.max_time is not None: + t = t / self.max_time * (2 * pi) + + freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = torch.repeat_interleave(freqs, 2, dim=-1) + + if should_cache: + self.tmp_store("cached_freqs", freqs.detach()) + + return freqs + + class AudioFlamingo3Attention(WhisperAttention): pass @@ -44,7 +304,47 @@ class AudioFlamingo3EncoderLayer(WhisperEncoderLayer): class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel): - pass + @torch.no_grad() + def _init_weights(self, module): + """Initialize the weights for AudioFlamingo3-specific modules.""" + if isinstance(module, RotaryEmbedding): + # Reinitialize freqs parameter + dim = module.dim + freqs_for = module.freqs_for + max_time = module.max_time + theta_rescale_factor = module.theta_rescale_factor + custom_freqs = None + + # Adjust theta + if max_time is not None and freqs_for == "lang": + theta = max_time / (2 * pi) + else: + theta = 50000 # default value + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + # Generate freqs + if custom_freqs is not None: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, module.max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(module.num_freqs).float() + + module.freqs.data = freqs + + # Reinitialize dummy buffer + module.dummy.data = torch.tensor(0) + + # Reinitialize scale if using xpos + if module.use_xpos and module.scale is not None: + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + module.scale.data = scale + else: + # Delegate to parent class for other modules + super()._init_weights(module) @auto_docstring( @@ -62,11 +362,21 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): "attentions": AudioFlamingo3Attention, } + def __init__(self, config: AudioFlamingo3Config): + super().__init__(config) + if getattr(config, "use_rotary_embedding", False): + self.pos_emb = RotaryEmbedding( + dim=config.rotary_dim, + freqs_for=config.rotary_freqs_for, + max_time=config.rotary_max_time, + ) + @check_model_inputs def forward( self, input_features: torch.Tensor, input_features_mask: torch.Tensor | None = None, + audio_times: torch.Tensor | None = None, **kwargs, ) -> tuple | BaseModelOutputWithPooling: r""" @@ -79,6 +389,8 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + audio_times (`torch.FloatTensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. Only used if rotary embeddings are enabled. """ seq_len = (input_features.shape[-1] - 1) // 2 + 1 # After conv2 downsampling @@ -86,6 +398,9 @@ def forward( input_features_lengths = (input_features_lengths - 1) // 2 + 1 # conv2 downsampling input_features_mask = torch.arange(seq_len, device=input_features.device) < input_features_lengths[:, None] + # Cast to model dtype + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + # Conv front-end inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) @@ -112,6 +427,19 @@ def forward( hidden_states = self.avg_pooler(hidden_states).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) + if ( + hasattr(self.config, "use_rotary_embedding") + and self.config.use_rotary_embedding + and audio_times is not None + ): + times = audio_times.to(hidden_states.device) + freqs = self.pos_emb.get_axial_freqs(times.shape[0], hidden_states.shape[-2]).to(self.conv1.weight.device) + angle = (-times * 2 * pi).to(self.conv1.weight.device) + angle_expanded = angle.unsqueeze(2).expand(times.shape[0], hidden_states.shape[-2], freqs.shape[-1]) + freqs = freqs * angle_expanded + + hidden_states = apply_rotary_emb(freqs, hidden_states) + return BaseModelOutputWithPooling( last_hidden_state=hidden_states, ) @@ -155,6 +483,7 @@ def get_audio_features( self, input_features: torch.FloatTensor, input_features_mask: torch.Tensor, + audio_times: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPooling: r""" @@ -166,11 +495,17 @@ def get_audio_features( and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. + audio_times (`torch.Tensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. """ # Encode audio audio_output = self.audio_tower( - input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs + input_features, + input_features_mask=input_features_mask, + audio_times=audio_times, + return_dict=True, + **kwargs, ) audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state) @@ -189,6 +524,7 @@ def forward( input_ids: torch.LongTensor | None = None, input_features: torch.FloatTensor | None = None, input_features_mask: torch.Tensor | None = None, + audio_times: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, @@ -205,6 +541,8 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + audio_times (`torch.Tensor` of shape `(batch_size,)`, *optional*): + The start time of the audio segments in seconds. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored @@ -266,7 +604,9 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + audio_embeds = self.get_audio_features( + input_features, input_features_mask, audio_times=audio_times, return_dict=True + ).pooler_output # replace text-audio token placeholders with audio embeddings audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) @@ -292,6 +632,7 @@ def prepare_inputs_for_generation(self, *args, **kwargs): input_features = kwargs.pop("input_features", None) input_features_mask = kwargs.pop("input_features_mask", None) + audio_times = kwargs.pop("audio_times", None) cache_position = kwargs.get("cache_position") model_inputs = super().prepare_inputs_for_generation(*args, **kwargs) @@ -302,6 +643,8 @@ def prepare_inputs_for_generation(self, *args, **kwargs): model_inputs["input_features"] = input_features if input_features_mask is not None: model_inputs["input_features_mask"] = input_features_mask + if audio_times is not None: + model_inputs["audio_times"] = audio_times return model_inputs diff --git a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py index 0fbac0791726..63a6e3fc838a 100644 --- a/src/transformers/models/audioflamingo3/processing_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/processing_audioflamingo3.py @@ -81,9 +81,15 @@ def __init__( audio_token="", default_transcription_prompt="Transcribe the input speech.", max_audio_len=600, + audio_bos_token=None, + audio_eos_token=None, ): self.audio_token = audio_token + self.audio_bos_token = audio_bos_token + self.audio_eos_token = audio_eos_token self.audio_token_id = tokenizer.convert_tokens_to_ids(audio_token) + self.audio_bos_token_id = tokenizer.convert_tokens_to_ids(audio_bos_token) if audio_bos_token else None + self.audio_eos_token_id = tokenizer.convert_tokens_to_ids(audio_eos_token) if audio_eos_token else None self.default_transcription_prompt = default_transcription_prompt self.max_audio_len = max_audio_len super().__init__(feature_extractor, tokenizer, chat_template=chat_template) @@ -151,6 +157,7 @@ def __call__( per_sample_windows: list[int] = [] flat_chunks: list[np.ndarray] = [] + audio_times_list: list[torch.Tensor] = [] for audio_el in audio: n_samples = int(audio_el.shape[0]) @@ -167,19 +174,33 @@ def __call__( start = i * window_size end = min((i + 1) * window_size, time_cap) flat_chunks.append(audio_el[start:end]) + # Calculate the start time of this audio chunk in seconds + start_sec = start / audio_kwargs["sampling_rate"] + + # Generate 750 timestamps at 40ms intervals (30s / 750 = 0.04s) + if is_torch_available(): + chunk_times = torch.arange(750).float() * 0.04 + start_sec + audio_times_list.append(chunk_times) # Feature extraction audio_inputs = self.feature_extractor(flat_chunks, **audio_kwargs) padding_mask = audio_inputs.pop("attention_mask") audio_inputs["input_features_mask"] = padding_mask + # Add audio times as tensor + if return_tensors == "pt" and audio_times_list: + audio_inputs["audio_times"] = torch.stack(audio_times_list).to(dtype=torch.float32) + # Compute sequence lengths token counting audio_lengths = torch.stack([s.sum() for s in torch.split(padding_mask.sum(-1), per_sample_windows)]) audio_tokens_lengths = self._get_audio_token_length(audio_lengths) # expand audio tokens in text for i, audio_length in enumerate(audio_tokens_lengths): - expanded = re.sub(re.escape(self.audio_token), self.audio_token * audio_length, text[i]) + replacement = self.audio_token * audio_length + if self.audio_bos_token is not None and self.audio_eos_token is not None: + replacement = self.audio_bos_token + replacement + self.audio_eos_token + expanded = re.sub(re.escape(self.audio_token), replacement, text[i]) text[i] = expanded # Tokenize @@ -189,6 +210,10 @@ def __call__( if output_labels: labels = data["input_ids"].clone() labels[labels == self.audio_token_id] = -100 + if self.audio_bos_token_id is not None: + labels[labels == self.audio_bos_token_id] = -100 + if self.audio_eos_token_id is not None: + labels[labels == self.audio_eos_token_id] = -100 labels[labels == self.tokenizer.pad_token_id] = -100 data["labels"] = labels @@ -198,7 +223,7 @@ def __call__( def model_input_names(self) -> list[str]: tok_names = self.tokenizer.model_input_names fea_names = self.feature_extractor.model_input_names - return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask"])) + return list(dict.fromkeys(tok_names + fea_names + ["input_features_mask", "audio_times"])) def apply_transcription_request( self, diff --git a/tests/fixtures/audioflamingo3/expected_music_results_batched.json b/tests/fixtures/audioflamingo3/expected_music_results_batched.json new file mode 100644 index 000000000000..cace70c6e04b --- /dev/null +++ b/tests/fixtures/audioflamingo3/expected_music_results_batched.json @@ -0,0 +1 @@ +{"token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220, 16, 21, 18, 13, 20, 24, 6486, 624, 12151, 78, 609, 5309, 1365, 576, 5492, 10797, 518, 264, 74391, 378, 107, 16, 20, 15, 378, 107, 33, 8795, 323, 374, 40876, 304, 378, 107, 36, 378, 107, 36505, 624, 56324, 367, 609, 24039, 1365, 362, 43361, 11, 1550, 55964, 69, 45101, 5670, 14087, 279, 26112, 13, 576, 36290, 16266, 374, 5798, 389, 264, 3040, 55964, 263, 55964, 1782, 55964, 30449, 14346, 23196, 5383, 448, 41854, 10323, 11, 4131, 11144, 4131, 546, 11, 323, 10296, 15588, 55964, 71, 1862, 11, 678, 3108, 55964, 331, 2627, 311, 264, 20380, 88, 42898, 21529, 429, 31676, 38969, 448, 279, 46289, 13, 26410, 11, 796, 10311, 8212, 657, 42898, 11508, 323, 63141, 43221, 278, 35995, 6777, 279, 1887, 10581, 52760, 29677, 11, 1393, 5107, 42898, 357, 3435, 323, 9824, 388, 29100, 6292, 279, 1936, 55964, 8602, 323, 21025, 13, 576, 6514, 374, 6884, 55964, 267, 64853, 11, 448, 279, 42898, 5424, 281, 7295, 311, 1855, 458, 60738, 5112, 20743, 11, 323, 279, 8084, 87761, 65059, 17361, 2090, 323, 31273, 11, 14260, 315, 18706, 15254, 55964, 30449, 13918, 624, 53, 3683, 83984, 1365, 576, 2990, 374, 264, 8778, 752, 45648, 55964, 82, 46288, 5652, 6693, 6792, 20512, 374, 2797, 11, 9906, 11, 323, 10078, 15233, 13, 2932, 27321, 279, 61584, 304, 264, 4240, 11, 10581, 52760, 1707, 11, 23922, 448, 26447, 3233, 55964, 83, 2886, 11, 312, 22328, 11, 323, 7626, 429, 2968, 279, 7743, 264, 73056, 11, 8887, 55964, 2307, 1340, 268, 13, 576, 25407, 1555, 23011, 4065, 55964, 437, 55964, 3057, 304, 279, 6514, 11, 14376, 1526, 279, 27950, 42898, 26112, 624, 43, 10920, 938, 62087, 1365, 576, 23261, 5772, 3948, 2163, 2948, 11, 78322, 11, 323, 279, 86335, 2355, 315, 264, 27430, 9362, 13, 576, 8622, 11, 58077, 9704, 2293, 9, 2073, 3838, 1035, 847, 1879, 387, 2041, 498, 12390, 9, 2293, 70499, 279, 55810, 323, 71790, 279, 5492, 748, 14269, 6200, 13, 6944, 5128, 1741, 438, 353, 2073, 2610, 13020, 279, 9759, 3403, 847, 6084, 1717, 21461, 854, 9, 323, 353, 2073, 19389, 21059, 88148, 1119, 25678, 24908, 854, 9, 40368, 279, 1042, 1229, 323, 37550, 16232, 624, 33039, 28596, 609, 52611, 1365, 576, 26112, 11017, 264, 11416, 19461, 98875, 52929, 25, 458, 40945, 19706, 429, 63564, 279, 9842, 9382, 11, 8110, 553, 49299, 429, 19131, 279, 25407, 19221, 11, 264, 855, 55964, 6150, 355, 429, 22111, 23504, 448, 16062, 42898, 13617, 11, 323, 264, 68897, 55810, 1380, 279, 9704, 43594, 916, 2480, 55964, 339, 27535, 6782, 16896, 323, 264, 59387, 23196, 5383, 13, 42305, 278, 18303, 3410, 3550, 369, 279, 42898, 11508, 311, 32405, 11, 6388, 1119, 264, 14164, 429, 29922, 1182, 311, 264, 803, 31387, 10434, 1573, 279, 1590, 11, 82498, 55810, 323, 264, 2805, 60658, 429, 86409, 279, 1887, 9704, 13, 21886, 20612, 37168, 71759, 11, 448, 1817, 55810, 7842, 4960, 13617, 320, 35499, 35995, 11, 5080, 55964, 41692, 523, 6782, 16896, 8, 311, 96068, 279, 63943, 19530, 292, 2666, 624, 785, 90767, 70971, 1365, 39659, 292, 7203, 26558, 13771, 2878, 279, 468, 55964, 36505, 1853, 266, 14011, 2070, 13, 576, 49299, 3545, 10775, 1526, 378, 107, 36, 378, 107, 4142, 378, 107, 34, 145346, 76, 378, 107, 4142, 378, 107, 32, 378, 107, 4142, 378, 107, 33, 11, 264, 358, 55964, 9971, 55964, 3090, 55964, 53, 32724, 429, 11450, 264, 9906, 11, 94509, 6337, 13, 576, 855, 55964, 6150, 355, 38919, 378, 107, 37, 145346, 76, 378, 107, 437, 378, 107, 33, 21, 11, 7842, 264, 26447, 8922, 55964, 1966, 5600, 85296, 17172, 1573, 52483, 1182, 311, 279, 98205, 55964, 3057, 291, 55810, 13, 576, 14164, 92836, 389, 378, 107, 6091, 1630, 22, 378, 107, 437, 378, 107, 33, 22, 11, 8241, 264, 9814, 13228, 6407, 429, 2608, 724, 14269, 23504, 1573, 279, 1590, 470, 311, 279, 2114, 1376, 624, 27489, 93567, 609, 9608, 1365, 576, 3754, 505, 28146, 264, 63943, 19530, 292, 11, 37550, 16566, 11, 77749, 279, 82274, 95070, 315, 3309, 55964, 16, 24, 24, 15, 82, 19461, 98875, 448, 279, 47394, 5670, 315, 220, 17, 15, 17, 15, 82, 15254, 55964, 8374, 13, 11445, 22268, 8111, 55810, 323, 43361, 42898, 975, 1281, 432, 1632, 55964, 27051, 1608, 369, 2176, 8887, 1486, 323, 6335, 5003, 11, 80558, 287, 279, 18706, 90490, 315, 2666, 55964, 18536, 11, 61584, 55964, 3612, 2071, 15254, 4627, 13, 151645], [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 279, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279, 17788, 11, 34254, 8048, 1616, 2238, 5135, 271, 334, 1143, 29869, 1019, 61457, 3729, 22502, 8266, 1290, 11, 1449, 22721, 264, 10526, 17970, 11, 304, 279, 40363, 315, 46652, 35966, 11, 1077, 279, 89671, 16484, 271, 334, 68043, 220, 18, 1019, 43930, 415, 60217, 389, 279, 3108, 11, 62371, 49411, 1690, 582, 646, 944, 10265, 11, 89115, 5059, 69051, 11, 4325, 11253, 279, 1618, 271, 334, 68043, 220, 19, 1019, 17814, 264, 46615, 11, 38862, 86979, 11, 5538, 12000, 11, 264, 26725, 57945, 11, 297, 1580, 5652, 50698, 642, 11174, 11, 22502, 2948, 1007, 279, 8781, 271, 334, 1143, 29869, 1019, 61457, 3729, 22502, 8266, 1290, 11, 1449, 22721, 264, 10526, 17970, 11, 304, 279, 40363, 315, 46652, 35966, 11, 1077, 279, 89671, 16484, 271, 334, 68043, 220, 18, 1019, 43930, 415, 60217, 389, 279, 3108, 11, 62371, 49411, 1690, 582, 646, 944, 10265, 11, 89115, 5059, 69051, 11, 4325, 11253, 279, 1618, 271, 334, 32848, 1019, 641, 279, 40363, 315, 46652, 35966, 11, 1077, 279, 89671, 16484, 271, 334, 68043, 220, 18, 1019, 43930, 415, 60217, 389, 279, 3108, 11, 62371, 49411, 1690, 582, 646, 944, 10265, 11, 89115, 5059, 69051, 11, 4325, 11253, 279, 1618, 1406, 334, 68043, 220, 19, 1019, 17814, 264, 46615, 11, 38862, 86979, 11, 5538, 12000, 11, 264, 26725, 57945, 11, 297, 1580, 5652, 50698, 642, 11174, 11, 22502, 2948, 1007, 279, 8781, 271, 334, 1143, 29869, 1019, 61457, 3729, 22502, 8266, 1290, 11, 1449, 22721, 264, 10526, 17970, 11, 304, 279, 40363, 315, 46652, 35966, 11, 1077, 279, 89671, 16484, 271, 334, 68043, 220, 18, 1019, 43930, 415, 60217, 389, 279, 3108, 11, 62371, 49411, 1690, 582, 646, 944, 10265, 11, 89115, 5059, 69051, 11, 4325, 11253, 279, 1618, 271, 334, 68043, 220, 19, 1019, 17814, 264, 46615, 11, 38862, 86979, 11, 5538, 12000, 11, 264, 26725, 57945, 11, 297, 1580, 5652, 50698, 642, 11174, 11, 22502, 2948, 1007, 279, 8781, 271, 334, 2662, 299, 1019, 61457, 3729, 22502, 8266, 1290, 11, 1896, 264, 26725, 57945, 11, 22502, 2948, 1007, 279, 8781, 11, 14019, 11, 1896, 264, 46615, 11, 38862, 86979, 11, 5538, 12000, 11, 264, 26725, 57945, 11, 297, 1580, 5652, 50698, 642, 11174, 11, 22502, 2948, 1007, 279, 8781, 151645, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669, 151669]], "transcriptions": ["This track is an energetic Eurodance\u202f/\u202fDance\u2011Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club\u2011ready pulse of classic Eurodance. The duration of the piece is 163.59 seconds.\nTempo & Key \u2013 The song moves at a brisk\u202f150\u202fBPM and is rooted in\u202fE\u202fmajor.\nInstrumentation & Production \u2013 A polished, high\u2011fidelity production frames the arrangement. The rhythm foundation is built on a four\u2011on\u2011the\u2011floor electronic drum pattern with crisp kick, snappy snare, and tight hi\u2011hats, all side\u2011chained to a punchy synth bass that locks tightly with the drums. Bright, arpeggiated synth leads and layered chordal pads carry the main melodic hooks, while additional synth stabs and risers accentuate the build\u2011ups and drops. The mix is wide\u2011stereo, with the synth elements panned to create an expansive soundstage, and the overall mastering emphasizes loudness and clarity, typical of contemporary dance\u2011floor tracks.\nVocal Characteristics \u2013 The lead is a female mezzo\u2011soprano whose timbre is clear, bright, and slightly processed. She delivers the melody in a clean, melodic style, enhanced with subtle auto\u2011tune, reverb, and delay that give the voice a glossy, radio\u2011ready sheen. The vocal line sits front\u2011and\u2011center in the mix, cutting through the dense synth arrangement.\nLyrical Themes \u2013 The lyrics revolve around love, longing, and the transformative power of a beloved presence. The central, repetitive hook\u2014*\u201cWhat would my world be without you?\u201d*\u2014anchors the chorus and underscores the song\u2019s emotional core. Other lines such as *\u201cYou lit the stars above my sleepless nights\u201d* and *\u201cTurn silent whispers into endless flights\u201d* illustrate the yearning and hopeful tone.\nSong Structure & Dynamics \u2013 The arrangement follows a classic Eurodance blueprint: an instrumental intro that establishes the driving beat, followed by verses that introduce the vocal narrative, a pre\u2011chorus that builds tension with rising synth layers, and a soaring chorus where the hook repeats over full\u2011throttle synths and a heightened drum pattern. Instrumental breaks provide space for the synth leads to shine, leading into a bridge that strips back to a more intimate texture before the final, amplified chorus and a short outro that fades the main hook. Dynamic intensity rises progressively, with each chorus adding extra layers (additional pads, higher\u2011octave synths) to amplify the euphoric feel.\nTheoretical Insight \u2013 Harmonic movement stays largely within the E\u2011major diatonic field. The verses often cycle through\u202fE\u202f\u2013\u202fC\u266fm\u202f\u2013\u202fA\u202f\u2013\u202fB, a I\u2011vi\u2011IV\u2011V progression that creates a bright, uplifting loop. The pre\u2011chorus introduces\u202fF\u266fm\u202fand\u202fB6, adding a subtle minor\u2011subdominant flavor before resolving back to the tonic\u2011centered chorus. The bridge leans on\u202fAmaj7\u202fand\u202fB7, providing a brief modal shift that heightens emotional tension before the final return to the home key.\nOverall Mood & Context \u2013 The track exudes a euphoric, hopeful atmosphere, marrying the nostalgic sparkle of late\u20111990s Eurodance with the sleek production of 2020s dance\u2011pop. Its anthemic chorus and polished synth work make it well\u2011suited for both radio play and club settings, embodying the contemporary resurgence of feel\u2011good, melody\u2011driven dance music.", "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on the skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the moon, slices gone way too soon\n\n**Chorus**\nLate night pizza feeling right, every bite a pure delight, in the warmth of neon glow, let the toppings overflow\n\n**Verse 3**\nGarlic knots on the side, grease drips we can't hide, marinara waterfall, someone answers the call\n\n**Verse 4**\nTake a sip, soda fizz, deep dish, a holy bliss, oregano sprinkles rain, pizza love off the chain\n\n**Chorus**\nLate night pizza feeling right, every bite a pure delight, in the warmth of neon glow, let the toppings overflow\n\n**Verse 3**\nGarlic knots on the side, grease drips we can't hide, marinara waterfall, someone answers the call\n\n**Bridge**\nIn the warmth of neon glow, let the toppings overflow\n\n**Verse 3**\nGarlic knots on the side, grease drips we can't hide, marinara waterfall, someone answers the call\n\n\n**Verse 4**\nTake a sip, soda fizz, deep dish, a holy bliss, oregano sprinkles rain, pizza love off the chain\n\n**Chorus**\nLate night pizza feeling right, every bite a pure delight, in the warmth of neon glow, let the toppings overflow\n\n**Verse 3**\nGarlic knots on the side, grease drips we can't hide, marinara waterfall, someone answers the call\n\n**Verse 4**\nTake a sip, soda fizz, deep dish, a holy bliss, oregano sprinkles rain, pizza love off the chain\n\n**Outro**\nLate night pizza feeling right, take a holy bliss, pizza love off the chain, oh, take a sip, soda fizz, deep dish, a holy bliss, oregano sprinkles rain, pizza love off the chain"]} \ No newline at end of file diff --git a/tests/fixtures/audioflamingo3/expected_music_results_single.json b/tests/fixtures/audioflamingo3/expected_music_results_single.json new file mode 100644 index 000000000000..df58f305becc --- /dev/null +++ b/tests/fixtures/audioflamingo3/expected_music_results_single.json @@ -0,0 +1 @@ +{"token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 55964, 3528, 29604, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 5670, 374, 43361, 323, 1550, 55964, 69, 45101, 11, 16445, 264, 6884, 55964, 267, 64853, 6514, 429, 7482, 75361, 287, 42898, 11508, 11, 57267, 35995, 11, 323, 264, 20380, 88, 42898, 21529, 4065, 323, 12261, 11, 1393, 264, 41854, 14346, 23196, 16138, 2293, 74182, 10323, 11, 4131, 11144, 4131, 546, 11, 323, 11048, 15588, 55964, 9198, 12624, 2293, 2674, 2010, 279, 36290, 4637, 13, 576, 26112, 374, 5798, 2163, 264, 2797, 32387, 55964, 6150, 355, 12626, 11, 31355, 12852, 553, 40945, 18303, 429, 2608, 268, 279, 15254, 30449, 23270, 13, 576, 8090, 315, 279, 6573, 374, 220, 16, 21, 18, 13, 20, 24, 6486, 624, 53, 509, 1127, 525, 12600, 553, 264, 8778, 752, 45648, 55964, 82, 46288, 5652, 448, 264, 2797, 11, 9906, 6792, 20512, 13, 6252, 10581, 52760, 11, 77123, 9691, 374, 23922, 448, 57072, 51121, 312, 22328, 323, 7626, 11, 7086, 279, 5068, 264, 32136, 11, 94509, 2666, 13, 576, 85237, 938, 2213, 18652, 389, 2948, 323, 43293, 11, 18822, 10161, 1036, 3838, 1035, 847, 1879, 387, 2041, 498, 12390, 323, 31589, 279, 86335, 2355, 315, 264, 8263, 26087, 2610, 13020, 279, 9759, 3403, 847, 6084, 1717, 21461, 2419, 1036, 19389, 21059, 88148, 1119, 25678, 24908, 64212, 4220, 45250, 5128, 54314, 279, 5492, 748, 37550, 11, 23467, 19221, 624, 9422, 41924, 11, 279, 5492, 15885, 448, 458, 6529, 55964, 58212, 7132, 42898, 19706, 429, 7289, 279, 8766, 278, 19671, 11, 10797, 1119, 264, 32387, 1380, 279, 25407, 61584, 582, 4693, 916, 264, 24020, 3040, 55964, 263, 55964, 1782, 55964, 30449, 9382, 11, 1221, 37075, 1119, 264, 68897, 55810, 429, 13617, 279, 1887, 9704, 448, 5107, 42898, 796, 10311, 70, 3530, 323, 264, 86918, 23196, 5383, 13, 1527, 40945, 1438, 11017, 11, 16445, 264, 18293, 42898, 2990, 323, 264, 9814, 5943, 429, 22111, 23504, 1573, 13451, 311, 279, 1590, 55810, 11, 892, 374, 11504, 448, 3694, 25407, 993, 55964, 35719, 323, 264, 26447, 11893, 304, 279, 26112, 11, 11695, 64283, 304, 458, 44855, 60658, 429, 86409, 279, 1887, 42898, 59512, 624, 27489, 11, 279, 3754, 505, 28146, 458, 94509, 11, 37550, 16566, 11, 39780, 279, 39657, 48482, 19461, 98875, 8913, 315, 279, 3309, 55964, 16, 24, 24, 15, 82, 311, 4124, 55964, 17, 15, 15, 15, 82, 1393, 9664, 31520, 40876, 304, 18706, 29604, 55964, 11598, 66223, 13, 11445, 9906, 5670, 11, 85505, 10581, 52760, 29677, 11, 323, 37583, 28180, 517, 23261, 1281, 432, 264, 39657, 48482, 2666, 55964, 18536, 6335, 55564, 13, 151645]], "transcriptions": ["This track is an energetic Eurodance\u2011style Dance\u2011Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club\u2011ready pulse of classic Eurodance. The production is polished and high\u2011fidelity, featuring a wide\u2011stereo mix that places shimmering synth leads, lush pads, and a punchy synth bass front and centre, while a crisp electronic drum kit\u2014tight kick, snappy snare, and rapid hi\u2011hat patterns\u2014propels the rhythm forward. The arrangement is built around a clear verse\u2011chorus framework, punctuated by instrumental breaks that heighten the dancefloor momentum. The duration of the piece is 163.59 seconds.\nVocals are delivered by a female mezzo\u2011soprano with a clear, bright timbre. Her melodic, expressive delivery is enhanced with tasteful reverb and delay, giving the performance a spacious, uplifting feel. The lyrical content centers on love and dependence, repeatedly asking \u201cWhat would my world be without you?\u201d and celebrating the transformative power of a partner (\u201cYou lit the stars above my sleepless nights,\u201d \u201cTurn silent whispers into endless flights\u201d). These recurring lines reinforce the song\u2019s hopeful, romantic narrative.\nStructurally, the song opens with an attention\u2011grabbing synth intro that sets the tonal mood, moves into a verse where the vocal melody weaves over a steady four\u2011on\u2011the\u2011floor beat, then launches into a soaring chorus that layers the main hook with additional synth arpeggios and a fuller drum pattern. An instrumental break follows, featuring a filtered synth lead and a brief drop that builds tension before returning to the final chorus, which is repeated with added vocal ad\u2011libs and a subtle lift in the arrangement, culminating in an energetic outro that fades the main synth motif.\nOverall, the track exudes an uplifting, hopeful atmosphere, capturing the quintessential Eurodance spirit of the late\u20111990s to early\u20112000s while remaining firmly rooted in contemporary Dance\u2011Pop aesthetics. Its bright production, catchy melodic hooks, and emotionally resonant lyrics make it a quintessential feel\u2011good club anthem."]} \ No newline at end of file diff --git a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py index 3346cee5c058..845ad6c8bbdc 100644 --- a/tests/models/audioflamingo3/test_modeling_audioflamingo3.py +++ b/tests/models/audioflamingo3/test_modeling_audioflamingo3.py @@ -352,3 +352,158 @@ def test_fixture_batched_matches(self): torch.testing.assert_close(gen_ids.cpu(), exp_ids) txt = self.processor.batch_decode(gen_ids, skip_special_tokens=True) self.assertListEqual(txt, exp_txt) + + +class AudioFlamingo3MusicModelTester(AudioFlamingo3ModelTester): + def __init__( + self, + parent, + audio_token_id=0, + seq_length=25, + feat_seq_length=60, + text_config=None, + audio_config=None, + is_training=True, + ): + if audio_config is None: + audio_config = { + "model_type": "audioflamingo3_encoder", + "hidden_size": 16, + "num_attention_heads": 4, + "intermediate_size": 16, + "num_hidden_layers": 2, + "num_mel_bins": 80, + "max_source_positions": 30, + "initializer_range": 0.02, + "use_rotary_embedding": True, + } + super().__init__(parent, audio_token_id, seq_length, feat_seq_length, text_config, audio_config, is_training) + + +@require_torch +class AudioFlamingo3MusicForConditionalGenerationModelTest(AudioFlamingo3ForConditionalGenerationModelTest): + """ + Model tester for `AudioFlamingo3ForConditionalGeneration` configured as Music Flamingo (with rotary embeddings). + """ + + def setUp(self): + self.model_tester = AudioFlamingo3MusicModelTester(self) + self.config_tester = ConfigTester(self, config_class=AudioFlamingo3Config, has_text_modality=False) + + +@require_torch +class AudioFlamingo3MusicForConditionalGenerationIntegrationTest(unittest.TestCase): + """ + Slow tests against the public checkpoint to validate processor-model alignment and in-place fusion + for the Music Flamingo configuration. + """ + + @classmethod + def setUp(cls): + cleanup(torch_device, gc_collect=True) + cls.checkpoint = "nvidia/music-flamingo-2601-hf" + cls.processor = AutoProcessor.from_pretrained(cls.checkpoint) + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + def test_fixture_single_matches(self): + """ + reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/c979f0f1a2b9223fa137faf1c02022d4#file-reproducer-py + """ + path = Path(__file__).parent.parent.parent / "fixtures/audioflamingo3/expected_music_results_single.json" + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + exp_ids = torch.tensor(raw["token_ids"]) + exp_txt = raw["transcriptions"] + + conversation = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", + }, + { + "type": "audio", + "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", + }, + ], + } + ] + + model = AudioFlamingo3ForConditionalGeneration.from_pretrained( + self.checkpoint, device_map=torch_device, dtype=torch.bfloat16 + ).eval() + + batch = self.processor.apply_chat_template( + conversation, tokenize=True, add_generation_prompt=True, return_dict=True + ).to(model.device, dtype=model.dtype) + seq = model.generate(**batch) + inp_len = batch["input_ids"].shape[1] + gen_ids = seq[:, inp_len:] if seq.shape[1] >= inp_len else seq + + torch.testing.assert_close(gen_ids.cpu(), exp_ids) + txt = self.processor.batch_decode(gen_ids, skip_special_tokens=True) + self.assertListEqual(txt, exp_txt) + + @slow + def test_fixture_batched_matches(self): + """ + reproducer (creates JSON directly in repo): https://gist.github.com/ebezzam/c979f0f1a2b9223fa137faf1c02022d4#file-reproducer-py + """ + path = Path(__file__).parent.parent.parent / "fixtures/audioflamingo3/expected_music_results_batched.json" + with open(path, "r", encoding="utf-8") as f: + raw = json.load(f) + exp_ids = torch.tensor(raw["token_ids"]) + exp_txt = raw["transcriptions"] + + conversations = [ + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this track in full detail - tell me the genre, tempo, and key, then dive into the instruments, production style, and overall mood it creates.", + }, + { + "type": "audio", + "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_1.mp3", + }, + ], + } + ], + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Generate a structured lyric sheet from the input music.", + }, + { + "type": "audio", + "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/song_2.mp3", + }, + ], + } + ], + ] + + model = AudioFlamingo3ForConditionalGeneration.from_pretrained( + self.checkpoint, device_map=torch_device, dtype=torch.bfloat16 + ).eval() + + batch = self.processor.apply_chat_template( + conversations, tokenize=True, add_generation_prompt=True, return_dict=True + ).to(model.device, dtype=model.dtype) + seq = model.generate(**batch) + inp_len = batch["input_ids"].shape[1] + gen_ids = seq[:, inp_len:] if seq.shape[1] >= inp_len else seq + + torch.testing.assert_close(gen_ids.cpu(), exp_ids) + txt = self.processor.batch_decode(gen_ids, skip_special_tokens=True) + self.assertListEqual(txt, exp_txt) diff --git a/tests/models/audioflamingo3/test_processing_audioflamingo3.py b/tests/models/audioflamingo3/test_processing_audioflamingo3.py index bbe01cede854..dff4f2e751c5 100644 --- a/tests/models/audioflamingo3/test_processing_audioflamingo3.py +++ b/tests/models/audioflamingo3/test_processing_audioflamingo3.py @@ -189,3 +189,96 @@ def test_apply_chat_template_audio(self, batch_size: int, return_tensors: str): self._test_apply_chat_template( "audio", batch_size, return_tensors, "audio_input_name", "feature_extractor", MODALITY_INPUT_DATA["audio"] ) + + +class AudioFlamingo3MusicProcessingTest(ProcessorTesterMixin, unittest.TestCase): + processor_class = AudioFlamingo3Processor + + @classmethod + @require_torch + @require_torchaudio + def setUpClass(cls): + cls.checkpoint = "nvidia/music-flamingo-2601-hf" + cls.tmpdirname = tempfile.mkdtemp() + + processor = AudioFlamingo3Processor.from_pretrained(cls.checkpoint) + processor.save_pretrained(cls.tmpdirname) + + @require_torch + @require_torchaudio + def get_tokenizer(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer + + @require_torch + @require_torchaudio + def get_audio_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).audio_processor + + @require_torch + @require_torchaudio + def get_processor(self, **kwargs): + return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdirname, ignore_errors=True) + + @require_torch + @require_torchaudio + def test_music_chat_template_and_boundaries(self): + processor = AutoProcessor.from_pretrained(self.checkpoint) + expected_system_prompt = ( + "<|im_start|>system\nYou are Music Flamingo, a multimodal assistant for language and music. " + "On each turn you receive an audio clip which contains music and optional text, " + "you will receive at least one or both; use your world knowledge and reasoning " + "to help the user with any task. Interpret the entirety of the content any input music" + "--regardlenss of whether the user calls it audio, music, or sound.<|im_end|>\n" + ) + + # Verify that the music-specific system prompt is preserved + self.assertIn(expected_system_prompt, processor.tokenizer.chat_template) + + # Basic integration test with dummy audio + conversations = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Analyze this track."}, + { + "type": "audio", + "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav", + }, + ], + } + ] + + inputs = processor.apply_chat_template( + conversations, tokenize=True, return_dict=True, add_generation_prompt=True + ) + + decoded = processor.decode(inputs["input_ids"][0]) + + if processor.audio_bos_token is not None: + self.assertIn(processor.audio_bos_token, decoded) + if processor.audio_eos_token is not None: + self.assertIn(processor.audio_eos_token, decoded) + + self.assertIn("<|im_start|>user", decoded) + self.assertIn("Analyze this track", decoded) + self.assertIn("<|im_start|>assistant", decoded) + + @require_librosa + @parameterized.expand([(1, "np"), (1, "pt"), (2, "np"), (2, "pt")]) + def test_apply_chat_template_audio(self, batch_size: int, return_tensors: str): + if return_tensors == "np": + self.skipTest("AudioFlamingo3 only supports PyTorch tensors") + self._test_apply_chat_template( + "audio", batch_size, return_tensors, "audio_input_name", "feature_extractor", MODALITY_INPUT_DATA["audio"] + ) + + def prepare_processor_dict(self): + return { + "audio_bos_token": "<|sound_bos|>", + "audio_eos_token": "<|sound_eos|>", + "max_audio_len": 1200, + }