Skip to content

Add xcodec2 model#44178

Open
ebezzam wants to merge 93 commits intohuggingface:mainfrom
ebezzam:add-xcodec2
Open

Add xcodec2 model#44178
ebezzam wants to merge 93 commits intohuggingface:mainfrom
ebezzam:add-xcodec2

Conversation

@ebezzam
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam commented Feb 20, 2026

What does this PR do?

Re-opening #37868

TODO

  • recompute expected outputs
  • passthrough code given new conventions
  • check for unused code paths / configuration parameters

Original checkpoint: https://huggingface.co/HKUSTAudio/xcodec2
Original modeling code: https://huggingface.co/HKUSTAudio/xcodec2/blob/main/modeling_xcodec2.py

@ebezzam
Copy link
Copy Markdown
Contributor Author

ebezzam commented Mar 18, 2026

run-slow: xcodec2

@github-actions
Copy link
Copy Markdown
Contributor

Workflow Run ⚙️

This comment contains run-slow, running the specified jobs:

models: ["models/xcodec2"]
quantizations: []

@github-actions
Copy link
Copy Markdown
Contributor

CI Results

Workflow Run ⚙️

Commit Info

Context Commit Description
RUN 6fd5f248 workflow commit (merge commit)
PR 1fbe78dc branch commit (from PR)
main 24a4dc22 base commit (on main)

✅ No failing test specific to this PR 🎉 👏 !

Copy link
Copy Markdown
Contributor Author

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eustlb a self-review review for X-Codec2!

Main things:

  • Unique feature extraction for DAC-like and SeamlessM4T-like input processing, as the model needs both padded audio and spectrogram inputs.
  • New type of components in modular: Xcodec2FiniteScalarQuantization and Xcodec2ISTFTHead (similar to what we saw in the Vocos PR)
  • Small tweaks/fixes for models that Xcodec2 depended on for modular

Draft model page: https://huggingface.co/bezzam/xcodec2

main_input_name = "input_features"
input_modalities = "audio"
supports_gradient_checkpointing = True
_no_split_modules = ["Wav2Vec2BertEncoderLayer"]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To allow loading with device_map="auto"

@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
super()._init_weights(module)
Copy link
Copy Markdown
Contributor Author

@ebezzam ebezzam Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XCodec2 uses a pretrained checkpoint of Wav2Vec2-BERT, but Xcodec2's test test_can_init_all_missing_weights was failing because Embedding wasn't initialized. We can rely on the base _init_weights and also remove some initialization from below

Comment thread src/transformers/models/xcodec2/modular_xcodec2.py
Comment on lines +134 to +139
class SnakeBeta(SnakeBeta):
pass


class AntiAliasedActivation1d(AntiAliasedActivation1d):
pass
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought just importing above would have been enough, but it wasn't generating the classes without this 🤔

Comment on lines +258 to +268
# Back to audio (ISTFT with "same" padding)
time_frames = torch.fft.irfft(spectrogram_complex, self.n_fft, dim=1, norm="backward")
time_frames = time_frames * self.window[None, :, None]
num_frames = spectrogram_complex.shape[-1]
output_size = (num_frames - 1) * self.hop_length + self.win_length
audio = F.fold(
time_frames,
output_size=(1, output_size),
kernel_size=(1, self.win_length),
stride=(1, self.hop_length),
)[:, 0, 0, self.padding : -self.padding]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.istft doesn't support the custom padding needed here for integrations tests to match expected output

Comment on lines +296 to +299
hidden_states = self.finite_scalar_quantization.bound(
hidden_states
) # For consistency with original checkpoint
quantized_out, indices = self.finite_scalar_quantization(hidden_states)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling self.finite_scalar_quantization.bound is a bit redundant, as it's called within self.finite_scalar_quantization(hidden_states). But the original modeling did it and it is needed to match expected outputs.

return hidden_states + residual


class Xcodec2FiniteScalarQuantization(nn.Module):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new component

return codes, indices


class Xcodec2ISTFTHead(nn.Module):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what we saw in the Vocos PR

@ebezzam ebezzam requested a review from eustlb March 19, 2026 12:08
@ebezzam ebezzam self-assigned this Apr 13, 2026
Comment on lines +409 to +412
if is_torchdynamo_compiling():
synced_gpus = False
else:
synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for torch.compile support

Copy link
Copy Markdown
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's iterate

