diff --git a/cmd/gh-aw/format_list_test.go b/cmd/gh-aw/format_list_test.go new file mode 100644 index 00000000000..518ef897c99 --- /dev/null +++ b/cmd/gh-aw/format_list_test.go @@ -0,0 +1,55 @@ +//go:build !integration + +package main + +import ( + "testing" +) + +func TestFormatListWithOr(t *testing.T) { + tests := []struct { + name string + items []string + expected string + }{ + { + name: "empty list", + items: []string{}, + expected: "", + }, + { + name: "single item", + items: []string{"apple"}, + expected: "apple", + }, + { + name: "two items", + items: []string{"apple", "banana"}, + expected: "apple or banana", + }, + { + name: "three items", + items: []string{"apple", "banana", "cherry"}, + expected: "apple, banana, or cherry", + }, + { + name: "four items", + items: []string{"apple", "banana", "cherry", "date"}, + expected: "apple, banana, cherry, or date", + }, + { + name: "engine names with quotes", + items: []string{"'claude'", "'codex'", "'copilot'", "'custom'"}, + expected: "'claude', 'codex', 'copilot', or 'custom'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatListWithOr(tt.items) + if result != tt.expected { + t.Errorf("formatListWithOr(%v) = %q, want %q", tt.items, result, tt.expected) + } + }) + } +} diff --git a/cmd/gh-aw/main.go b/cmd/gh-aw/main.go index af2a16cbfcc..4ed02b5e3e9 100644 --- a/cmd/gh-aw/main.go +++ b/cmd/gh-aw/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "os" + "sort" "strings" "github.com/github/gh-aw/pkg/cli" @@ -23,6 +24,22 @@ var ( var verboseFlag bool var bannerFlag bool +// formatListWithOr formats a list of strings with commas and "or" before the last item +// Example: ["a", "b", "c"] -> "a, b, or c" +func formatListWithOr(items []string) string { + if len(items) == 0 { + return "" + } + if len(items) == 1 { + return items[0] + } + if len(items) == 2 { + return items[0] + " or " + items[1] + } + // For 3+ items: "a, b, or c" + return strings.Join(items[:len(items)-1], ", ") + ", or " + items[len(items)-1] +} + // validateEngine validates the engine flag value func validateEngine(engine string) error { // Get the global engine registry @@ -30,15 +47,26 @@ func validateEngine(engine string) error { validEngines := registry.GetSupportedEngines() if engine != "" && !registry.IsValidEngine(engine) { + // Sort engines for deterministic output + sortedEngines := make([]string, len(validEngines)) + copy(sortedEngines, validEngines) + sort.Strings(sortedEngines) + + // Format engines with quotes and "or" conjunction + quotedEngines := make([]string, len(sortedEngines)) + for i, e := range sortedEngines { + quotedEngines[i] = "'" + e + "'" + } + formattedList := formatListWithOr(quotedEngines) + // Try to find close matches for "did you mean" suggestion suggestions := parser.FindClosestMatches(engine, validEngines, 1) - errMsg := fmt.Sprintf("invalid engine value '%s'. Must be '%s'", - engine, strings.Join(validEngines, "', '")) + errMsg := fmt.Sprintf("invalid engine value '%s'. Must be %s", engine, formattedList) if len(suggestions) > 0 { - errMsg = fmt.Sprintf("invalid engine value '%s'. Must be '%s'.\n\nDid you mean: %s?", - engine, strings.Join(validEngines, "', '"), suggestions[0]) + errMsg = fmt.Sprintf("invalid engine value '%s'. Must be %s.\n\nDid you mean: %s?", + engine, formattedList, suggestions[0]) } return fmt.Errorf("%s", errMsg) diff --git a/cmd/gh-aw/main_entry_test.go b/cmd/gh-aw/main_entry_test.go index ee06cc8cdb5..7d99c5d649c 100644 --- a/cmd/gh-aw/main_entry_test.go +++ b/cmd/gh-aw/main_entry_test.go @@ -87,8 +87,11 @@ func TestValidateEngine(t *testing.T) { return } - if tt.errMessage != "" && err.Error() != fmt.Sprintf("invalid engine value '%s'. Must be 'claude', 'codex', 'copilot', or 'custom'", tt.engine) { - t.Errorf("validateEngine(%q) error message = %v, want to contain %v", tt.engine, err.Error(), tt.errMessage) + // Check that error message contains the expected format + // Error may include "Did you mean" suggestions, so we check if it starts with the base message + expectedMsg := fmt.Sprintf("invalid engine value '%s'. Must be 'claude', 'codex', 'copilot', or 'custom'", tt.engine) + if tt.errMessage != "" && !strings.HasPrefix(err.Error(), expectedMsg) { + t.Errorf("validateEngine(%q) error message = %v, want to start with %v", tt.engine, err.Error(), expectedMsg) } } else { if err != nil {