diff --git a/pkg/cli/deps_report.go b/pkg/cli/deps_report.go index d3644c607b..f3cacbe46a 100644 --- a/pkg/cli/deps_report.go +++ b/pkg/cli/deps_report.go @@ -35,10 +35,12 @@ func GenerateDependencyReport(verbose bool) (*DependencyReport, error) { } // Parse go.mod to get all dependencies - allDeps, err := parseGoModWithIndirect(goModPath) + depsReportLog.Printf("Parsing go.mod file: %s", goModPath) + allDeps, err := parseGoModFile(goModPath) if err != nil { return nil, fmt.Errorf("failed to parse go.mod: %w", err) } + depsReportLog.Printf("Parsed go.mod: %d total dependencies", len(allDeps)) // Count direct vs indirect dependencies directCount := 0 @@ -270,18 +272,6 @@ type DependencyInfoWithIndirect struct { Indirect bool } -// parseGoModWithIndirect parses go.mod including indirect dependencies. -// This is a thin wrapper around parseGoModFile for backward compatibility. -func parseGoModWithIndirect(path string) ([]DependencyInfoWithIndirect, error) { - depsReportLog.Printf("Parsing go.mod file: %s", path) - deps, err := parseGoModFile(path) - if err != nil { - return nil, err - } - depsReportLog.Printf("Parsed go.mod: %d total dependencies", len(deps)) - return deps, nil -} - // pluralize returns the singular or plural form of a word based on count func pluralize(word string, count int) string { if count == 1 { diff --git a/pkg/cli/deps_test.go b/pkg/cli/deps_test.go index 580c461660..388b338c36 100644 --- a/pkg/cli/deps_test.go +++ b/pkg/cli/deps_test.go @@ -133,7 +133,7 @@ require ( } } -func TestParseGoModWithIndirect(t *testing.T) { +func TestParseGoModFile_WithIndirect(t *testing.T) { goModContent := `module github.com/example/test go 1.25.0 @@ -148,14 +148,14 @@ require ( tmpFile := createTempFile(t, goModContent) defer removeTempFile(t, tmpFile) - deps, err := parseGoModWithIndirect(tmpFile) + deps, err := parseGoModFile(tmpFile) if err != nil { - t.Fatalf("parseGoModWithIndirect() error = %v", err) + t.Fatalf("parseGoModFile() error = %v", err) } // Should parse all dependencies including indirect if len(deps) != 3 { - t.Errorf("parseGoModWithIndirect() found %d dependencies, want 3", len(deps)) + t.Errorf("parseGoModFile() found %d dependencies, want 3", len(deps)) } // Check indirect flag @@ -166,7 +166,7 @@ require ( } } if indirectCount != 1 { - t.Errorf("parseGoModWithIndirect() found %d indirect dependencies, want 1", indirectCount) + t.Errorf("parseGoModFile() found %d indirect dependencies, want 1", indirectCount) } } diff --git a/pkg/cli/mcp_server_helpers.go b/pkg/cli/mcp_server_helpers.go index dcf1f39c76..2addb68969 100644 --- a/pkg/cli/mcp_server_helpers.go +++ b/pkg/cli/mcp_server_helpers.go @@ -126,9 +126,13 @@ func hasWriteAccess(permission string) bool { } } -// validateWorkflowName validates that a workflow name exists. +// validateWorkflowName validates that a workflow name exists in the repository. // Returns nil if the workflow exists, or an error with suggestions if not. // Empty workflow names are considered valid (means all workflows). +// +// Note: This function checks whether a workflow exists/is accessible, not its format. +// For format-only validation (alphanumeric characters, hyphens, underscores), +// use validators.go:ValidateWorkflowName instead. func validateWorkflowName(workflowName string) error { // Empty workflow name means "all workflows" - this is valid if workflowName == "" { diff --git a/pkg/cli/pr_command.go b/pkg/cli/pr_command.go index 0d1695a67f..df2e33fb83 100644 --- a/pkg/cli/pr_command.go +++ b/pkg/cli/pr_command.go @@ -103,11 +103,6 @@ The command will: return cmd } -// parsePRURL extracts owner, repo, and PR number from a GitHub PR URL -func parsePRURL(prURL string) (owner, repo string, prNumber int, err error) { - return parser.ParsePRURL(prURL) -} - // checkRepositoryAccess checks if the current user has write access to the target repository func checkRepositoryAccess(owner, repo string) (bool, error) { prLog.Printf("Checking repository access: %s/%s", owner, repo) @@ -548,7 +543,7 @@ func transferPR(prURL, targetRepo string, verbose bool) error { } // Parse PR URL - sourceOwner, sourceRepoName, prNumber, err := parsePRURL(prURL) + sourceOwner, sourceRepoName, prNumber, err := parser.ParsePRURL(prURL) if err != nil { prLog.Printf("Failed to parse PR URL: %s", err) return err diff --git a/pkg/cli/pr_command_test.go b/pkg/cli/pr_command_test.go index 203772ddcf..a2b150753b 100644 --- a/pkg/cli/pr_command_test.go +++ b/pkg/cli/pr_command_test.go @@ -4,6 +4,8 @@ package cli import ( "testing" + + "github.com/github/gh-aw/pkg/parser" ) func TestParsePRURL(t *testing.T) { @@ -81,30 +83,30 @@ func TestParsePRURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - owner, repo, prNumber, err := parsePRURL(tt.url) + owner, repo, prNumber, err := parser.ParsePRURL(tt.url) if tt.wantErr { if err == nil { - t.Errorf("parsePRURL() expected error but got none") + t.Errorf("ParsePRURL() expected error but got none") } return } if err != nil { - t.Errorf("parsePRURL() unexpected error: %v", err) + t.Errorf("ParsePRURL() unexpected error: %v", err) return } if owner != tt.wantOwner { - t.Errorf("parsePRURL() owner = %v, want %v", owner, tt.wantOwner) + t.Errorf("ParsePRURL() owner = %v, want %v", owner, tt.wantOwner) } if repo != tt.wantRepo { - t.Errorf("parsePRURL() repo = %v, want %v", repo, tt.wantRepo) + t.Errorf("ParsePRURL() repo = %v, want %v", repo, tt.wantRepo) } if prNumber != tt.wantPR { - t.Errorf("parsePRURL() prNumber = %v, want %v", prNumber, tt.wantPR) + t.Errorf("ParsePRURL() prNumber = %v, want %v", prNumber, tt.wantPR) } }) } diff --git a/pkg/workflow/cache.go b/pkg/workflow/cache.go index 0950d7445d..3f6f8dd9b4 100644 --- a/pkg/workflow/cache.go +++ b/pkg/workflow/cache.go @@ -864,23 +864,3 @@ func (c *Compiler) buildUpdateCacheMemoryJob(data *WorkflowData, threatDetection return job, nil } - -// validateNoDuplicateCacheIDs checks for duplicate cache IDs and returns an error if found -func validateNoDuplicateCacheIDs(caches []CacheMemoryEntry) error { - cacheLog.Printf("Validating cache IDs: checking %d caches for duplicates", len(caches)) - seen := make(map[string]bool) - for _, cache := range caches { - if seen[cache.ID] { - cacheLog.Printf("Duplicate cache ID found: %s", cache.ID) - return NewValidationError( - "sandbox.cache-memory", - cache.ID, - "duplicate cache-memory ID found - each cache must have a unique ID", - "Change the cache ID to a unique value. Example:\n\nsandbox:\n cache-memory:\n - id: cache-1\n size: 100MB\n - id: cache-2 # Use unique IDs\n size: 50MB", - ) - } - seen[cache.ID] = true - } - cacheLog.Print("Cache ID validation passed: no duplicates found") - return nil -} diff --git a/pkg/workflow/cache_validation.go b/pkg/workflow/cache_validation.go new file mode 100644 index 0000000000..82a5efdd5d --- /dev/null +++ b/pkg/workflow/cache_validation.go @@ -0,0 +1,39 @@ +// This file provides validation for sandbox cache-memory configuration. +// +// # Cache Memory Validation +// +// This file validates that cache-memory entries in a workflow's sandbox +// configuration have unique IDs, preventing runtime conflicts when multiple +// cache entries are defined. +// +// # Validation Functions +// +// - validateNoDuplicateCacheIDs() - Ensures each cache entry has a unique ID +// +// # When to Add Validation Here +// +// Add validation to this file when: +// - Adding new cache-memory configuration constraints +// - Adding cross-cache validation rules (e.g., total size limits) + +package workflow + +// validateNoDuplicateCacheIDs checks for duplicate cache IDs and returns an error if found. +// Uses the generic validateNoDuplicateIDs helper for consistent duplicate detection. +func validateNoDuplicateCacheIDs(caches []CacheMemoryEntry) error { + cacheLog.Printf("Validating cache IDs: checking %d caches for duplicates", len(caches)) + err := validateNoDuplicateIDs(caches, func(c CacheMemoryEntry) string { return c.ID }, func(id string) error { + cacheLog.Printf("Duplicate cache ID found: %s", id) + return NewValidationError( + "sandbox.cache-memory", + id, + "duplicate cache-memory ID found - each cache must have a unique ID", + "Change the cache ID to a unique value. Example:\n\nsandbox:\n cache-memory:\n - id: cache-1\n size: 100MB\n - id: cache-2 # Use unique IDs\n size: 50MB", + ) + }) + if err != nil { + return err + } + cacheLog.Print("Cache ID validation passed: no duplicates found") + return nil +} diff --git a/pkg/workflow/concurrency_validation.go b/pkg/workflow/concurrency_validation.go index f5a68fcca2..f8cfe0b144 100644 --- a/pkg/workflow/concurrency_validation.go +++ b/pkg/workflow/concurrency_validation.go @@ -28,7 +28,6 @@ package workflow import ( - "fmt" "regexp" "strings" ) @@ -36,8 +35,7 @@ import ( var concurrencyValidationLog = newValidationLogger("concurrency") var ( - concurrencyExpressionPattern = regexp.MustCompile(`\$\{\{([^}]*)\}\}`) - concurrencyGroupPattern = regexp.MustCompile(`(?m)^\s*group:\s*["']?([^"'\n]+?)["']?\s*$`) + concurrencyGroupPattern = regexp.MustCompile(`(?m)^\s*group:\s*["']?([^"'\n]+?)["']?\s*$`) ) // validateConcurrencyGroupExpression validates the syntax of a custom concurrency group expression. @@ -74,216 +72,6 @@ func validateConcurrencyGroupExpression(group string) error { return nil } -// validateBalancedBraces checks that all ${{ }} braces are balanced and properly closed -func validateBalancedBraces(group string) error { - concurrencyValidationLog.Print("Checking balanced braces in expression") - openCount := 0 - i := 0 - positions := []int{} // Track positions of opening braces for error reporting - - for i < len(group) { - // Check for opening ${{ - if i+2 < len(group) && group[i:i+3] == "${{" { - openCount++ - positions = append(positions, i) - i += 3 - continue - } - - // Check for closing }} - if i+1 < len(group) && group[i:i+2] == "}}" { - if openCount == 0 { - return NewValidationError( - "concurrency", - "unbalanced closing braces", - fmt.Sprintf("found '}}' at position %d without matching opening '${{' in expression: %s", i, group), - "Ensure all '}}' have a corresponding opening '${{'. Check for typos or missing opening braces.", - ) - } - openCount-- - if len(positions) > 0 { - positions = positions[:len(positions)-1] - } - i += 2 - continue - } - - i++ - } - - if openCount > 0 { - // Find the position of the first unclosed opening brace - pos := positions[0] - concurrencyValidationLog.Printf("Found %d unclosed brace(s) starting at position %d", openCount, pos) - return NewValidationError( - "concurrency", - "unclosed expression braces", - fmt.Sprintf("found opening '${{' at position %d without matching closing '}}' in expression: %s", pos, group), - "Ensure all '${{' have a corresponding closing '}}'. Add the missing closing braces.", - ) - } - - concurrencyValidationLog.Print("Brace balance check passed") - return nil -} - -// validateExpressionSyntax validates the syntax of expressions within ${{ }} -func validateExpressionSyntax(group string) error { - // Pattern to extract content between ${{ and }} - matches := concurrencyExpressionPattern.FindAllStringSubmatch(group, -1) - - concurrencyValidationLog.Printf("Found %d expression(s) to validate", len(matches)) - - for _, match := range matches { - if len(match) < 2 { - continue - } - - exprContent := strings.TrimSpace(match[1]) - if exprContent == "" { - return NewValidationError( - "concurrency", - "empty expression content", - "found empty expression '${{ }}' in concurrency group: "+group, - "Provide a valid GitHub Actions expression inside '${{ }}'. Example: '${{ github.ref }}'", - ) - } - - // Check for common syntax errors - if err := validateExpressionContent(exprContent, group); err != nil { - return err - } - } - - return nil -} - -// validateExpressionContent validates the content inside ${{ }} -func validateExpressionContent(expr string, fullGroup string) error { - // Check for unbalanced parentheses - parenCount := 0 - for i, ch := range expr { - switch ch { - case '(': - parenCount++ - case ')': - parenCount-- - if parenCount < 0 { - return NewValidationError( - "concurrency", - "unbalanced parentheses in expression", - fmt.Sprintf("found closing ')' without matching opening '(' at position %d in expression: %s", i, expr), - "Ensure all parentheses are properly balanced in your concurrency group expression.", - ) - } - } - } - - if parenCount > 0 { - return NewValidationError( - "concurrency", - "unclosed parentheses in expression", - fmt.Sprintf("found %d unclosed opening '(' in expression: %s", parenCount, expr), - "Add the missing closing ')' to balance parentheses in your expression.", - ) - } - - // Check for unbalanced quotes (single, double, backtick) - if err := validateBalancedQuotes(expr); err != nil { - return err - } - - // Try to parse complex expressions with logical operators - if containsLogicalOperators(expr) { - concurrencyValidationLog.Print("Expression contains logical operators, performing deep validation") - if _, err := ParseExpression(expr); err != nil { - concurrencyValidationLog.Printf("Expression parsing failed: %v", err) - return NewValidationError( - "concurrency", - "invalid expression syntax", - "failed to parse expression in concurrency group: "+err.Error(), - "Fix the syntax error in your concurrency group expression. Full expression: "+fullGroup, - ) - } - } - - return nil -} - -// validateBalancedQuotes checks for balanced quotes in an expression -func validateBalancedQuotes(expr string) error { - inSingleQuote := false - inDoubleQuote := false - inBacktick := false - escaped := false - - for i, ch := range expr { - if escaped { - escaped = false - continue - } - - if ch == '\\' { - escaped = true - continue - } - - switch ch { - case '\'': - if !inDoubleQuote && !inBacktick { - inSingleQuote = !inSingleQuote - } - case '"': - if !inSingleQuote && !inBacktick { - inDoubleQuote = !inDoubleQuote - } - case '`': - if !inSingleQuote && !inDoubleQuote { - inBacktick = !inBacktick - } - } - - // Check if we reached end of string with unclosed quote - if i == len(expr)-1 { - if inSingleQuote { - return NewValidationError( - "concurrency", - "unclosed single quote", - "found unclosed single quote in expression: "+expr, - "Add the missing closing single quote (') to your expression.", - ) - } - if inDoubleQuote { - return NewValidationError( - "concurrency", - "unclosed double quote", - "found unclosed double quote in expression: "+expr, - "Add the missing closing double quote (\") to your expression.", - ) - } - if inBacktick { - return NewValidationError( - "concurrency", - "unclosed backtick", - "found unclosed backtick in expression: "+expr, - "Add the missing closing backtick (`) to your expression.", - ) - } - } - } - - return nil -} - -// containsLogicalOperators checks if an expression contains logical operators (&&, ||, !) -// Note: This is a simple string-based check that may return true for expressions containing -// '!=' (not equals) since it includes the '!' character. This is acceptable because the -// function is used to decide whether to parse the expression with the expression parser, -// and expressions with '!=' will be successfully parsed by the parser. -func containsLogicalOperators(expr string) bool { - return strings.Contains(expr, "&&") || strings.Contains(expr, "||") || strings.Contains(expr, "!") -} - // extractConcurrencyGroupFromYAML extracts the group value from a YAML-formatted concurrency string. // The input is expected to be in the format generated by the compiler: // diff --git a/pkg/workflow/expression_validation.go b/pkg/workflow/expression_validation.go index d8bdabd39b..9e1f586f9d 100644 --- a/pkg/workflow/expression_validation.go +++ b/pkg/workflow/expression_validation.go @@ -499,3 +499,217 @@ func validateRuntimeImportFiles(markdownContent string, workspaceDir string) err expressionValidationLog.Print("All runtime-import files validated successfully") return nil } + +// expressionBracesPattern matches GitHub Actions ${{ }} expressions for syntax validation. +// Uses [^}]* to match non-closing-brace characters within the expression. +var expressionBracesPattern = regexp.MustCompile(`\$\{\{([^}]*)\}\}`) + +// validateBalancedBraces checks that all ${{ }} braces are balanced and properly closed +func validateBalancedBraces(group string) error { + expressionValidationLog.Print("Checking balanced braces in expression") + openCount := 0 + i := 0 + positions := []int{} // Track positions of opening braces for error reporting + + for i < len(group) { + // Check for opening ${{ + if i+2 < len(group) && group[i:i+3] == "${{" { + openCount++ + positions = append(positions, i) + i += 3 + continue + } + + // Check for closing }} + if i+1 < len(group) && group[i:i+2] == "}}" { + if openCount == 0 { + return NewValidationError( + "expression", + "unbalanced closing braces", + fmt.Sprintf("found '}}' at position %d without matching opening '${{' in expression: %s", i, group), + "Ensure all '}}' have a corresponding opening '${{'. Check for typos or missing opening braces.", + ) + } + openCount-- + if len(positions) > 0 { + positions = positions[:len(positions)-1] + } + i += 2 + continue + } + + i++ + } + + if openCount > 0 { + // Find the position of the first unclosed opening brace + pos := positions[0] + expressionValidationLog.Printf("Found %d unclosed brace(s) starting at position %d", openCount, pos) + return NewValidationError( + "expression", + "unclosed expression braces", + fmt.Sprintf("found opening '${{' at position %d without matching closing '}}' in expression: %s", pos, group), + "Ensure all '${{' have a corresponding closing '}}'. Add the missing closing braces.", + ) + } + + expressionValidationLog.Print("Brace balance check passed") + return nil +} + +// validateExpressionSyntax validates the syntax of expressions within ${{ }} +func validateExpressionSyntax(group string) error { + // Pattern to extract content between ${{ and }} + matches := expressionBracesPattern.FindAllStringSubmatch(group, -1) + + expressionValidationLog.Printf("Found %d expression(s) to validate", len(matches)) + + for _, match := range matches { + if len(match) < 2 { + continue + } + + exprContent := strings.TrimSpace(match[1]) + if exprContent == "" { + return NewValidationError( + "expression", + "empty expression content", + "found empty expression '${{ }}' in: "+group, + "Provide a valid GitHub Actions expression inside '${{ }}'. Example: '${{ github.ref }}'", + ) + } + + // Check for common syntax errors + if err := validateExpressionContent(exprContent, group); err != nil { + return err + } + } + + return nil +} + +// validateExpressionContent validates the content inside ${{ }} +func validateExpressionContent(expr string, fullGroup string) error { + // Check for unbalanced parentheses + parenCount := 0 + for i, ch := range expr { + switch ch { + case '(': + parenCount++ + case ')': + parenCount-- + if parenCount < 0 { + return NewValidationError( + "expression", + "unbalanced parentheses in expression", + fmt.Sprintf("found closing ')' without matching opening '(' at position %d in expression: %s", i, expr), + "Ensure all parentheses are properly balanced in your expression.", + ) + } + } + } + + if parenCount > 0 { + return NewValidationError( + "expression", + "unclosed parentheses in expression", + fmt.Sprintf("found %d unclosed opening '(' in expression: %s", parenCount, expr), + "Add the missing closing ')' to balance parentheses in your expression.", + ) + } + + // Check for unbalanced quotes (single, double, backtick) + if err := validateBalancedQuotes(expr); err != nil { + return err + } + + // Try to parse complex expressions with logical operators + if containsLogicalOperators(expr) { + expressionValidationLog.Print("Expression contains logical operators, performing deep validation") + if _, err := ParseExpression(expr); err != nil { + expressionValidationLog.Printf("Expression parsing failed: %v", err) + return NewValidationError( + "expression", + "invalid expression syntax", + "failed to parse expression: "+err.Error(), + "Fix the syntax error in your expression. Full expression: "+fullGroup, + ) + } + } + + return nil +} + +// validateBalancedQuotes checks for balanced quotes in an expression +func validateBalancedQuotes(expr string) error { + inSingleQuote := false + inDoubleQuote := false + inBacktick := false + escaped := false + + for i, ch := range expr { + if escaped { + escaped = false + continue + } + + if ch == '\\' { + escaped = true + continue + } + + switch ch { + case '\'': + if !inDoubleQuote && !inBacktick { + inSingleQuote = !inSingleQuote + } + case '"': + if !inSingleQuote && !inBacktick { + inDoubleQuote = !inDoubleQuote + } + case '`': + if !inSingleQuote && !inDoubleQuote { + inBacktick = !inBacktick + } + } + + // Check if we reached end of string with unclosed quote + if i == len(expr)-1 { + if inSingleQuote { + return NewValidationError( + "expression", + "unclosed single quote", + "found unclosed single quote in expression: "+expr, + "Add the missing closing single quote (') to your expression.", + ) + } + if inDoubleQuote { + return NewValidationError( + "expression", + "unclosed double quote", + "found unclosed double quote in expression: "+expr, + "Add the missing closing double quote (\") to your expression.", + ) + } + if inBacktick { + return NewValidationError( + "expression", + "unclosed backtick", + "found unclosed backtick in expression: "+expr, + "Add the missing closing backtick (`) to your expression.", + ) + } + } + } + + return nil +} + +// containsLogicalOperators checks if an expression contains logical operators (&&, ||, !) +// Note: This is a simple string-based check that may return true for expressions containing +// '!=' (not equals) since it includes the '!' character. This is acceptable because the +// function is used to decide whether to parse the expression with the expression parser, +// and expressions with '!=' will be successfully parsed by the parser. +func containsLogicalOperators(expr string) bool { + return strings.Contains(expr, "&&") || strings.Contains(expr, "||") || strings.Contains(expr, "!") +} diff --git a/pkg/workflow/github_tool_to_toolset.go b/pkg/workflow/github_tool_to_toolset.go index 04a6fcfc87..0ae9371a7e 100644 --- a/pkg/workflow/github_tool_to_toolset.go +++ b/pkg/workflow/github_tool_to_toolset.go @@ -3,12 +3,8 @@ package workflow import ( _ "embed" "encoding/json" - "fmt" - "sort" - "strings" "github.com/github/gh-aw/pkg/logger" - "github.com/github/gh-aw/pkg/parser" ) var githubToolToToolsetLog = logger.New("workflow:github_tool_to_toolset") @@ -28,102 +24,5 @@ func init() { } } -// ValidateGitHubToolsAgainstToolsets validates that all allowed GitHub tools have their -// corresponding toolsets enabled in the configuration -func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets []string) error { - githubToolToToolsetLog.Printf("Validating GitHub tools against toolsets: allowed_tools=%d, enabled_toolsets=%d", len(allowedTools), len(enabledToolsets)) - - if len(allowedTools) == 0 { - githubToolToToolsetLog.Print("No tools to validate, skipping") - // No specific tools restricted, validation not needed - return nil - } - - // Create a set of enabled toolsets for fast lookup - enabledSet := make(map[string]bool) - for _, toolset := range enabledToolsets { - enabledSet[toolset] = true - } - githubToolToToolsetLog.Printf("Enabled toolsets: %v", enabledToolsets) - - // Track missing toolsets and which tools need them - missingToolsets := make(map[string][]string) // toolset -> list of tools that need it - - // Track unknown tools for suggestions - var unknownTools []string - var suggestions []string - - for _, tool := range allowedTools { - // Skip wildcard - it means "allow all tools" - if tool == "*" { - continue - } - - requiredToolset, exists := GitHubToolToToolsetMap[tool] - if !exists { - githubToolToToolsetLog.Printf("Tool %s not found in mapping, checking for typo", tool) - - // Get all valid tool names for suggestion - validTools := make([]string, 0, len(GitHubToolToToolsetMap)) - for validTool := range GitHubToolToToolsetMap { - validTools = append(validTools, validTool) - } - sort.Strings(validTools) - - // Try to find close matches - matches := parser.FindClosestMatches(tool, validTools, 1) - if len(matches) > 0 { - githubToolToToolsetLog.Printf("Found suggestion for unknown tool %s: %s", tool, matches[0]) - unknownTools = append(unknownTools, tool) - suggestions = append(suggestions, fmt.Sprintf("%s → %s", tool, matches[0])) - } else { - githubToolToToolsetLog.Printf("No suggestion found for unknown tool: %s", tool) - unknownTools = append(unknownTools, tool) - } - // Tool not in our mapping - this could be a new tool or a typo - // We'll skip validation for unknown tools to avoid false positives - continue - } - - if !enabledSet[requiredToolset] { - githubToolToToolsetLog.Printf("Tool %s requires missing toolset: %s", tool, requiredToolset) - missingToolsets[requiredToolset] = append(missingToolsets[requiredToolset], tool) - } - } - - // Report unknown tools with suggestions if any were found - if len(unknownTools) > 0 { - githubToolToToolsetLog.Printf("Found %d unknown tools", len(unknownTools)) - var errMsg strings.Builder - errMsg.WriteString(fmt.Sprintf("Unknown GitHub tool(s): %s\n\n", formatList(unknownTools))) - - if len(suggestions) > 0 { - errMsg.WriteString("Did you mean:\n") - for _, s := range suggestions { - errMsg.WriteString(fmt.Sprintf(" %s\n", s)) - } - errMsg.WriteString("\n") - } - - // Show a few examples of valid tools - validTools := make([]string, 0, len(GitHubToolToToolsetMap)) - for tool := range GitHubToolToToolsetMap { - validTools = append(validTools, tool) - } - sort.Strings(validTools) - - exampleCount := min(10, len(validTools)) - errMsg.WriteString(fmt.Sprintf("Valid GitHub tools include: %s\n\n", formatList(validTools[:exampleCount]))) - errMsg.WriteString("See all tools: https://github.com/github/gh-aw/blob/main/pkg/workflow/data/github_tool_to_toolset.json") - - return fmt.Errorf("%s", errMsg.String()) - } - - if len(missingToolsets) > 0 { - githubToolToToolsetLog.Printf("Validation failed: missing %d toolsets", len(missingToolsets)) - return NewGitHubToolsetValidationError(missingToolsets) - } - - githubToolToToolsetLog.Print("Validation successful: all tools have required toolsets") - return nil -} +// GitHubToolToToolsetMap is the last declaration in this file; ValidateGitHubToolsAgainstToolsets +// has been moved to tools_validation.go. diff --git a/pkg/workflow/imports.go b/pkg/workflow/imports.go index e95245a792..e3c785898e 100644 --- a/pkg/workflow/imports.go +++ b/pkg/workflow/imports.go @@ -155,163 +155,6 @@ func (c *Compiler) MergeNetworkPermissions(topNetwork *NetworkPermissions, impor return result, nil } -// ValidateIncludedPermissions validates that the main workflow permissions satisfy the imported workflow requirements -// This function is specifically used when merging included/imported workflow files to ensure the main workflow -// has sufficient permissions to support the requirements from all imported files. -// Takes the top-level permissions YAML string and imported permissions JSON string -// Returns an error if the main workflow permissions are insufficient -// -// Use ValidatePermissions (in permissions_validator.go) for general permission validation against GitHub MCP toolsets. -// Use ValidateIncludedPermissions (this function) when validating permissions from included/imported workflow files. -func (c *Compiler) ValidateIncludedPermissions(topPermissionsYAML string, importedPermissionsJSON string) error { - importsLog.Print("Validating permissions from imports") - - // If no imported permissions, no validation needed - if importedPermissionsJSON == "" || importedPermissionsJSON == "{}" { - importsLog.Print("No imported permissions to validate") - return nil - } - - // Parse top-level permissions - var topPerms *Permissions - if topPermissionsYAML != "" { - topPerms = NewPermissionsParser(topPermissionsYAML).ToPermissions() - } else { - topPerms = NewPermissions() - } - - // Track missing permissions - missingPermissions := make(map[PermissionScope]PermissionLevel) - insufficientPermissions := make(map[PermissionScope]struct { - required PermissionLevel - current PermissionLevel - }) - - // Split by newlines to handle multiple JSON objects from different imports - lines := strings.Split(importedPermissionsJSON, "\n") - importsLog.Printf("Processing %d permission definition lines", len(lines)) - - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" || line == "{}" { - continue - } - - // Parse JSON line to permissions map - var importedPermsMap map[string]any - if err := json.Unmarshal([]byte(line), &importedPermsMap); err != nil { - importsLog.Printf("Skipping malformed permission entry: %q (error: %v)", line, err) - continue - } - - // Check each permission from the imported map - for scopeStr, levelValue := range importedPermsMap { - scope := PermissionScope(scopeStr) - - // Parse the level - it might be a string or already unmarshaled - var requiredLevel PermissionLevel - if levelStr, ok := levelValue.(string); ok { - requiredLevel = PermissionLevel(levelStr) - } else { - // Skip invalid level values - continue - } - - // Get current level for this scope - currentLevel, exists := topPerms.Get(scope) - - // Validate that the main workflow has sufficient permissions - if !exists || currentLevel == PermissionNone { - // Permission is missing entirely - missingPermissions[scope] = requiredLevel - importsLog.Printf("Missing permission: %s: %s", scope, requiredLevel) - } else if !isPermissionSufficient(currentLevel, requiredLevel) { - // Permission exists but is insufficient - insufficientPermissions[scope] = struct { - required PermissionLevel - current PermissionLevel - }{requiredLevel, currentLevel} - importsLog.Printf("Insufficient permission: %s: has %s, needs %s", scope, currentLevel, requiredLevel) - } - } - } - - // If there are missing or insufficient permissions, return an error - if len(missingPermissions) > 0 || len(insufficientPermissions) > 0 { - var errorMsg strings.Builder - errorMsg.WriteString("ERROR: Imported workflows require permissions that are not granted in the main workflow.\n\n") - errorMsg.WriteString("The permission set must be explicitly declared in the main workflow.\n\n") - - if len(missingPermissions) > 0 { - errorMsg.WriteString("Missing permissions:\n") - // Sort for consistent output - var scopes []PermissionScope - for scope := range missingPermissions { - scopes = append(scopes, scope) - } - SortPermissionScopes(scopes) - for _, scope := range scopes { - level := missingPermissions[scope] - fmt.Fprintf(&errorMsg, " - %s: %s\n", scope, level) - } - errorMsg.WriteString("\n") - } - - if len(insufficientPermissions) > 0 { - errorMsg.WriteString("Insufficient permissions:\n") - // Sort for consistent output - var scopes []PermissionScope - for scope := range insufficientPermissions { - scopes = append(scopes, scope) - } - SortPermissionScopes(scopes) - for _, scope := range scopes { - info := insufficientPermissions[scope] - fmt.Fprintf(&errorMsg, " - %s: has %s, requires %s\n", scope, info.current, info.required) - } - errorMsg.WriteString("\n") - } - - errorMsg.WriteString("Suggested fix: Add the required permissions to your main workflow frontmatter:\n") - errorMsg.WriteString("permissions:\n") - - // Combine all required permissions for the suggestion - allRequired := make(map[PermissionScope]PermissionLevel) - maps.Copy(allRequired, missingPermissions) - for scope, info := range insufficientPermissions { - allRequired[scope] = info.required - } - - var scopes []PermissionScope - for scope := range allRequired { - scopes = append(scopes, scope) - } - SortPermissionScopes(scopes) - for _, scope := range scopes { - level := allRequired[scope] - fmt.Fprintf(&errorMsg, " %s: %s\n", scope, level) - } - - return fmt.Errorf("%s", errorMsg.String()) - } - - importsLog.Print("All imported permissions are satisfied by main workflow") - return nil -} - -// isPermissionSufficient checks if the current permission level is sufficient for the required level -// write > read > none -func isPermissionSufficient(current, required PermissionLevel) bool { - if current == required { - return true - } - // write satisfies read requirement - if current == PermissionWrite && required == PermissionRead { - return true - } - return false -} - // getSafeOutputTypeKeys returns the list of safe output type keys from the embedded schema. // This is a cached wrapper around parser.GetSafeOutputTypeKeys() to avoid parsing on every call. var ( diff --git a/pkg/workflow/jobs.go b/pkg/workflow/jobs.go index 7a49b2bc16..3f9cd7c611 100644 --- a/pkg/workflow/jobs.go +++ b/pkg/workflow/jobs.go @@ -85,66 +85,6 @@ func (jm *JobManager) GetAllJobs() map[string]*Job { return result } -// ValidateDependencies checks that all job dependencies exist and there are no cycles -func (jm *JobManager) ValidateDependencies() error { - jobLog.Printf("Validating dependencies for %d jobs", len(jm.jobs)) - // First check that all dependencies reference existing jobs - for jobName, job := range jm.jobs { - for _, dep := range job.Needs { - if _, exists := jm.jobs[dep]; !exists { - jobLog.Printf("Validation failed: job %s depends on non-existent job %s", jobName, dep) - return fmt.Errorf("job '%s' depends on non-existent job '%s'", jobName, dep) - } - } - } - - // Check for cycles using DFS - return jm.detectCycles() -} - -// ValidateDuplicateSteps checks that no job has duplicate steps -// This detects compiler bugs where the same step is added multiple times -func (jm *JobManager) ValidateDuplicateSteps() error { - jobLog.Printf("Validating for duplicate steps in %d jobs", len(jm.jobs)) - - for jobName, job := range jm.jobs { - if len(job.Steps) == 0 { - continue - } - - // Track seen steps to detect duplicates - seen := make(map[string]int) - - for i, step := range job.Steps { - // job.Steps entries may be either complete step blocks (multi-line) or - // individual YAML line fragments. Only elements that begin with the step - // leader "- " represent a new step definition; property lines (e.g., - // "continue-on-error:", "name:" inside a "with:" block) start with - // plain indentation and should not be treated as step definitions. - if !strings.HasPrefix(strings.TrimSpace(step), "-") { - continue - } - - // Extract step name from YAML for comparison - stepName := extractStepName(step) - if stepName == "" { - // Steps without names can't be checked for duplicates - continue - } - - if firstIndex, exists := seen[stepName]; exists { - jobLog.Printf("Duplicate step detected in job '%s': step '%s' at positions %d and %d", jobName, stepName, firstIndex, i) - return fmt.Errorf("compiler bug: duplicate step '%s' found in job '%s' (positions %d and %d)", stepName, jobName, firstIndex, i) - } - - seen[stepName] = i - } - } - - jobLog.Print("No duplicate steps detected in any job") - return nil -} - // extractStepName extracts the step name from a YAML step string // Returns empty string if no name is found func extractStepName(stepYAML string) string { diff --git a/pkg/workflow/jobs_validation.go b/pkg/workflow/jobs_validation.go new file mode 100644 index 0000000000..ced9f74e3e --- /dev/null +++ b/pkg/workflow/jobs_validation.go @@ -0,0 +1,84 @@ +// This file provides validation for GitHub Actions job configurations. +// +// # Job Validation +// +// This file validates that job definitions are correct before workflow compilation, +// catching issues that would cause silent failures or confusing runtime errors. +// +// # Validation Functions +// +// - ValidateDependencies() - Checks job dependencies exist and contain no cycles +// - ValidateDuplicateSteps() - Detects duplicate step definitions (compiler bugs) +// +// # When to Add Validation Here +// +// Add validation to this file when: +// - Adding new job-level structural constraints +// - Adding new dependency graph validation rules + +package workflow + +import ( + "fmt" + "strings" +) + +// ValidateDependencies checks that all job dependencies exist and there are no cycles +func (jm *JobManager) ValidateDependencies() error { + jobLog.Printf("Validating dependencies for %d jobs", len(jm.jobs)) + // First check that all dependencies reference existing jobs + for jobName, job := range jm.jobs { + for _, dep := range job.Needs { + if _, exists := jm.jobs[dep]; !exists { + jobLog.Printf("Validation failed: job %s depends on non-existent job %s", jobName, dep) + return fmt.Errorf("job '%s' depends on non-existent job '%s'", jobName, dep) + } + } + } + + // Check for cycles using DFS + return jm.detectCycles() +} + +// ValidateDuplicateSteps checks that no job has duplicate steps. +// This detects compiler bugs where the same step is added multiple times. +func (jm *JobManager) ValidateDuplicateSteps() error { + jobLog.Printf("Validating for duplicate steps in %d jobs", len(jm.jobs)) + + for jobName, job := range jm.jobs { + if len(job.Steps) == 0 { + continue + } + + // Track seen steps to detect duplicates + seen := make(map[string]int) + + for i, step := range job.Steps { + // job.Steps entries may be either complete step blocks (multi-line) or + // individual YAML line fragments. Only elements that begin with the step + // leader "- " represent a new step definition; property lines (e.g., + // "continue-on-error:", "name:" inside a "with:" block) start with + // plain indentation and should not be treated as step definitions. + if !strings.HasPrefix(strings.TrimSpace(step), "-") { + continue + } + + // Extract step name from YAML for comparison + stepName := extractStepName(step) + if stepName == "" { + // Steps without names can't be checked for duplicates + continue + } + + if firstIndex, exists := seen[stepName]; exists { + jobLog.Printf("Duplicate step detected in job '%s': step '%s' at positions %d and %d", jobName, stepName, firstIndex, i) + return fmt.Errorf("compiler bug: duplicate step '%s' found in job '%s' (positions %d and %d)", stepName, jobName, firstIndex, i) + } + + seen[stepName] = i + } + } + + jobLog.Print("No duplicate steps detected in any job") + return nil +} diff --git a/pkg/workflow/lock_schema.go b/pkg/workflow/lock_schema.go index 724063c78d..d45d1b39f9 100644 --- a/pkg/workflow/lock_schema.go +++ b/pkg/workflow/lock_schema.go @@ -75,41 +75,6 @@ func ExtractMetadataFromLockFile(content string) (*LockMetadata, bool, error) { return nil, false, nil } -// ValidateLockSchemaCompatibility validates that a lock file's schema is compatible -// Returns an error with actionable guidance if incompatible -func ValidateLockSchemaCompatibility(content string, lockFilePath string) error { - metadata, isLegacy, err := ExtractMetadataFromLockFile(content) - if err != nil { - return fmt.Errorf("failed to extract metadata from %s: %w", lockFilePath, err) - } - - // Legacy files (no schema version) are supported for backward compatibility - if isLegacy { - lockSchemaLog.Printf("Legacy lock file accepted: %s", lockFilePath) - return nil - } - - // Missing metadata entirely is suspicious - if metadata == nil { - return fmt.Errorf("lock file %s is missing required metadata. This file may be corrupted or manually edited.\n\nTo fix this, recompile the workflow:\n gh aw compile %s", - lockFilePath, - strings.TrimSuffix(lockFilePath, ".lock.yml")+".md") - } - - // Check schema compatibility - if !IsSchemaVersionSupported(metadata.SchemaVersion) { - // Future version detected - return fmt.Errorf("lock file %s uses unsupported schema version '%s'.\n\nThis file was generated by a newer version of gh-aw that uses incompatible features.\n\nSupported versions: %s\n\nTo fix this:\n 1. Upgrade gh-aw: gh extension upgrade gh-aw\n 2. Or downgrade the lock file by editing the source .md file and recompiling:\n gh aw compile %s", - lockFilePath, - metadata.SchemaVersion, - formatSupportedVersions(), - strings.TrimSuffix(lockFilePath, ".lock.yml")+".md") - } - - lockSchemaLog.Printf("Lock file schema validated: %s (version=%s)", lockFilePath, metadata.SchemaVersion) - return nil -} - // formatSupportedVersions formats the list of supported versions for error messages func formatSupportedVersions() string { versions := make([]string, len(SupportedSchemaVersions)) diff --git a/pkg/workflow/lock_validation.go b/pkg/workflow/lock_validation.go new file mode 100644 index 0000000000..c518fe3383 --- /dev/null +++ b/pkg/workflow/lock_validation.go @@ -0,0 +1,59 @@ +// This file provides validation for workflow lock file schema compatibility. +// +// # Lock File Schema Validation +// +// This file validates that lock files use a schema version that the current +// build of gh-aw supports. It provides actionable error messages when a +// lock file was generated by an incompatible version. +// +// # Validation Functions +// +// - ValidateLockSchemaCompatibility() - Validates lock file schema version +// +// # When to Add Validation Here +// +// Add validation to this file when: +// - Adding new lock file format constraints +// - Adding migration validation for schema upgrades + +package workflow + +import ( + "fmt" + "strings" +) + +// ValidateLockSchemaCompatibility validates that a lock file's schema is compatible. +// Returns an error with actionable guidance if incompatible. +func ValidateLockSchemaCompatibility(content string, lockFilePath string) error { + metadata, isLegacy, err := ExtractMetadataFromLockFile(content) + if err != nil { + return fmt.Errorf("failed to extract metadata from %s: %w", lockFilePath, err) + } + + // Legacy files (no schema version) are supported for backward compatibility + if isLegacy { + lockSchemaLog.Printf("Legacy lock file accepted: %s", lockFilePath) + return nil + } + + // Missing metadata entirely is suspicious + if metadata == nil { + return fmt.Errorf("lock file %s is missing required metadata. This file may be corrupted or manually edited.\n\nTo fix this, recompile the workflow:\n gh aw compile %s", + lockFilePath, + strings.TrimSuffix(lockFilePath, ".lock.yml")+".md") + } + + // Check schema compatibility + if !IsSchemaVersionSupported(metadata.SchemaVersion) { + // Future version detected + return fmt.Errorf("lock file %s uses unsupported schema version '%s'.\n\nThis file was generated by a newer version of gh-aw that uses incompatible features.\n\nSupported versions: %s\n\nTo fix this:\n 1. Upgrade gh-aw: gh extension upgrade gh-aw\n 2. Or downgrade the lock file by editing the source .md file and recompiling:\n gh aw compile %s", + lockFilePath, + metadata.SchemaVersion, + formatSupportedVersions(), + strings.TrimSuffix(lockFilePath, ".lock.yml")+".md") + } + + lockSchemaLog.Printf("Lock file schema validated: %s (version=%s)", lockFilePath, metadata.SchemaVersion) + return nil +} diff --git a/pkg/workflow/permissions_validation.go b/pkg/workflow/permissions_validation.go index 6a02d3f7eb..5574300155 100644 --- a/pkg/workflow/permissions_validation.go +++ b/pkg/workflow/permissions_validation.go @@ -4,6 +4,7 @@ import ( _ "embed" "encoding/json" "fmt" + "maps" "slices" "sort" "strings" @@ -318,3 +319,159 @@ func formatMissingPermissionsMessage(result *PermissionsValidationResult) string return strings.Join(lines, "\n") } + +// ValidateIncludedPermissions validates that the main workflow permissions satisfy the imported +// workflow requirements. This function is specifically used when merging included/imported workflow +// files to ensure the main workflow has sufficient permissions to support all imported files. +// +// Use ValidatePermissions (in permissions_validator.go) for general permission validation against +// GitHub MCP toolsets. Use ValidateIncludedPermissions (this function) when validating permissions +// from included/imported workflow files. +func (c *Compiler) ValidateIncludedPermissions(topPermissionsYAML string, importedPermissionsJSON string) error { + permissionsValidationLog.Print("Validating included workflow permissions") + + // If no imported permissions, no validation needed + if importedPermissionsJSON == "" || importedPermissionsJSON == "{}" { + permissionsValidationLog.Print("No included workflow permissions to validate") + return nil + } + + // Parse top-level permissions + var topPerms *Permissions + if topPermissionsYAML != "" { + topPerms = NewPermissionsParser(topPermissionsYAML).ToPermissions() + } else { + topPerms = NewPermissions() + } + + // Track missing permissions + missingPermissions := make(map[PermissionScope]PermissionLevel) + insufficientPermissions := make(map[PermissionScope]struct { + required PermissionLevel + current PermissionLevel + }) + + // Split by newlines to handle multiple JSON objects from different imports + lines := strings.Split(importedPermissionsJSON, "\n") + permissionsValidationLog.Printf("Processing %d permission definition lines", len(lines)) + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" || line == "{}" { + continue + } + + // Parse JSON line to permissions map + var importedPermsMap map[string]any + if err := json.Unmarshal([]byte(line), &importedPermsMap); err != nil { + permissionsValidationLog.Printf("Skipping malformed permission entry: %q (error: %v)", line, err) + continue + } + + // Check each permission from the imported map + for scopeStr, levelValue := range importedPermsMap { + scope := PermissionScope(scopeStr) + + // Parse the level - it might be a string or already unmarshaled + var requiredLevel PermissionLevel + if levelStr, ok := levelValue.(string); ok { + requiredLevel = PermissionLevel(levelStr) + } else { + // Skip invalid level values + continue + } + + // Get current level for this scope + currentLevel, exists := topPerms.Get(scope) + + // Validate that the main workflow has sufficient permissions + if !exists || currentLevel == PermissionNone { + // Permission is missing entirely + missingPermissions[scope] = requiredLevel + permissionsValidationLog.Printf("Missing permission: %s: %s", scope, requiredLevel) + } else if !isPermissionSufficient(currentLevel, requiredLevel) { + // Permission exists but is insufficient + insufficientPermissions[scope] = struct { + required PermissionLevel + current PermissionLevel + }{requiredLevel, currentLevel} + permissionsValidationLog.Printf("Insufficient permission: %s: has %s, needs %s", scope, currentLevel, requiredLevel) + } + } + } + + // If there are missing or insufficient permissions, return an error + if len(missingPermissions) > 0 || len(insufficientPermissions) > 0 { + var errorMsg strings.Builder + errorMsg.WriteString("ERROR: Imported workflows require permissions that are not granted in the main workflow.\n\n") + errorMsg.WriteString("The permission set must be explicitly declared in the main workflow.\n\n") + + if len(missingPermissions) > 0 { + errorMsg.WriteString("Missing permissions:\n") + // Sort for consistent output + var scopes []PermissionScope + for scope := range missingPermissions { + scopes = append(scopes, scope) + } + SortPermissionScopes(scopes) + for _, scope := range scopes { + level := missingPermissions[scope] + fmt.Fprintf(&errorMsg, " - %s: %s\n", scope, level) + } + errorMsg.WriteString("\n") + } + + if len(insufficientPermissions) > 0 { + errorMsg.WriteString("Insufficient permissions:\n") + // Sort for consistent output + var scopes []PermissionScope + for scope := range insufficientPermissions { + scopes = append(scopes, scope) + } + SortPermissionScopes(scopes) + for _, scope := range scopes { + info := insufficientPermissions[scope] + fmt.Fprintf(&errorMsg, " - %s: has %s, requires %s\n", scope, info.current, info.required) + } + errorMsg.WriteString("\n") + } + + errorMsg.WriteString("Suggested fix: Add the required permissions to your main workflow frontmatter:\n") + errorMsg.WriteString("permissions:\n") + + // Combine all required permissions for the suggestion + allRequired := make(map[PermissionScope]PermissionLevel) + maps.Copy(allRequired, missingPermissions) + for scope, info := range insufficientPermissions { + allRequired[scope] = info.required + } + + var scopes []PermissionScope + for scope := range allRequired { + scopes = append(scopes, scope) + } + SortPermissionScopes(scopes) + for _, scope := range scopes { + level := allRequired[scope] + fmt.Fprintf(&errorMsg, " %s: %s\n", scope, level) + } + + return fmt.Errorf("%s", errorMsg.String()) + } + + permissionsValidationLog.Print("All included workflow permissions are satisfied by main workflow") + return nil +} + +// isPermissionSufficient checks if the current permission level is sufficient for the required level. +// write > read > none +func isPermissionSufficient(current, required PermissionLevel) bool { + if current == required { + return true + } + // write satisfies read requirement + if current == PermissionWrite && required == PermissionRead { + return true + } + return false +} diff --git a/pkg/workflow/repo_memory.go b/pkg/workflow/repo_memory.go index 35cba0d8a3..2e37ed9e92 100644 --- a/pkg/workflow/repo_memory.go +++ b/pkg/workflow/repo_memory.go @@ -1,25 +1,16 @@ -// This file provides repository memory configuration and validation. +// This file provides repository memory configuration and generation. // // This file handles: // - Repo-memory configuration structures and defaults // - Repo-memory tool configuration extraction and parsing // - Generation of per-memory GitHub token secrets -// - Domain-specific validation for repo-memory configurations // -// # Validation Functions -// -// This file contains domain-specific validation functions for repo-memory: -// - validateNoDuplicateMemoryIDs() - Ensures unique memory identifiers -// -// These validation functions are co-located with repo-memory logic following the -// principle that domain-specific validation belongs in domain files. See validation.go -// for the validation architecture documentation. +// See repo_memory_validation.go for validation functions. package workflow import ( "encoding/json" - "errors" "fmt" "regexp" "strings" @@ -78,34 +69,6 @@ func generateDefaultBranchName(memoryID string, branchPrefix string) string { return fmt.Sprintf("%s/%s", branchPrefix, memoryID) } -// validateBranchPrefix validates that the branch prefix meets requirements -func validateBranchPrefix(prefix string) error { - if prefix == "" { - return nil // Empty means use default - } - - // Check length (4-32 characters) - if len(prefix) < 4 { - return fmt.Errorf("branch-prefix must be at least 4 characters long, got %d", len(prefix)) - } - if len(prefix) > 32 { - return fmt.Errorf("branch-prefix must be at most 32 characters long, got %d", len(prefix)) - } - - // Check for alphanumeric and branch-friendly characters (alphanumeric, hyphens, underscores) - // Use pre-compiled regex from package level for performance - if !branchPrefixValidPattern.MatchString(prefix) { - return fmt.Errorf("branch-prefix must contain only alphanumeric characters, hyphens, and underscores, got '%s'", prefix) - } - - // Cannot be "copilot" - if strings.ToLower(prefix) == "copilot" { - return errors.New("branch-prefix cannot be 'copilot' (reserved)") - } - - return nil -} - // extractRepoMemoryConfig extracts repo-memory configuration from tools section. // workflowID is used to qualify the default branch name (e.g. "memory/{workflowID}"). func (c *Compiler) extractRepoMemoryConfig(toolsConfig *ToolsConfig, workflowID string) (*RepoMemoryConfig, error) { @@ -506,18 +469,6 @@ func (c *Compiler) extractRepoMemoryConfig(toolsConfig *ToolsConfig, workflowID return nil, nil } -// validateNoDuplicateMemoryIDs checks for duplicate memory IDs and returns an error if found -func validateNoDuplicateMemoryIDs(memories []RepoMemoryEntry) error { - seen := make(map[string]bool) - for _, memory := range memories { - if seen[memory.ID] { - return fmt.Errorf("duplicate memory ID found: '%s'. Each memory must have a unique ID", memory.ID) - } - seen[memory.ID] = true - } - return nil -} - // generateRepoMemoryArtifactUpload generates steps to upload repo-memory directories as artifacts // This runs at the end of the agent job (always condition) to save the state func generateRepoMemoryArtifactUpload(builder *strings.Builder, data *WorkflowData) { diff --git a/pkg/workflow/repo_memory_validation.go b/pkg/workflow/repo_memory_validation.go new file mode 100644 index 0000000000..fccc5a4a54 --- /dev/null +++ b/pkg/workflow/repo_memory_validation.go @@ -0,0 +1,61 @@ +// This file provides validation for repo-memory configuration. +// +// # Repo Memory Validation +// +// This file validates that repo-memory entries have unique IDs and that +// branch prefix configuration meets naming requirements. +// +// # Validation Functions +// +// - validateBranchPrefix() - Validates branch prefix length, format, and reserved names +// - validateNoDuplicateMemoryIDs() - Ensures each memory entry has a unique ID +// +// # When to Add Validation Here +// +// Add validation to this file when: +// - Adding new repo-memory configuration constraints +// - Adding new branch naming rules + +package workflow + +import ( + "errors" + "fmt" + "strings" +) + +// validateBranchPrefix validates that the branch prefix meets requirements +func validateBranchPrefix(prefix string) error { + if prefix == "" { + return nil // Empty means use default + } + + // Check length (4-32 characters) + if len(prefix) < 4 { + return fmt.Errorf("branch-prefix must be at least 4 characters long, got %d", len(prefix)) + } + if len(prefix) > 32 { + return fmt.Errorf("branch-prefix must be at most 32 characters long, got %d", len(prefix)) + } + + // Check for alphanumeric and branch-friendly characters (alphanumeric, hyphens, underscores) + // Use pre-compiled regex from package level for performance + if !branchPrefixValidPattern.MatchString(prefix) { + return fmt.Errorf("branch-prefix must contain only alphanumeric characters, hyphens, and underscores, got '%s'", prefix) + } + + // Cannot be "copilot" + if strings.ToLower(prefix) == "copilot" { + return errors.New("branch-prefix cannot be 'copilot' (reserved)") + } + + return nil +} + +// validateNoDuplicateMemoryIDs checks for duplicate memory IDs and returns an error if found. +// Uses the generic validateNoDuplicateIDs helper for consistent duplicate detection. +func validateNoDuplicateMemoryIDs(memories []RepoMemoryEntry) error { + return validateNoDuplicateIDs(memories, func(m RepoMemoryEntry) string { return m.ID }, func(id string) error { + return fmt.Errorf("duplicate memory ID found: '%s'. Each memory must have a unique ID", id) + }) +} diff --git a/pkg/workflow/templatables.go b/pkg/workflow/templatables.go index 79db1ae6d6..f91dad0104 100644 --- a/pkg/workflow/templatables.go +++ b/pkg/workflow/templatables.go @@ -29,7 +29,6 @@ package workflow import ( "fmt" - "slices" "strconv" "strings" @@ -70,27 +69,6 @@ func preprocessBoolFieldAsString(configData map[string]any, fieldName string, lo return nil } -// validateStringEnumField checks that a config field, if present, contains one -// of the allowed string values. Non-string values and unrecognised strings are -// removed from the map (treated as absent) and a warning is logged. Use this -// for fields that are pure string enums with no boolean shorthand. -func validateStringEnumField(configData map[string]any, fieldName string, allowed []string, log *logger.Logger) { - if configData == nil { - return - } - val, exists := configData[fieldName] - if !exists || val == nil { - return - } - strVal, ok := val.(string) - if !ok || !slices.Contains(allowed, strVal) { - if log != nil { - log.Printf("Invalid %s value %v (must be one of %v), ignoring", fieldName, val, allowed) - } - delete(configData, fieldName) - } -} - // buildTemplatableBoolEnvVar returns a YAML environment variable entry for a // templatable boolean field. If value is a GitHub Actions expression it is // embedded unquoted so that GitHub Actions can evaluate it at runtime; diff --git a/pkg/workflow/tools_validation.go b/pkg/workflow/tools_validation.go index 4b9ac9a6ec..343c18cc5a 100644 --- a/pkg/workflow/tools_validation.go +++ b/pkg/workflow/tools_validation.go @@ -2,7 +2,11 @@ package workflow import ( "errors" + "fmt" + "sort" "strings" + + "github.com/github/gh-aw/pkg/parser" ) var toolsValidationLog = newValidationLogger("tools") @@ -225,3 +229,103 @@ func isValidOwnerOrRepo(s string) bool { // injected by the compiler when safe-outputs needs them (see compiler_safe_outputs.go). // The validation was misleading - it would fail even though the compiler would add the // necessary git commands during compilation. + +// ValidateGitHubToolsAgainstToolsets validates that all allowed GitHub tools have their +// corresponding toolsets enabled in the configuration. +func ValidateGitHubToolsAgainstToolsets(allowedTools []string, enabledToolsets []string) error { + githubToolToToolsetLog.Printf("Validating GitHub tools against toolsets: allowed_tools=%d, enabled_toolsets=%d", len(allowedTools), len(enabledToolsets)) + + if len(allowedTools) == 0 { + githubToolToToolsetLog.Print("No tools to validate, skipping") + // No specific tools restricted, validation not needed + return nil + } + + // Create a set of enabled toolsets for fast lookup + enabledSet := make(map[string]bool) + for _, toolset := range enabledToolsets { + enabledSet[toolset] = true + } + githubToolToToolsetLog.Printf("Enabled toolsets: %v", enabledToolsets) + + // Track missing toolsets and which tools need them + missingToolsets := make(map[string][]string) // toolset -> list of tools that need it + + // Track unknown tools for suggestions + var unknownTools []string + var suggestions []string + + for _, tool := range allowedTools { + // Skip wildcard - it means "allow all tools" + if tool == "*" { + continue + } + + requiredToolset, exists := GitHubToolToToolsetMap[tool] + if !exists { + githubToolToToolsetLog.Printf("Tool %s not found in mapping, checking for typo", tool) + + // Get all valid tool names for suggestion + validTools := make([]string, 0, len(GitHubToolToToolsetMap)) + for validTool := range GitHubToolToToolsetMap { + validTools = append(validTools, validTool) + } + sort.Strings(validTools) + + // Try to find close matches + matches := parser.FindClosestMatches(tool, validTools, 1) + if len(matches) > 0 { + githubToolToToolsetLog.Printf("Found suggestion for unknown tool %s: %s", tool, matches[0]) + unknownTools = append(unknownTools, tool) + suggestions = append(suggestions, fmt.Sprintf("%s → %s", tool, matches[0])) + } else { + githubToolToToolsetLog.Printf("No suggestion found for unknown tool: %s", tool) + unknownTools = append(unknownTools, tool) + } + // Tool not in our mapping - this could be a new tool or a typo + // We'll skip validation for unknown tools to avoid false positives + continue + } + + if !enabledSet[requiredToolset] { + githubToolToToolsetLog.Printf("Tool %s requires missing toolset: %s", tool, requiredToolset) + missingToolsets[requiredToolset] = append(missingToolsets[requiredToolset], tool) + } + } + + // Report unknown tools with suggestions if any were found + if len(unknownTools) > 0 { + githubToolToToolsetLog.Printf("Found %d unknown tools", len(unknownTools)) + var errMsg strings.Builder + errMsg.WriteString(fmt.Sprintf("Unknown GitHub tool(s): %s\n\n", formatList(unknownTools))) + + if len(suggestions) > 0 { + errMsg.WriteString("Did you mean:\n") + for _, s := range suggestions { + errMsg.WriteString(fmt.Sprintf(" %s\n", s)) + } + errMsg.WriteString("\n") + } + + // Show a few examples of valid tools + validTools := make([]string, 0, len(GitHubToolToToolsetMap)) + for tool := range GitHubToolToToolsetMap { + validTools = append(validTools, tool) + } + sort.Strings(validTools) + + exampleCount := min(10, len(validTools)) + errMsg.WriteString(fmt.Sprintf("Valid GitHub tools include: %s\n\n", formatList(validTools[:exampleCount]))) + errMsg.WriteString("See all tools: https://github.com/github/gh-aw/blob/main/pkg/workflow/data/github_tool_to_toolset.json") + + return fmt.Errorf("%s", errMsg.String()) + } + + if len(missingToolsets) > 0 { + githubToolToToolsetLog.Printf("Validation failed: missing %d toolsets", len(missingToolsets)) + return NewGitHubToolsetValidationError(missingToolsets) + } + + githubToolToToolsetLog.Print("Validation successful: all tools have required toolsets") + return nil +} diff --git a/pkg/workflow/validation_helpers.go b/pkg/workflow/validation_helpers.go index b75a59c3b2..65b22f61db 100644 --- a/pkg/workflow/validation_helpers.go +++ b/pkg/workflow/validation_helpers.go @@ -26,6 +26,7 @@ package workflow import ( "errors" "fmt" + "slices" "strings" "github.com/github/gh-aw/pkg/logger" @@ -114,3 +115,38 @@ func validateTargetRepoSlug(targetRepoSlug string, log *logger.Logger) bool { } return false } + +// validateStringEnumField checks that a config field, if present, contains one +// of the allowed string values. Non-string values and unrecognised strings are +// removed from the map (treated as absent) and a warning is logged. Use this +// for fields that are pure string enums with no boolean shorthand. +func validateStringEnumField(configData map[string]any, fieldName string, allowed []string, log *logger.Logger) { + if configData == nil { + return + } + val, exists := configData[fieldName] + if !exists || val == nil { + return + } + strVal, ok := val.(string) + if !ok || !slices.Contains(allowed, strVal) { + if log != nil { + log.Printf("Invalid %s value %v (must be one of %v), ignoring", fieldName, val, allowed) + } + delete(configData, fieldName) + } +} + +// validateNoDuplicateIDs checks that all items have unique IDs extracted by idFunc. +// The onDuplicate callback creates the error to return when a duplicate is found. +func validateNoDuplicateIDs[T any](items []T, idFunc func(T) string, onDuplicate func(string) error) error { + seen := make(map[string]bool) + for _, item := range items { + id := idFunc(item) + if seen[id] { + return onDuplicate(id) + } + seen[id] = true + } + return nil +}