Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions pkg/hfutil/modelconfig/testdata/whisper_large_v3_turbo.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"_name_or_path": "/raid/yoach/tmp_whisper_turbo",
"activation_dropout": 0.0,
"activation_function": "gelu",
"apply_spec_augment": false,
"architectures": [
"WhisperForConditionalGeneration"
],
"attention_dropout": 0.0,
"begin_suppress_tokens": [
220,
50256
],
"bos_token_id": 50257,
"classifier_proj_size": 256,
"d_model": 1280,
"decoder_attention_heads": 20,
"decoder_ffn_dim": 5120,
"decoder_layerdrop": 0.0,
"decoder_layers": 4,
"decoder_start_token_id": 50258,
"dropout": 0.0,
"encoder_attention_heads": 20,
"encoder_ffn_dim": 5120,
"encoder_layerdrop": 0.0,
"encoder_layers": 32,
"eos_token_id": 50257,
"init_std": 0.02,
"is_encoder_decoder": true,
"mask_feature_length": 10,
"mask_feature_min_masks": 0,
"mask_feature_prob": 0.0,
"mask_time_length": 10,
"mask_time_min_masks": 2,
"mask_time_prob": 0.05,
"max_source_positions": 1500,
"max_target_positions": 448,
"median_filter_width": 7,
"model_type": "whisper",
"num_hidden_layers": 32,
"num_mel_bins": 128,
"pad_token_id": 50257,
"scale_embedding": false,
"torch_dtype": "float16",
"transformers_version": "4.46.0.dev0",
"use_cache": true,
"use_weighted_layer_sum": false,
"vocab_size": 51866
}
203 changes: 203 additions & 0 deletions pkg/hfutil/modelconfig/whisper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package modelconfig

import (
"encoding/json"
"fmt"
"os"
)

// WhisperConfig defines the configuration for Whisper speech recognition models
// (e.g., openai/whisper-large-v3, openai/whisper-large-v3-turbo).
//
// Whisper is an encoder-decoder Transformer where the encoder consumes log-Mel
// spectrogram features and the decoder produces text tokens. As a result the
// config carries separate dimensions for the encoder and decoder stacks rather
// than the single num_hidden_layers / hidden_size pair used by causal LMs.
type WhisperConfig struct {
BaseModelConfig

// Shared model dimensions
DModel int `json:"d_model"`
VocabSize int `json:"vocab_size"`

// Encoder dimensions
EncoderLayers int `json:"encoder_layers"`
EncoderAttentionHeads int `json:"encoder_attention_heads"`
EncoderFfnDim int `json:"encoder_ffn_dim"`

// Decoder dimensions
DecoderLayers int `json:"decoder_layers"`
DecoderAttentionHeads int `json:"decoder_attention_heads"`
DecoderFfnDim int `json:"decoder_ffn_dim"`

// Audio / position limits
NumMelBins int `json:"num_mel_bins"`
MaxSourcePositions int `json:"max_source_positions"`
MaxTargetPositions int `json:"max_target_positions"`

// Special tokens
BosTokenId int `json:"bos_token_id"`
EosTokenId int `json:"eos_token_id"`
PadTokenId int `json:"pad_token_id"`
DecoderStartTokenId int `json:"decoder_start_token_id"`
ClassifierProjSize int `json:"classifier_proj_size"`

// Activation / regularization
ActivationFunction string `json:"activation_function"`
ActivationDropout float64 `json:"activation_dropout"`
AttentionDropout float64 `json:"attention_dropout"`
Dropout float64 `json:"dropout"`
EncoderLayerdrop float64 `json:"encoder_layerdrop"`
DecoderLayerdrop float64 `json:"decoder_layerdrop"`
InitStd float64 `json:"init_std"`

// Misc options
IsEncoderDecoder bool `json:"is_encoder_decoder"`
ScaleEmbedding bool `json:"scale_embedding"`
UseCache bool `json:"use_cache"`
UseWeightedLayerSum bool `json:"use_weighted_layer_sum"`
NumHiddenLayers int `json:"num_hidden_layers"`
}

// LoadWhisperConfig loads a Whisper model configuration from a JSON file.
func LoadWhisperConfig(configPath string) (*WhisperConfig, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read Whisper config file '%s': %w", configPath, err)
}

var config WhisperConfig
if err := json.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse Whisper config JSON from '%s': %w", configPath, err)
}

config.ConfigPath = configPath

if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid Whisper configuration in '%s': %w", configPath, err)
}

return &config, nil
}

