Skip to content

Add qwen3 tts#44517

Open
ShahVandit wants to merge 28 commits intohuggingface:mainfrom
ShahVandit:add-qwen3-tts
Open

Add qwen3 tts#44517
ShahVandit wants to merge 28 commits intohuggingface:mainfrom
ShahVandit:add-qwen3-tts

Conversation

@ShahVandit
Copy link
Copy Markdown

What does this PR do?

Adds Qwen3-TTS, a series of text-to-speech models by the Qwen team (Alibaba Group), to Transformers.

Architecture:

  • Qwen3TTSForConditionalGeneration — text to multi-codebook speech codes (talker)
  • Qwen3TTSTokenizerV2Model (12Hz) and Qwen3TTSTokenizerV1Model (25Hz) — codes to audio waveform
  • Qwen3TTSProcessor — text preprocessing

Features: voice presets, voice design via natural language, batch inference, 10 languages

Paper: Qwen3-TTS Technical Report

Before submitting

  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

@eustlb @ebezzam @vasqu

Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ShahVandit thanks a bunch for your contribution!

It's great that you're using modular 👏 I've put a few comments to improve its usage because there are a lot of lines of code here which I think we'll be able to iteratively reduce.

A general principle when adding a new model to Transformers is that we want to be very careful when adding new modeling components. If it exists elsewhere, we want to use modular to inherit from those existing components. This will significantly reduce the number of lines in the modular file! Moreover, we don't want to keep unused code paths. I pointed out a few cases. When going through your modular file, ask yourself (and a coding agent helps a lot here for navigating the quite large code base!), whether (1) this is being used in the final modeling code and (2) does something similar exist in another model.

Moreover, one file that is missing is a script to convert the existing QwenTTS checkpoints to Transformers-compatible ones. Here are some recent examples of conversion scripts:

If it helps, below is my typical workflow when it comes to model integration:

  1. Write an integration test(s). This will be our sanity check to make sure that the modeling code (generated from modular) doesn't deviate from the original. For example, VibeVoice ASR.
  2. Write a reproducer script that generates expected outputs with the original checkpoint + code. For example: VIbeVoiceASR, Qwen 3 ASR. We will add a link to this reproducer in the integration test like this, as we won't add this file to the repo. If you look at the VibeVoice and Qwen3ASR reproducers, they write the expected outputs directly in the repo as JSON files (rather than having to copy and pasted EXPECTED_OUTPUT lists).
  3. Get to a functional modular and conversion script. Note that the modular file can be used to generate the modeling AND configuration files, see how Qwen3ASR is using existing configs in the library here.
  4. Iteratively prune, clean, and conform the modular file to Transformers conventions, while running the integration test to ensure that you aren't deviating from the original model's outputs.
RUN_SLOW=1 pytest tests/models/qwen3_tts/test_modeling_qwen3_ttspy::Qwen3TTSForConditionalGenerationIntegrationTest

Coding agents are very helpful for this process by giving targeted tasks and pointing to similar models/files in the repo.

I hope that helps! Let me know if you have any questions and thanks for your valuable contribution 🤗

class Qwen3TTSIntegrationTest(unittest.TestCase):
"""Integration tests for Qwen3-TTS (require real weights, run with --slow)."""

model_id = "Qwen/Qwen3-TTS-12Hz-0.6B-Base"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the integration test, we will rather test your converted checkpoint, instead of the original. See Qwen3ASR.

When the model is ready merge, then we may contact the original Qwen team, to upload a Transformers compatible version to their org. For example, with VibeVoice ASR:


@slow
@require_torch_accelerator
def test_small_model_integration_text_to_codes(self):
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take a look these examples for recent approaches in writing the integration test:

we can limit to generate around 50-100 tokens

Comment on lines +139 to +140
class Qwen3TTSTokenizerV2LayerScale(MimiLayerScale):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this module being used in the generated modeling? If so, we can remove.

Comment on lines +150 to +159
class Qwen3TTSTokenizerV1DiTCodecEmbedding(DiTCodecEmbedding):
pass


class Qwen3TTSTokenizerV1DiTMLP(DiTMLP):
pass


class Qwen3TTSTokenizerV1DiTAttention(DiTAttention):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, I don't see these modules being used?

Comment on lines +166 to +175
class Qwen3TTSTokenizerV1DiTTimestepEmbedding(DiTTimestepEmbedding):
pass


class Qwen3TTSTokenizerV1SinusoidsPositionEmbedding(SinusoidsPositionEmbedding):
pass


class Qwen3TTSTokenizerV1AdaLayerNormZero_Final(Qwen2_5_OmniAdaLayerNormZero_Final):
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, are these being used?

