From d84a1c22cc11d9c3b6d33400aaca5d504478daeb Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:15:33 +0200 Subject: [PATCH 1/4] Add model key --- internal/modelkey/modelkey.go | 46 ++++++++++++ internal/modelkey/modelkey_test.go | 114 +++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 internal/modelkey/modelkey.go create mode 100644 internal/modelkey/modelkey_test.go diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go new file mode 100644 index 00000000..9cec0eac --- /dev/null +++ b/internal/modelkey/modelkey.go @@ -0,0 +1,46 @@ +package modelkey + +import ( + "fmt" + "strings" +) + +type ModelKey struct { + Provider string + Publisher string + ModelName string +} + +func ParseModelKey(modelKey string) (*ModelKey, error) { + if modelKey == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + + parts := strings.Split(modelKey, "/") + + // Check for empty parts + for _, part := range parts { + if part == "" { + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } + } + + switch len(parts) { + case 2: + // Format: publisher/model-name (provider defaults to "azureml") + return &ModelKey{ + Provider: "azureml", + Publisher: parts[0], + ModelName: parts[1], + }, nil + case 3: + // Format: provider/publisher/model-name + return &ModelKey{ + Provider: parts[0], + Publisher: parts[1], + ModelName: parts[2], + }, nil + default: + return nil, fmt.Errorf("invalid model key format: %s", modelKey) + } +} diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go new file mode 100644 index 00000000..561447c7 --- /dev/null +++ b/internal/modelkey/modelkey_test.go @@ -0,0 +1,114 @@ +package modelkey + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseModelKey(t *testing.T) { + tests := []struct { + name string + input string + expected *ModelKey + expectError bool + }{ + { + name: "valid format with provider", + input: "custom/openai/gpt-4", + expected: &ModelKey{ + Provider: "custom", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format without provider (defaults to azureml)", + input: "openai/gpt-4", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expectError: false, + }, + { + name: "valid format with azureml provider explicitly", + input: "azureml/microsoft/phi-3", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expectError: false, + }, + { + name: "valid format with hyphens in model name", + input: "cohere/command-r-plus", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expectError: false, + }, + { + name: "valid format with underscores in model name", + input: "ai21/jamba_instruct", + expected: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expectError: false, + }, + { + name: "invalid format with only one part", + input: "gpt-4", + expected: nil, + expectError: true, + }, + { + name: "invalid format with four parts", + input: "provider/publisher/model/extra", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty string", + input: "", + expected: nil, + expectError: true, + }, + { + name: "invalid format with only slashes", + input: "//", + expected: nil, + expectError: true, + }, + { + name: "invalid format with empty parts", + input: "provider//model", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseModelKey(tt.input) + + if tt.expectError { + require.Error(t, err) + require.Nil(t, result) + } else { + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, tt.expected.Provider, result.Provider) + require.Equal(t, tt.expected.Publisher, result.Publisher) + require.Equal(t, tt.expected.ModelName, result.ModelName) + } + }) + } +} From 9a0e37bbf94ccf20770ef1726ce102cabbc26386 Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:18:40 +0200 Subject: [PATCH 2/4] Convert model key to string --- internal/modelkey/modelkey.go | 5 +++ internal/modelkey/modelkey_test.go | 61 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go index 9cec0eac..e58990a7 100644 --- a/internal/modelkey/modelkey.go +++ b/internal/modelkey/modelkey.go @@ -44,3 +44,8 @@ func ParseModelKey(modelKey string) (*ModelKey, error) { return nil, fmt.Errorf("invalid model key format: %s", modelKey) } } + +// String returns the string representation of the ModelKey in the format provider/publisher/model-name +func (mk *ModelKey) String() string { + return fmt.Sprintf("%s/%s/%s", mk.Provider, mk.Publisher, mk.ModelName) +} diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go index 561447c7..ea4583fa 100644 --- a/internal/modelkey/modelkey_test.go +++ b/internal/modelkey/modelkey_test.go @@ -112,3 +112,64 @@ func TestParseModelKey(t *testing.T) { }) } } + +func TestModelKey_String(t *testing.T) { + tests := []struct { + name string + modelKey *ModelKey + expected string + }{ + { + name: "standard format with azureml provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "openai", + ModelName: "gpt-4", + }, + expected: "azureml/openai/gpt-4", + }, + { + name: "custom provider", + modelKey: &ModelKey{ + Provider: "custom", + Publisher: "microsoft", + ModelName: "phi-3", + }, + expected: "custom/microsoft/phi-3", + }, + { + name: "model name with hyphens", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "cohere", + ModelName: "command-r-plus", + }, + expected: "azureml/cohere/command-r-plus", + }, + { + name: "model name with underscores", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "ai21", + ModelName: "jamba_instruct", + }, + expected: "azureml/ai21/jamba_instruct", + }, + { + name: "long provider name", + modelKey: &ModelKey{ + Provider: "custom-provider", + Publisher: "test-publisher", + ModelName: "test-model", + }, + expected: "custom-provider/test-publisher/test-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.modelKey.String() + require.Equal(t, tt.expected, result) + }) + } +} From d9164f069b51283349db20299128cac32f0c2198 Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 12:43:16 +0200 Subject: [PATCH 3/4] Do not validate models for the custom provider --- cmd/run/run.go | 17 +++++++++++++-- cmd/run/run_test.go | 53 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/cmd/run/run.go b/cmd/run/run.go index e380de5b..1fe574b2 100644 --- a/cmd/run/run.go +++ b/cmd/run/run.go @@ -16,6 +16,7 @@ import ( "github.com/MakeNowJust/heredoc" "github.com/briandowns/spinner" "github.com/github/gh-models/internal/azuremodels" + "github.com/github/gh-models/internal/modelkey" "github.com/github/gh-models/internal/sse" "github.com/github/gh-models/pkg/command" "github.com/github/gh-models/pkg/prompt" @@ -513,9 +514,21 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } + parsedModel, err := modelkey.ParseModelKey(modelName) + if err != nil { + return "", fmt.Errorf("invalid model format: %w", err) + } + + if parsedModel.Provider == "custom" { + // Skip validation for custom provider + return parsedModel.String(), nil + } + + // For non-custom providers, validate the model exists + expectedModelID := azuremodels.FormatIdentifier(parsedModel.Publisher, parsedModel.ModelName) foundMatch := false for _, model := range models { - if model.HasName(modelName) { + if model.HasName(expectedModelID) { foundMatch = true break } @@ -525,7 +538,7 @@ func validateModelName(modelName string, models []*azuremodels.ModelSummary) (st return "", errors.New(noMatchErrorMessage) } - return modelName, nil + return expectedModelID, nil } func (h *runCommandHandler) getChatCompletionStreamReader(req azuremodels.ChatCompletionOptions, org string) (sse.Reader[azuremodels.ChatCompletion], error) { diff --git a/cmd/run/run_test.go b/cmd/run/run_test.go index 43ef6a1c..eb10649c 100644 --- a/cmd/run/run_test.go +++ b/cmd/run/run_test.go @@ -403,3 +403,56 @@ func TestParseTemplateVariables(t *testing.T) { }) } } + +func TestValidateModelName(t *testing.T) { + tests := []struct { + name string + modelName string + expectedModel string + expectError bool + }{ + { + name: "custom provider skips validation", + modelName: "custom/mycompany/custom-model", + expectedModel: "custom/mycompany/custom-model", + expectError: false, + }, + { + name: "azureml provider requires validation", + modelName: "openai/gpt-4", + expectedModel: "openai/gpt-4", + expectError: false, + }, + { + name: "invalid model format", + modelName: "invalid-format", + expectError: true, + }, + { + name: "nonexistent azureml model", + modelName: "nonexistent/model", + expectError: true, + }, + } + + // Create a mock model for testing + mockModel := &azuremodels.ModelSummary{ + Name: "gpt-4", + Publisher: "openai", + Task: "chat-completion", + } + models := []*azuremodels.ModelSummary{mockModel} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := validateModelName(tt.modelName, models) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expectedModel, result) + } + }) + } +} From bd103368cece54e331e91a08b3d9accb0b3d5a8a Mon Sep 17 00:00:00 2001 From: Christopher Schleiden Date: Tue, 17 Jun 2025 14:38:47 +0200 Subject: [PATCH 4/4] Refactor model key formatting to use centralized function and update tests for azureml provider behavior --- internal/azuremodels/model_details.go | 12 ++------ internal/modelkey/modelkey.go | 29 ++++++++++++++++-- internal/modelkey/modelkey_test.go | 43 ++++++++++++++++++++++----- 3 files changed, 65 insertions(+), 19 deletions(-) diff --git a/internal/azuremodels/model_details.go b/internal/azuremodels/model_details.go index ecd135ac..53289cf0 100644 --- a/internal/azuremodels/model_details.go +++ b/internal/azuremodels/model_details.go @@ -2,7 +2,8 @@ package azuremodels import ( "fmt" - "strings" + + "github.com/github/gh-models/internal/modelkey" ) // ModelDetails includes detailed information about a model. @@ -28,12 +29,5 @@ func (m *ModelDetails) ContextLimits() string { // FormatIdentifier formats the model identifier based on the publisher and model name. func FormatIdentifier(publisher, name string) string { - formatPart := func(s string) string { - // Replace spaces with dashes and convert to lowercase - result := strings.ToLower(s) - result = strings.ReplaceAll(result, " ", "-") - return result - } - - return fmt.Sprintf("%s/%s", formatPart(publisher), formatPart(name)) + return modelkey.FormatIdentifier("azureml", publisher, name) } diff --git a/internal/modelkey/modelkey.go b/internal/modelkey/modelkey.go index e58990a7..bd18562d 100644 --- a/internal/modelkey/modelkey.go +++ b/internal/modelkey/modelkey.go @@ -45,7 +45,32 @@ func ParseModelKey(modelKey string) (*ModelKey, error) { } } -// String returns the string representation of the ModelKey in the format provider/publisher/model-name +// String returns the string representation of the ModelKey. func (mk *ModelKey) String() string { - return fmt.Sprintf("%s/%s/%s", mk.Provider, mk.Publisher, mk.ModelName) + provider := formatPart(mk.Provider) + publisher := formatPart(mk.Publisher) + modelName := formatPart(mk.ModelName) + + if provider == "azureml" { + return fmt.Sprintf("%s/%s", publisher, modelName) + } + + return fmt.Sprintf("%s/%s/%s", provider, publisher, modelName) +} + +func formatPart(s string) string { + s = strings.ToLower(s) + s = strings.ReplaceAll(s, " ", "-") + + return s +} + +func FormatIdentifier(provider, publisher, name string) string { + mk := &ModelKey{ + Provider: provider, + Publisher: publisher, + ModelName: name, + } + + return mk.String() } diff --git a/internal/modelkey/modelkey_test.go b/internal/modelkey/modelkey_test.go index ea4583fa..f4d13410 100644 --- a/internal/modelkey/modelkey_test.go +++ b/internal/modelkey/modelkey_test.go @@ -120,16 +120,16 @@ func TestModelKey_String(t *testing.T) { expected string }{ { - name: "standard format with azureml provider", + name: "standard format with azureml provider - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "openai", ModelName: "gpt-4", }, - expected: "azureml/openai/gpt-4", + expected: "openai/gpt-4", }, { - name: "custom provider", + name: "custom provider - should include provider", modelKey: &ModelKey{ Provider: "custom", Publisher: "microsoft", @@ -138,25 +138,25 @@ func TestModelKey_String(t *testing.T) { expected: "custom/microsoft/phi-3", }, { - name: "model name with hyphens", + name: "azureml provider with hyphens - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "cohere", ModelName: "command-r-plus", }, - expected: "azureml/cohere/command-r-plus", + expected: "cohere/command-r-plus", }, { - name: "model name with underscores", + name: "azureml provider with underscores - should omit provider", modelKey: &ModelKey{ Provider: "azureml", Publisher: "ai21", ModelName: "jamba_instruct", }, - expected: "azureml/ai21/jamba_instruct", + expected: "ai21/jamba_instruct", }, { - name: "long provider name", + name: "non-azureml provider - should include provider", modelKey: &ModelKey{ Provider: "custom-provider", Publisher: "test-publisher", @@ -164,6 +164,33 @@ func TestModelKey_String(t *testing.T) { }, expected: "custom-provider/test-publisher/test-model", }, + { + name: "azureml provider with uppercase and spaces - should format and omit provider", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Open AI", + ModelName: "GPT 4", + }, + expected: "open-ai/gpt-4", + }, + { + name: "non-azureml provider with uppercase and spaces - should format and include provider", + modelKey: &ModelKey{ + Provider: "Custom Provider", + Publisher: "Test Publisher", + ModelName: "Test Model Name", + }, + expected: "custom-provider/test-publisher/test-model-name", + }, + { + name: "mixed case with multiple spaces", + modelKey: &ModelKey{ + Provider: "azureml", + Publisher: "Microsoft Corporation", + ModelName: "Phi 3 Mini Instruct", + }, + expected: "microsoft-corporation/phi-3-mini-instruct", + }, } for _, tt := range tests {