// Validate checks if the Whisper configuration is internally consistent.
func (c *WhisperConfig) Validate() error {
if c.DModel <= 0 {
return fmt.Errorf("d_model must be positive, got %d", c.DModel)
}
if c.EncoderLayers <= 0 {
return fmt.Errorf("encoder_layers must be positive, got %d", c.EncoderLayers)
}
if c.DecoderLayers <= 0 {
return fmt.Errorf("decoder_layers must be positive, got %d", c.DecoderLayers)
}
if c.EncoderAttentionHeads <= 0 {
return fmt.Errorf("encoder_attention_heads must be positive, got %d", c.EncoderAttentionHeads)
}
if c.DecoderAttentionHeads <= 0 {
return fmt.Errorf("decoder_attention_heads must be positive, got %d", c.DecoderAttentionHeads)
}
if c.VocabSize <= 0 {
return fmt.Errorf("vocab_size must be positive, got %d", c.VocabSize)
}
if c.MaxTargetPositions <= 0 {
return fmt.Errorf("max_target_positions must be positive, got %d", c.MaxTargetPositions)
}
if c.MaxSourcePositions <= 0 {
return fmt.Errorf("max_source_positions must be positive, got %d", c.MaxSourcePositions)
}
return nil
}

// Implementation of the HuggingFaceModel interface

// GetParameterCount returns the total number of parameters in the model.
// It first tries to read the precise count from accompanying safetensors
// files, and falls back to a hard-coded value for the well-known Whisper
// checkpoints.
func (c *WhisperConfig) GetParameterCount() int64 {
count, err := FindAndParseSafetensors(c.ConfigPath)
if err == nil {
return count
}

fmt.Printf("Warning: failed to get parameter count from safetensors: %v\n", err)

// Hard-coded counts for known OpenAI Whisper checkpoints. Whisper sizes
// are determined by (encoder_layers, decoder_layers, d_model).
switch {
case c.EncoderLayers == 32 && c.DecoderLayers == 4 && c.DModel == 1280:
return 809_000_000 // whisper-large-v3-turbo (~809M)
case c.EncoderLayers == 32 && c.DecoderLayers == 32 && c.DModel == 1280:
return 1_550_000_000 // whisper-large / large-v2 / large-v3 (~1.55B)
case c.EncoderLayers == 24 && c.DecoderLayers == 24 && c.DModel == 1024:
return 769_000_000 // whisper-medium (~769M)
case c.EncoderLayers == 12 && c.DecoderLayers == 12 && c.DModel == 768:
return 244_000_000 // whisper-small (~244M)
case c.EncoderLayers == 6 && c.DecoderLayers == 6 && c.DModel == 512:
return 74_000_000 // whisper-base (~74M)
case c.EncoderLayers == 4 && c.DecoderLayers == 4 && c.DModel == 384:
return 39_000_000 // whisper-tiny (~39M)
}

return 0
Copy link
Copy Markdown
Collaborator

@pallasathena92 pallasathena92 Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use a fallback method rather than return 0 here. Detail is in other comment.

}

// GetTransformerVersion returns the transformers library version.
func (c *WhisperConfig) GetTransformerVersion() string {
return c.TransformerVersion
}

// GetQuantizationType returns the quantization method used (if any).
func (c *WhisperConfig) GetQuantizationType() string {
return ""
}

// GetArchitecture returns the model architecture.
func (c *WhisperConfig) GetArchitecture() string {
if len(c.Architectures) > 0 {
return c.Architectures[0]
}
return "WhisperForConditionalGeneration"
}

// GetModelType returns the model type.
func (c *WhisperConfig) GetModelType() string {
return c.ModelType
}

// GetContextLength returns the maximum context length.
//
// For Whisper this is the decoder token budget (max_target_positions, 448
// for every published OpenAI checkpoint), which is what callers use when
// sizing requests against the OpenAI-compatible API.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whisper models Input and output live in completely different spaces:

  • Input is audio, represented as mel spectrogram frames. 1 frame ≠ 1 token. 1500 frames ≈ 30 seconds of audio.
  • Output is text tokens from a BPE vocabulary (51,866 entries). 448 tokens ≈ a few paragraphs of transcribed text.

The OpenAI transcription API (/v1/audio/transcriptions) doesn't have a max_tokens parameter, the caller won't use this size to set response size.

might let's use audio input size here?
// GetContextLength returns max_source_positions— the encoder's input token
// limit (1500).
//
// Note: Whisper is an encoder-decoder model where input (audio) and output
// (text) are in different modalities. The encoder's input capacity is governed
// by max_source_positions (1500 frames ≈ 30s of audio), which is independent
// of this value. We return the encoder limit as it represents the audio generation capacity.

func (c *WhisperConfig) GetContextLength() int {
return c.MaxTargetPositions
}

// GetModelSizeBytes returns the estimated size of the model in bytes.
func (c *WhisperConfig) GetModelSizeBytes() int64 {
return EstimateModelSizeBytes(c.GetParameterCount(), c.GetTorchDtype())
}