decoder_past_key_values: Cache | None = None


class Qwen3TTSConv1dPaddingCache:
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to update the modular in the relevant place to use this newer padding cache object:

Comment thread src/transformers/models/qwen3_tts/modeling_qwen3_tts.py Outdated
Comment thread src/transformers/models/qwen3_tts/modeling_qwen3_tts.py Outdated
Comment thread src/transformers/models/qwen3_tts/modeling_qwen3_tts.py Outdated
Comment thread src/transformers/models/qwen3_tts/modeling_qwen3_tts.py Outdated
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry some of my comments went into the modeling file when I was jumping in between the modeling and modular!

return quantized.transpose(1, 2)


class Qwen3TTSTokenizerV2ResidualVectorQuantization(nn.Module):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry meant to put the comment here! -> Can we use RVQ from DAC or Mimi?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ebezzam , I tried using MimiModel for the V2 encoder but hit a converter issue.

The modular converter renames all Mimi* references to Qwen3TTSTokenizerV2* based on prefix voting. So MimiEncoder inside MimiModel.init becomes Qwen3TTSTokenizerV2Encoder, which is the same name as the class itself, causing infinite recursion.

I tried renaming the class to Qwen3TTSTokenizerV2AudioEncoder to avoid the collision. That fixed the recursion, but then MimiTransformerModel got renamed to Qwen3TTSTokenizerV2TransformerModel, which clashes with our Code2Wav decoder transformer that has the same name but expects a completely different config.

The RVQ classes (EuclideanCodebook, VectorQuantization, etc.) inherit from Mimi fine since there's no name collision there. The problem is specifically with MimiModel because it creates internal components whose names clash with classes we already define.

Is there a recommended way to handle this, or should we keep the V2 encoder standalone for now?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ShahVandit for the detailed explanation. A couple points:

  • To simplify / compartmentalize things, we can make the QwenTTS Tokenizer(s) their own model(s). Similar to how Mimi is its own model and is used as a subconfig/model for Kyutai's STT (see here and here). Here are other examples: (VibeVoice tokenizer, VibeVoice ASR) and (Higgs tokenizer, Higgs model). That may help with the clashing names in modular, and also make the modular more readable for each model!
  • Similarly from what I understand in the paper, TokenizerV2 and TokenizerV1 are meant to be two types of tokenizers? A single codebook one (Qwen3-TTS-Tokenizer-25Hz) and multi-codebook (Qwen3-TTS-Tokenizer-12Hz). So let's use a more meaningful name for them, e.g. such as Qwen3TTSTokenizerSingleCodebook and Qwen3TTSTokenizerMultiCodebook, and make them two separate models. And from I understand in the paper Qwen3TTSTokenizerSingleCodebook will be able to inherit via modular from Qwen2Audio and Qwen3TTSTokenizerMultiCodebook from Mimi.

So there will be three models in totals, each with their own model folder (with configuration, modular, etc): Qwen3TTSTokenizerSingleCodebook, Qwen3TTSTokenizerMultiCodebook, Qwen3TTS. And the later will be able to use the modeling from the tokenizers, via AutoModel (like this).

Hope that's clear and that it helps!

return torch.log(torch.clamp(x, min=clip_val) * C)


def mel_spectrogram(
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are mel_spectrogram methods within audio_utils. Although we may be able to use the feature extraction from Whisper, as is the case for Qwen ASR (but should be double-checked).

Moreover, we typically bundle the feature extractor and the tokenizer within the processor (as can be seen in the Qwen ASR example).

If a new feature extractor is needed however, it should be in a separate feature_extraction_MODEL.py file. For example: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py

return self.supported_languages

@classmethod
def from_pretrained(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From pretrained methods shouldn't have to be overwritten

return text_embed + codec_embed, tts_pad_embed

@torch.no_grad()
def generate(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attributes = ["tokenizer"]
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")

def __init__(self, tokenizer=None, chat_template=None):
Copy link
Copy Markdown
Contributor

@ebezzam ebezzam Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want to move the feature extraction (what you were doing with computing mel spectrograms) to the processor.

Note that it may even be interesting to generate the processor from the modular file. (See Qwen3ASR)

Comment on lines +84 to +94
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to the tokenizer's batch_decode method.
"""
return self.tokenizer.batch_decode(*args, **kwargs)

def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to the tokenizer's decode method.
"""
return self.tokenizer.decode(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we are passing directly to the tokenizer, we don't need to define these methods.

@ebezzam ebezzam self-assigned this Apr 13, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, qwen3_tts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants