-
Notifications
You must be signed in to change notification settings - Fork 79
[Misc] add Whisper config parser #574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| { | ||
| "_name_or_path": "/raid/yoach/tmp_whisper_turbo", | ||
| "activation_dropout": 0.0, | ||
| "activation_function": "gelu", | ||
| "apply_spec_augment": false, | ||
| "architectures": [ | ||
| "WhisperForConditionalGeneration" | ||
| ], | ||
| "attention_dropout": 0.0, | ||
| "begin_suppress_tokens": [ | ||
| 220, | ||
| 50256 | ||
| ], | ||
| "bos_token_id": 50257, | ||
| "classifier_proj_size": 256, | ||
| "d_model": 1280, | ||
| "decoder_attention_heads": 20, | ||
| "decoder_ffn_dim": 5120, | ||
| "decoder_layerdrop": 0.0, | ||
| "decoder_layers": 4, | ||
| "decoder_start_token_id": 50258, | ||
| "dropout": 0.0, | ||
| "encoder_attention_heads": 20, | ||
| "encoder_ffn_dim": 5120, | ||
| "encoder_layerdrop": 0.0, | ||
| "encoder_layers": 32, | ||
| "eos_token_id": 50257, | ||
| "init_std": 0.02, | ||
| "is_encoder_decoder": true, | ||
| "mask_feature_length": 10, | ||
| "mask_feature_min_masks": 0, | ||
| "mask_feature_prob": 0.0, | ||
| "mask_time_length": 10, | ||
| "mask_time_min_masks": 2, | ||
| "mask_time_prob": 0.05, | ||
| "max_source_positions": 1500, | ||
| "max_target_positions": 448, | ||
| "median_filter_width": 7, | ||
| "model_type": "whisper", | ||
| "num_hidden_layers": 32, | ||
| "num_mel_bins": 128, | ||
| "pad_token_id": 50257, | ||
| "scale_embedding": false, | ||
| "torch_dtype": "float16", | ||
| "transformers_version": "4.46.0.dev0", | ||
| "use_cache": true, | ||
| "use_weighted_layer_sum": false, | ||
| "vocab_size": 51866 | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| package modelconfig | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "fmt" | ||
| "os" | ||
| ) | ||
|
|
||
| // WhisperConfig defines the configuration for Whisper speech recognition models | ||
| // (e.g., openai/whisper-large-v3, openai/whisper-large-v3-turbo). | ||
| // | ||
| // Whisper is an encoder-decoder Transformer where the encoder consumes log-Mel | ||
| // spectrogram features and the decoder produces text tokens. As a result the | ||
| // config carries separate dimensions for the encoder and decoder stacks rather | ||
| // than the single num_hidden_layers / hidden_size pair used by causal LMs. | ||
| type WhisperConfig struct { | ||
| BaseModelConfig | ||
|
|
||
| // Shared model dimensions | ||
| DModel int `json:"d_model"` | ||
| VocabSize int `json:"vocab_size"` | ||
|
|
||
| // Encoder dimensions | ||
| EncoderLayers int `json:"encoder_layers"` | ||
| EncoderAttentionHeads int `json:"encoder_attention_heads"` | ||
| EncoderFfnDim int `json:"encoder_ffn_dim"` | ||
|
|
||
| // Decoder dimensions | ||
| DecoderLayers int `json:"decoder_layers"` | ||
| DecoderAttentionHeads int `json:"decoder_attention_heads"` | ||
| DecoderFfnDim int `json:"decoder_ffn_dim"` | ||
|
|
||
| // Audio / position limits | ||
| NumMelBins int `json:"num_mel_bins"` | ||
| MaxSourcePositions int `json:"max_source_positions"` | ||
| MaxTargetPositions int `json:"max_target_positions"` | ||
|
|
||
| // Special tokens | ||
| BosTokenId int `json:"bos_token_id"` | ||
| EosTokenId int `json:"eos_token_id"` | ||
| PadTokenId int `json:"pad_token_id"` | ||
| DecoderStartTokenId int `json:"decoder_start_token_id"` | ||
| ClassifierProjSize int `json:"classifier_proj_size"` | ||
|
|
||
| // Activation / regularization | ||
| ActivationFunction string `json:"activation_function"` | ||
| ActivationDropout float64 `json:"activation_dropout"` | ||
| AttentionDropout float64 `json:"attention_dropout"` | ||
| Dropout float64 `json:"dropout"` | ||
| EncoderLayerdrop float64 `json:"encoder_layerdrop"` | ||
| DecoderLayerdrop float64 `json:"decoder_layerdrop"` | ||
| InitStd float64 `json:"init_std"` | ||
|
|
||
| // Misc options | ||
| IsEncoderDecoder bool `json:"is_encoder_decoder"` | ||
| ScaleEmbedding bool `json:"scale_embedding"` | ||
| UseCache bool `json:"use_cache"` | ||
| UseWeightedLayerSum bool `json:"use_weighted_layer_sum"` | ||
| NumHiddenLayers int `json:"num_hidden_layers"` | ||
| } | ||
|
|
||
| // LoadWhisperConfig loads a Whisper model configuration from a JSON file. | ||
| func LoadWhisperConfig(configPath string) (*WhisperConfig, error) { | ||
| data, err := os.ReadFile(configPath) | ||
| if err != nil { | ||
| return nil, fmt.Errorf("failed to read Whisper config file '%s': %w", configPath, err) | ||
| } | ||
|
|
||
| var config WhisperConfig | ||
| if err := json.Unmarshal(data, &config); err != nil { | ||
| return nil, fmt.Errorf("failed to parse Whisper config JSON from '%s': %w", configPath, err) | ||
| } | ||
|
|
||
| config.ConfigPath = configPath | ||
|
|
||
| if err := config.Validate(); err != nil { | ||
| return nil, fmt.Errorf("invalid Whisper configuration in '%s': %w", configPath, err) | ||
| } | ||
|
|
||
| return &config, nil | ||
| } | ||
|
|
||
| // Validate checks if the Whisper configuration is internally consistent. | ||
| func (c *WhisperConfig) Validate() error { | ||
| if c.DModel <= 0 { | ||
| return fmt.Errorf("d_model must be positive, got %d", c.DModel) | ||
| } | ||
| if c.EncoderLayers <= 0 { | ||
| return fmt.Errorf("encoder_layers must be positive, got %d", c.EncoderLayers) | ||
| } | ||
| if c.DecoderLayers <= 0 { | ||
| return fmt.Errorf("decoder_layers must be positive, got %d", c.DecoderLayers) | ||
| } | ||
| if c.EncoderAttentionHeads <= 0 { | ||
| return fmt.Errorf("encoder_attention_heads must be positive, got %d", c.EncoderAttentionHeads) | ||
| } | ||
| if c.DecoderAttentionHeads <= 0 { | ||
| return fmt.Errorf("decoder_attention_heads must be positive, got %d", c.DecoderAttentionHeads) | ||
| } | ||
| if c.VocabSize <= 0 { | ||
| return fmt.Errorf("vocab_size must be positive, got %d", c.VocabSize) | ||
| } | ||
| if c.MaxTargetPositions <= 0 { | ||
| return fmt.Errorf("max_target_positions must be positive, got %d", c.MaxTargetPositions) | ||
| } | ||
| if c.MaxSourcePositions <= 0 { | ||
| return fmt.Errorf("max_source_positions must be positive, got %d", c.MaxSourcePositions) | ||
| } | ||
| return nil | ||
| } | ||
|
|
||
| // Implementation of the HuggingFaceModel interface | ||
|
|
||
| // GetParameterCount returns the total number of parameters in the model. | ||
| // It first tries to read the precise count from accompanying safetensors | ||
| // files, and falls back to a hard-coded value for the well-known Whisper | ||
| // checkpoints. | ||
| func (c *WhisperConfig) GetParameterCount() int64 { | ||
| count, err := FindAndParseSafetensors(c.ConfigPath) | ||
| if err == nil { | ||
| return count | ||
| } | ||
|
|
||
| fmt.Printf("Warning: failed to get parameter count from safetensors: %v\n", err) | ||
|
|
||
| // Hard-coded counts for known OpenAI Whisper checkpoints. Whisper sizes | ||
| // are determined by (encoder_layers, decoder_layers, d_model). | ||
| switch { | ||
| case c.EncoderLayers == 32 && c.DecoderLayers == 4 && c.DModel == 1280: | ||
| return 809_000_000 // whisper-large-v3-turbo (~809M) | ||
| case c.EncoderLayers == 32 && c.DecoderLayers == 32 && c.DModel == 1280: | ||
| return 1_550_000_000 // whisper-large / large-v2 / large-v3 (~1.55B) | ||
| case c.EncoderLayers == 24 && c.DecoderLayers == 24 && c.DModel == 1024: | ||
| return 769_000_000 // whisper-medium (~769M) | ||
| case c.EncoderLayers == 12 && c.DecoderLayers == 12 && c.DModel == 768: | ||
| return 244_000_000 // whisper-small (~244M) | ||
| case c.EncoderLayers == 6 && c.DecoderLayers == 6 && c.DModel == 512: | ||
| return 74_000_000 // whisper-base (~74M) | ||
| case c.EncoderLayers == 4 && c.DecoderLayers == 4 && c.DModel == 384: | ||
| return 39_000_000 // whisper-tiny (~39M) | ||
| } | ||
|
|
||
| return 0 | ||
| } | ||
|
|
||
| // GetTransformerVersion returns the transformers library version. | ||
| func (c *WhisperConfig) GetTransformerVersion() string { | ||
| return c.TransformerVersion | ||
| } | ||
|
|
||
| // GetQuantizationType returns the quantization method used (if any). | ||
| func (c *WhisperConfig) GetQuantizationType() string { | ||
| return "" | ||
| } | ||
|
|
||
| // GetArchitecture returns the model architecture. | ||
| func (c *WhisperConfig) GetArchitecture() string { | ||
| if len(c.Architectures) > 0 { | ||
| return c.Architectures[0] | ||
| } | ||
| return "WhisperForConditionalGeneration" | ||
| } | ||
|
|
||
| // GetModelType returns the model type. | ||
| func (c *WhisperConfig) GetModelType() string { | ||
| return c.ModelType | ||
| } | ||
|
|
||
| // GetContextLength returns the maximum context length. | ||
| // | ||
| // For Whisper this is the decoder token budget (max_target_positions, 448 | ||
| // for every published OpenAI checkpoint), which is what callers use when | ||
| // sizing requests against the OpenAI-compatible API. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whisper models Input and output live in completely different spaces:
The OpenAI transcription API (/v1/audio/transcriptions) doesn't have a might let's use audio input size here? |
||
| func (c *WhisperConfig) GetContextLength() int { | ||
| return c.MaxTargetPositions | ||
| } | ||
|
|
||
| // GetModelSizeBytes returns the estimated size of the model in bytes. | ||
| func (c *WhisperConfig) GetModelSizeBytes() int64 { | ||
| return EstimateModelSizeBytes(c.GetParameterCount(), c.GetTorchDtype()) | ||
| } | ||
|
|
||
| // GetTorchDtype returns the torch data type used by the model. | ||
| func (c *WhisperConfig) GetTorchDtype() string { | ||
| return c.TorchDtype | ||
| } | ||
|
|
||
| // HasVision returns false. Whisper is an audio model, not a vision model. | ||
| func (c *WhisperConfig) HasVision() bool { | ||
| return false | ||
| } | ||
|
|
||
| // IsEmbedding returns false since Whisper is a generative ASR model. | ||
| func (c *WhisperConfig) IsEmbedding() bool { | ||
| return false | ||
| } | ||
|
|
||
| // Register the Whisper model handler. | ||
| func init() { | ||
| RegisterModelLoader("whisper", func(configPath string) (HuggingFaceModel, error) { | ||
| return LoadWhisperConfig(configPath) | ||
| }) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| package modelconfig | ||
|
|
||
| import ( | ||
| "path/filepath" | ||
| "testing" | ||
| ) | ||
|
|
||
| func TestWhisperConfig(t *testing.T) { | ||
| configPath := filepath.Join("testdata", "whisper_large_v3_turbo.json") | ||
|
|
||
| config, err := LoadModelConfig(configPath) | ||
| if err != nil { | ||
| t.Fatalf("Failed to load Whisper config: %v", err) | ||
| } | ||
|
|
||
| if config.GetModelType() != "whisper" { | ||
| t.Errorf("Incorrect model type, expected 'whisper', got '%s'", config.GetModelType()) | ||
| } | ||
|
|
||
| if config.GetArchitecture() != "WhisperForConditionalGeneration" { | ||
| t.Errorf("Incorrect architecture, expected 'WhisperForConditionalGeneration', got '%s'", config.GetArchitecture()) | ||
| } | ||
|
|
||
| whisperConfig, ok := config.(*WhisperConfig) | ||
| if !ok { | ||
| t.Fatalf("Failed to convert to WhisperConfig") | ||
| } | ||
|
|
||
| if whisperConfig.DModel != 1280 { | ||
| t.Errorf("Incorrect d_model, expected 1280, got %d", whisperConfig.DModel) | ||
| } | ||
|
|
||
| if whisperConfig.EncoderLayers != 32 { | ||
| t.Errorf("Incorrect encoder_layers, expected 32, got %d", whisperConfig.EncoderLayers) | ||
| } | ||
|
|
||
| if whisperConfig.DecoderLayers != 4 { | ||
| t.Errorf("Incorrect decoder_layers, expected 4, got %d", whisperConfig.DecoderLayers) | ||
| } | ||
|
|
||
| if whisperConfig.NumMelBins != 128 { | ||
| t.Errorf("Incorrect num_mel_bins, expected 128, got %d", whisperConfig.NumMelBins) | ||
| } | ||
|
|
||
| if whisperConfig.MaxSourcePositions != 1500 { | ||
| t.Errorf("Incorrect max_source_positions, expected 1500, got %d", whisperConfig.MaxSourcePositions) | ||
| } | ||
|
|
||
| if whisperConfig.MaxTargetPositions != 448 { | ||
| t.Errorf("Incorrect max_target_positions, expected 448, got %d", whisperConfig.MaxTargetPositions) | ||
| } | ||
|
|
||
| if whisperConfig.VocabSize != 51866 { | ||
| t.Errorf("Incorrect vocab_size, expected 51866, got %d", whisperConfig.VocabSize) | ||
| } | ||
|
|
||
| if !whisperConfig.IsEncoderDecoder { | ||
| t.Errorf("Expected is_encoder_decoder to be true") | ||
| } | ||
|
|
||
| // Context length should be the decoder token budget. | ||
| if config.GetContextLength() != 448 { | ||
| t.Errorf("Incorrect context length, expected 448, got %d", config.GetContextLength()) | ||
| } | ||
|
|
||
| if config.GetTorchDtype() != "float16" { | ||
| t.Errorf("Incorrect torch_dtype, expected 'float16', got '%s'", config.GetTorchDtype()) | ||
| } | ||
|
|
||
| // whisper-large-v3-turbo has ~809M parameters. | ||
| paramCount := config.GetParameterCount() | ||
| expectedCount := int64(809_000_000) | ||
| if paramCount != expectedCount { | ||
| t.Errorf("Incorrect parameter count, expected %s, got %s", | ||
| FormatParamCount(expectedCount), FormatParamCount(paramCount)) | ||
| } | ||
|
|
||
| // float16 → 2 bytes per parameter. | ||
| modelSize := config.GetModelSizeBytes() | ||
| expectedSize := int64(809_000_000 * 2) | ||
| if modelSize != expectedSize { | ||
| t.Errorf("Incorrect model size, expected %s, got %s", | ||
| FormatSize(expectedSize), FormatSize(modelSize)) | ||
| } | ||
|
|
||
| if config.HasVision() { | ||
| t.Errorf("Whisper should not report HasVision() == true") | ||
| } | ||
| } | ||
|
|
||
| func TestLoadModelWithWhisper(t *testing.T) { | ||
| configPath := filepath.Join("testdata", "whisper_large_v3_turbo.json") | ||
|
|
||
| model, err := LoadModelConfig(configPath) | ||
| if err != nil { | ||
| t.Fatalf("Failed to load Whisper model through generic loader: %v", err) | ||
| } | ||
|
|
||
| if model.GetModelType() != "whisper" { | ||
| t.Errorf("Expected model type 'whisper', got '%s'", model.GetModelType()) | ||
| } | ||
|
|
||
| if model.GetContextLength() != 448 { | ||
| t.Errorf("Expected context length 448, got %d", model.GetContextLength()) | ||
| } | ||
|
|
||
| paramCount := model.GetParameterCount() | ||
| expectedCount := int64(809_000_000) | ||
| if paramCount != expectedCount { | ||
| t.Errorf("Expected parameter count %s, got %s", | ||
| FormatParamCount(expectedCount), FormatParamCount(paramCount)) | ||
| } | ||
|
|
||
| t.Logf("Whisper model parameter count via generic loader: %s", FormatParamCount(paramCount)) | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use a fallback method rather than return 0 here. Detail is in other comment.