// GetTorchDtype returns the torch data type used by the model.
func (c *WhisperConfig) GetTorchDtype() string {
return c.TorchDtype
}

// HasVision returns false. Whisper is an audio model, not a vision model.
func (c *WhisperConfig) HasVision() bool {
return false
}

// IsEmbedding returns false since Whisper is a generative ASR model.
func (c *WhisperConfig) IsEmbedding() bool {
return false
}

// Register the Whisper model handler.
func init() {
RegisterModelLoader("whisper", func(configPath string) (HuggingFaceModel, error) {
return LoadWhisperConfig(configPath)
})
}
115 changes: 115 additions & 0 deletions pkg/hfutil/modelconfig/whisper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package modelconfig

import (
"path/filepath"
"testing"
)

func TestWhisperConfig(t *testing.T) {
configPath := filepath.Join("testdata", "whisper_large_v3_turbo.json")

config, err := LoadModelConfig(configPath)
if err != nil {
t.Fatalf("Failed to load Whisper config: %v", err)
}

if config.GetModelType() != "whisper" {
t.Errorf("Incorrect model type, expected 'whisper', got '%s'", config.GetModelType())
}

if config.GetArchitecture() != "WhisperForConditionalGeneration" {
t.Errorf("Incorrect architecture, expected 'WhisperForConditionalGeneration', got '%s'", config.GetArchitecture())
}

whisperConfig, ok := config.(*WhisperConfig)
if !ok {
t.Fatalf("Failed to convert to WhisperConfig")
}

if whisperConfig.DModel != 1280 {
t.Errorf("Incorrect d_model, expected 1280, got %d", whisperConfig.DModel)
}

if whisperConfig.EncoderLayers != 32 {
t.Errorf("Incorrect encoder_layers, expected 32, got %d", whisperConfig.EncoderLayers)
}

if whisperConfig.DecoderLayers != 4 {
t.Errorf("Incorrect decoder_layers, expected 4, got %d", whisperConfig.DecoderLayers)
}

if whisperConfig.NumMelBins != 128 {
t.Errorf("Incorrect num_mel_bins, expected 128, got %d", whisperConfig.NumMelBins)
}

if whisperConfig.MaxSourcePositions != 1500 {
t.Errorf("Incorrect max_source_positions, expected 1500, got %d", whisperConfig.MaxSourcePositions)
}

if whisperConfig.MaxTargetPositions != 448 {
t.Errorf("Incorrect max_target_positions, expected 448, got %d", whisperConfig.MaxTargetPositions)
}

if whisperConfig.VocabSize != 51866 {
t.Errorf("Incorrect vocab_size, expected 51866, got %d", whisperConfig.VocabSize)
}

if !whisperConfig.IsEncoderDecoder {
t.Errorf("Expected is_encoder_decoder to be true")
}

// Context length should be the decoder token budget.
if config.GetContextLength() != 448 {
t.Errorf("Incorrect context length, expected 448, got %d", config.GetContextLength())
}

if config.GetTorchDtype() != "float16" {
t.Errorf("Incorrect torch_dtype, expected 'float16', got '%s'", config.GetTorchDtype())
}

// whisper-large-v3-turbo has ~809M parameters.
paramCount := config.GetParameterCount()
expectedCount := int64(809_000_000)
if paramCount != expectedCount {
t.Errorf("Incorrect parameter count, expected %s, got %s",
FormatParamCount(expectedCount), FormatParamCount(paramCount))
}

// float16 → 2 bytes per parameter.
modelSize := config.GetModelSizeBytes()
expectedSize := int64(809_000_000 * 2)
if modelSize != expectedSize {
t.Errorf("Incorrect model size, expected %s, got %s",
FormatSize(expectedSize), FormatSize(modelSize))
}

if config.HasVision() {
t.Errorf("Whisper should not report HasVision() == true")
}
}

func TestLoadModelWithWhisper(t *testing.T) {
configPath := filepath.Join("testdata", "whisper_large_v3_turbo.json")

model, err := LoadModelConfig(configPath)
if err != nil {
t.Fatalf("Failed to load Whisper model through generic loader: %v", err)
}

if model.GetModelType() != "whisper" {
t.Errorf("Expected model type 'whisper', got '%s'", model.GetModelType())
}

if model.GetContextLength() != 448 {
t.Errorf("Expected context length 448, got %d", model.GetContextLength())
}

paramCount := model.GetParameterCount()
expectedCount := int64(809_000_000)
if paramCount != expectedCount {
t.Errorf("Expected parameter count %s, got %s",
FormatParamCount(expectedCount), FormatParamCount(paramCount))
}

t.Logf("Whisper model parameter count via generic loader: %s", FormatParamCount(paramCount))
}
Loading