diff --git a/pkg/workflow/compiler.go b/pkg/workflow/compiler.go index 37caf1d84bc..50d26961523 100644 --- a/pkg/workflow/compiler.go +++ b/pkg/workflow/compiler.go @@ -177,6 +177,12 @@ func (c *Compiler) validateWorkflowData(workflowData *WorkflowData, markdownPath return formatCompilerError(markdownPath, "error", err.Error(), err) } + // Validate safe-outputs max configuration + log.Printf("Validating safe-outputs max fields") + if err := validateSafeOutputsMax(workflowData.SafeOutputs); err != nil { + return formatCompilerError(markdownPath, "error", err.Error(), err) + } + // Validate safe-outputs allowed-domains configuration log.Printf("Validating safe-outputs allowed-domains") if err := c.validateSafeOutputsAllowedDomains(workflowData.SafeOutputs); err != nil { diff --git a/pkg/workflow/safe_outputs_max_validation_test.go b/pkg/workflow/safe_outputs_max_validation_test.go new file mode 100644 index 00000000000..c99d24cd9c6 --- /dev/null +++ b/pkg/workflow/safe_outputs_max_validation_test.go @@ -0,0 +1,235 @@ +//go:build !integration + +package workflow + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateSafeOutputsMax(t *testing.T) { + t.Run("nil config is valid", func(t *testing.T) { + err := validateSafeOutputsMax(nil) + assert.NoError(t, err, "nil config should be valid") + }) + + t.Run("config with no max fields is valid", func(t *testing.T) { + config := &SafeOutputsConfig{} + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "config with no max fields should be valid") + }) + + t.Run("max of 1 is valid", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("1")}, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "max: 1 should be valid") + }) + + t.Run("max of 5 is valid", func(t *testing.T) { + config := &SafeOutputsConfig{ + CreateIssues: &CreateIssuesConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("5")}, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "max: 5 should be valid") + }) + + t.Run("max of -1 is valid (unlimited)", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("-1")}, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "max: -1 should be valid (means unlimited per spec)") + }) + + t.Run("max of 0 is invalid", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("0")}, + }, + } + err := validateSafeOutputsMax(config) + require.Error(t, err, "max: 0 should be invalid") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error should explain valid values") + assert.Contains(t, err.Error(), "add-comment", "error should mention the field name") + }) + + t.Run("max of -2 is invalid", func(t *testing.T) { + config := &SafeOutputsConfig{ + CreateIssues: &CreateIssuesConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("-2")}, + }, + } + err := validateSafeOutputsMax(config) + require.Error(t, err, "max: -2 should be invalid") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error should explain valid values") + }) + + t.Run("max as GitHub Actions expression is skipped", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("${{ inputs.max }}")}, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "GitHub Actions expression should be skipped") + }) + + t.Run("nil max is valid", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: nil}, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "nil max should be valid") + }) + + t.Run("dispatch_repository tool max of 0 is invalid", func(t *testing.T) { + maxVal := "0" + config := &SafeOutputsConfig{ + DispatchRepository: &DispatchRepositoryConfig{ + Tools: map[string]*DispatchRepositoryToolConfig{ + "my-tool": {Max: &maxVal}, + }, + }, + } + err := validateSafeOutputsMax(config) + require.Error(t, err, "dispatch_repository max: 0 should be invalid") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error should explain valid values") + assert.Contains(t, err.Error(), "my-tool", "error should mention the tool name") + assert.Contains(t, err.Error(), "dispatch_repository", "error should use underscore form") + }) + + t.Run("dispatch_repository tool max of -1 is valid (unlimited)", func(t *testing.T) { + maxVal := "-1" + config := &SafeOutputsConfig{ + DispatchRepository: &DispatchRepositoryConfig{ + Tools: map[string]*DispatchRepositoryToolConfig{ + "my-tool": {Max: &maxVal}, + }, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "dispatch_repository max: -1 should be valid") + }) + + t.Run("dispatch_repository tool max of 1 is valid", func(t *testing.T) { + maxVal := "1" + config := &SafeOutputsConfig{ + DispatchRepository: &DispatchRepositoryConfig{ + Tools: map[string]*DispatchRepositoryToolConfig{ + "my-tool": {Max: &maxVal}, + }, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "dispatch_repository max: 1 should be valid") + }) + + t.Run("dispatch_repository tool max as expression is skipped", func(t *testing.T) { + maxVal := "${{ inputs.max }}" + config := &SafeOutputsConfig{ + DispatchRepository: &DispatchRepositoryConfig{ + Tools: map[string]*DispatchRepositoryToolConfig{ + "my-tool": {Max: &maxVal}, + }, + }, + } + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "GitHub Actions expression for dispatch_repository should be skipped") + }) + + t.Run("multiple configs with one invalid returns error", func(t *testing.T) { + config := &SafeOutputsConfig{ + AddComments: &AddCommentsConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("3")}, + }, + CreateIssues: &CreateIssuesConfig{ + BaseSafeOutputConfig: BaseSafeOutputConfig{Max: strPtr("0")}, + }, + } + err := validateSafeOutputsMax(config) + require.Error(t, err, "config with one invalid max should return error") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error should explain valid values") + }) +} + +func TestValidateSafeOutputsMaxIntegration(t *testing.T) { + compiler := &Compiler{} + + t.Run("max of 0 is rejected during config extraction via compiler", func(t *testing.T) { + frontmatter := map[string]any{ + "safe-outputs": map[string]any{ + "add-comment": map[string]any{ + "max": 0, + }, + }, + } + + config := compiler.extractSafeOutputsConfig(frontmatter) + require.NotNil(t, config, "config should be extracted") + + err := validateSafeOutputsMax(config) + require.Error(t, err, "max: 0 should fail validation") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error message should explain valid values") + }) + + t.Run("max of -2 is rejected during config extraction via compiler", func(t *testing.T) { + frontmatter := map[string]any{ + "safe-outputs": map[string]any{ + "create-issue": map[string]any{ + "max": -2, + }, + }, + } + + config := compiler.extractSafeOutputsConfig(frontmatter) + require.NotNil(t, config, "config should be extracted") + + err := validateSafeOutputsMax(config) + require.Error(t, err, "max: -2 should fail validation") + assert.Contains(t, err.Error(), "max must be a positive integer or -1", "error message should explain valid values") + }) + + t.Run("max of -1 passes validation (unlimited)", func(t *testing.T) { + frontmatter := map[string]any{ + "safe-outputs": map[string]any{ + "add-comment": map[string]any{ + "max": -1, + }, + }, + } + + config := compiler.extractSafeOutputsConfig(frontmatter) + require.NotNil(t, config, "config should be extracted") + + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "max: -1 should pass validation (unlimited per spec)") + }) + + t.Run("max of 1 passes validation", func(t *testing.T) { + frontmatter := map[string]any{ + "safe-outputs": map[string]any{ + "add-comment": map[string]any{ + "max": 1, + }, + }, + } + + config := compiler.extractSafeOutputsConfig(frontmatter) + require.NotNil(t, config, "config should be extracted") + + err := validateSafeOutputsMax(config) + assert.NoError(t, err, "max: 1 should pass validation") + }) +} diff --git a/pkg/workflow/safe_outputs_validation.go b/pkg/workflow/safe_outputs_validation.go index cb466de0650..1cefdfebae9 100644 --- a/pkg/workflow/safe_outputs_validation.go +++ b/pkg/workflow/safe_outputs_validation.go @@ -2,7 +2,10 @@ package workflow import ( "fmt" + "reflect" "regexp" + "sort" + "strconv" "strings" "github.com/github/gh-aw/pkg/stringutil" @@ -405,3 +408,112 @@ func isGitHubExpression(s string) bool { // and there must be something between them return openIndex >= 0 && closeIndex > openIndex+3 } + +var safeOutputsMaxValidationLog = newValidationLogger("safe_outputs_max") + +// isInvalidMaxValue returns true if n is not a valid max field value. +// Valid values are positive integers (n > 0) or -1 (unlimited). +// Invalid values are 0 and negative integers except -1. +func isInvalidMaxValue(n int) bool { + if n == -1 { + return false // -1 = unlimited, explicitly allowed by spec + } + return n <= 0 +} + +// maxInvalidErrSuffix is the common suffix of max validation error messages. +const maxInvalidErrSuffix = "\n\nThe max field controls how many times this safe output can be triggered.\nProvide a positive integer (e.g., max: 1 or max: 5) or -1 for unlimited" + +// validateSafeOutputsMax validates that all max fields in safe-outputs configs hold valid values. +// Valid values are positive integers (n > 0) or -1 (unlimited per spec). +// 0 and other negative values are rejected. +// GitHub Actions expressions (e.g. "${{ inputs.max }}") are not evaluable at compile time +// and are therefore skipped. +func validateSafeOutputsMax(config *SafeOutputsConfig) error { + if config == nil { + return nil + } + + safeOutputsMaxValidationLog.Print("Validating safe-outputs max fields") + + val := reflect.ValueOf(config).Elem() + + // Iterate over sorted field names for deterministic error reporting. + sortedFieldNames := make([]string, 0, len(safeOutputFieldMapping)) + for fieldName := range safeOutputFieldMapping { + sortedFieldNames = append(sortedFieldNames, fieldName) + } + sort.Strings(sortedFieldNames) + + // Validate max on all named safe output fields that embed BaseSafeOutputConfig + for _, fieldName := range sortedFieldNames { + toolName := safeOutputFieldMapping[fieldName] + field := val.FieldByName(fieldName) + if !field.IsValid() || field.IsNil() { + continue + } + + elem := field.Elem() + baseCfgField := elem.FieldByName("BaseSafeOutputConfig") + if !baseCfgField.IsValid() { + continue + } + + maxField := baseCfgField.FieldByName("Max") + if !maxField.IsValid() || maxField.IsNil() { + continue + } + + maxPtr, ok := maxField.Interface().(*string) + if !ok || maxPtr == nil || isExpressionString(*maxPtr) { + continue + } + + n, err := strconv.Atoi(*maxPtr) + if err != nil { + continue + } + + if isInvalidMaxValue(n) { + toolDisplayName := strings.ReplaceAll(toolName, "_", "-") + safeOutputsMaxValidationLog.Printf("Invalid max value %d for %s", n, toolDisplayName) + return fmt.Errorf( + "safe-outputs.%s: max must be a positive integer or -1 (unlimited), got %d%s", + toolDisplayName, n, maxInvalidErrSuffix, + ) + } + } + + // Validate max on dispatch_repository tools (different structure: map of tools). + // Use sorted tool names for deterministic error reporting. + if config.DispatchRepository != nil { + sortedToolNames := make([]string, 0, len(config.DispatchRepository.Tools)) + for toolName := range config.DispatchRepository.Tools { + sortedToolNames = append(sortedToolNames, toolName) + } + sort.Strings(sortedToolNames) + + for _, toolName := range sortedToolNames { + tool := config.DispatchRepository.Tools[toolName] + if tool == nil || tool.Max == nil || isExpressionString(*tool.Max) { + continue + } + + n, err := strconv.Atoi(*tool.Max) + if err != nil { + continue + } + + if isInvalidMaxValue(n) { + safeOutputsMaxValidationLog.Printf("Invalid max value %d for dispatch_repository tool %s", n, toolName) + return fmt.Errorf( + "safe-outputs.dispatch_repository.%s: max must be a positive integer or -1 (unlimited), got %d%s", + toolName, n, maxInvalidErrSuffix, + ) + } + } + } + + safeOutputsMaxValidationLog.Print("Safe-outputs max fields validation passed") + return nil +}