@slow
@require_torch
class Xcodec2IntegrationTest(unittest.TestCase):
"""NOTE (ebezzam): PyPI model does not support batch inference."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as noted on their model card, their HF checkpoint and the corresponding modeling code with their PyPI package doesn't support batch inference: https://huggingface.co/HKUSTAudio/xcodec2

They claim it's possible to use their GitHub code for batch inference as noted here, but I think most people are rather using the checkpoint/code from the model card (as it's unclear what checkpoint works with their batch inference code on their GitHub page as apparently this is the script for batch inference but there is no HF checkpoint mentioned...)

In any case, I agree we should test batched inference. So I've created a reproducer to compute outputs by looping over samples, and a new test_batch_integration compares batched outputs from the Transformers implementation with each output of the reproducer.

This indeed needed a padding mask for the spectrogram to be returned by the feature extractor (still need to clean up that!).

PS: more context on the PyPi package and some unconventional things done in their modeling which needed to be moved the feature extractor: #37868 (comment)

Comment thread src/transformers/models/xcodec2/modular_xcodec2.py
Comment on lines +494 to +496
semantic_output = self.semantic_model(audio_spectrogram, output_hidden_states=True)
semantic_hidden_16 = semantic_output.hidden_states[16]
semantic_hidden_16 = semantic_hidden_16.transpose(1, 2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've checked the training code there, both inference and training use layer 16 of Wav2Vec2Bert. Why don't we just set num_layers = 16 in the semantic_model_config and take the last hidden states?
We're double saving memory and compute: we don't need each hidden states stored via output_hidden_states. The don't load and infer uselessly layers 16..25

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very good point and idea!

Comment on lines +503 to +506
if acoustic_hidden_states.shape[-1] != semantic_hidden_states.shape[-1]:
min_len = min(acoustic_hidden_states.shape[-1], semantic_hidden_states.shape[-1])
acoustic_hidden_states = acoustic_hidden_states[:, :, :min_len]
semantic_hidden_states = semantic_hidden_states[:, :, :min_len]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be removed

Comment on lines +443 to +445
def apply_weight_norm(self, legacy=True):
weight_norm = nn.utils.weight_norm
if hasattr(nn.utils.parametrizations, "weight_norm") and not legacy:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no sure to see why we have this legacy kwarg

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah brings me back to my first weeks at HF with DAC 😆

I agree that the flag could be removed and we can simply use nn.utils.weight_norm, as the conversion script uses legacy=True:

model.apply_weight_norm()
model = convert_state_dict(original_checkpoint, model)
model.remove_weight_norm()

Using this legacy flag was something I came up with during this PR because the apply_weight_norm method of DAC wrongly assumed that just because nn.utils.parametrizations exists (newer weight norm methods) that it should be used.

This causes an issue when converting a checkpoint that used the legacy method, as they produce difference weight norm tensors in the state dict, hence the legacy flag.

ALTERNATIVELY, 9 months later haha, I think it's better to move apply_weight_norm and remove_weight_norm directly in the conversion script rather having them in the modeling, since they are only really used when converting. What do you think?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Since it could be used for training I would leave it, but here we're not enabling training for now anyway so let's indeed remove it 👍

Comment on lines +431 to +439
self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
self.semantic_adapter = Xcodec2SemanticAdapter(config)
self.acoustic_encoder = Xcodec2Encoder(config)
self.fc_encoder = nn.Linear(
config.hidden_size + config.semantic_model_config.hidden_size,
config.hidden_size + config.semantic_model_config.hidden_size,
)
self.quantizer = Xcodec2Quantizer(config)
self.decoder = Xcodec2Decoder(config)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

semantic_model → semantic_encoder
also even if we don't have the semantic_decoder here since we're not enabling training, I would rename the decoder semantic_decoder

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to semantic_encoder. DId you mean renaming self.decoder to self.acoustic_decoder?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes! semantic decoder my bad

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

acoustic decoder 😉

super().__init__(config)

self.hop_length = config.hop_length
self.semantic_model = AutoModel.from_config(config.semantic_model_config).eval()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this .eval()?

Copy link
Copy Markdown
Contributor Author

@ebezzam ebezzam Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original had it, as the semantic encoder was meant to be frozen. But you're right that it's not necessary, and if someone wants to train (which isn't supported now), they could always freeze it themselves.

Removing it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eval() does not freeze weights but just disables dropout, batchnorm etc
Note that AutoModel.from_config does not call .eval() by default, but .from_pretrained does.

If the semantic encoder is always frozen, which is the case here, what we should do:

  1. let's not put eval/ train in the init
  2. use torch.no_grad() in the in forward when infering it. This will save memory by not storing actications anymore.

When we wan't to freeze parameters, we'd do: self.semantic_model.requires_grad_(False)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry!

For reference (since things are split between HF Hub and GitHub), this is where they freeze the semantic encoder during training, and their encode uses torch.no_grad() for inference.

Comment thread src/transformers/models/xcodec2/modular_xcodec2.py Outdated
Comment thread src/transformers/models/xcodec2/modular_xcodec2.py Outdated
Comment thread src/transformers/models/xcodec2/modular_xcodec2.py
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, dac, higgs_audio_v2_tokenizer, pe_audio, qwen2_5_omni, seamless_m4t, wav2vec2_bert, xcodec, xcodec2

@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44178&sha=4c6f00

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants