Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 47 additions & 13 deletions pkg/workflow/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/github/gh-aw/pkg/console"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/stringutil"
"github.com/goccy/go-yaml"
)

var log = logger.New("workflow:compiler")
Expand Down Expand Up @@ -426,37 +427,70 @@ func (c *Compiler) generateAndValidateYAML(workflowData *WorkflowData, markdownP
return "", formattedErr
}

// Template injection validation and GitHub Actions schema validation both require a
// parsed representation of the compiled YAML. Parse it once here and share the
// result between the two validators to avoid redundant yaml.Unmarshal calls.
//
// Fast-path: if the YAML contains no unsafe context expressions we can skip the
// parse (and template-injection check) entirely for the common case.
needsTemplateCheck := unsafeContextRegex.MatchString(yamlContent)
needsSchemaCheck := !c.skipValidation

var parsedWorkflow map[string]any
if needsTemplateCheck || needsSchemaCheck {
log.Print("Parsing compiled YAML for validation")
if parseErr := yaml.Unmarshal([]byte(yamlContent), &parsedWorkflow); parseErr != nil {
// If parsing fails here the subsequent validators would also fail; keep going
// so we surface the root error from the right validator.
parsedWorkflow = nil
}
}

// Validate for template injection vulnerabilities - detect unsafe expression usage in run: commands
log.Print("Validating for template injection vulnerabilities")
if err := validateNoTemplateInjection(yamlContent); err != nil {
// Store error first so we can write invalid YAML before returning
formattedErr := formatCompilerError(markdownPath, "error", err.Error(), err)
// Write the invalid YAML to a .invalid.yml file for inspection
invalidFile := strings.TrimSuffix(lockFile, ".lock.yml") + ".invalid.yml"
if writeErr := os.WriteFile(invalidFile, []byte(yamlContent), 0644); writeErr == nil {
fmt.Fprintln(os.Stderr, console.FormatWarningMessage("Workflow with template injection risks written to: "+console.ToRelativePath(invalidFile)))
if needsTemplateCheck {
log.Print("Validating for template injection vulnerabilities")
var templateErr error
if parsedWorkflow != nil {
templateErr = validateNoTemplateInjectionFromParsed(parsedWorkflow)
} else {
templateErr = validateNoTemplateInjection(yamlContent)
}
if templateErr != nil {
// Store error first so we can write invalid YAML before returning
formattedErr := formatCompilerError(markdownPath, "error", templateErr.Error(), templateErr)
// Write the invalid YAML to a .invalid.yml file for inspection
invalidFile := strings.TrimSuffix(lockFile, ".lock.yml") + ".invalid.yml"
if writeErr := os.WriteFile(invalidFile, []byte(yamlContent), 0644); writeErr == nil {
fmt.Fprintln(os.Stderr, console.FormatWarningMessage("Workflow with template injection risks written to: "+console.ToRelativePath(invalidFile)))
}
return "", formattedErr
}
return "", formattedErr
}

// Validate against GitHub Actions schema (unless skipped)
if !c.skipValidation {
if needsSchemaCheck {
log.Print("Validating workflow against GitHub Actions schema")
if err := c.validateGitHubActionsSchema(yamlContent); err != nil {
var schemaErr error
if parsedWorkflow != nil {
schemaErr = c.validateGitHubActionsSchemaFromParsed(parsedWorkflow)
} else {
schemaErr = c.validateGitHubActionsSchema(yamlContent)
}
if schemaErr != nil {
// Try to point at the exact line of the failing field in the source markdown.
// extractSchemaErrorField unwraps the error chain to find the top-level field
// name (e.g. "timeout-minutes"), which findFrontmatterFieldLine then locates in
// the source frontmatter so the error is IDE-navigable.
fieldLine := 1
if fieldName := extractSchemaErrorField(err); fieldName != "" {
if fieldName := extractSchemaErrorField(schemaErr); fieldName != "" {
frontmatterLines := strings.Split(workflowData.FrontmatterYAML, "\n")
if line := findFrontmatterFieldLine(frontmatterLines, 2, fieldName); line > 0 {
fieldLine = line
}
}
// Store error first so we can write invalid YAML before returning
formattedErr := formatCompilerErrorWithPosition(markdownPath, fieldLine, 1, "error",
fmt.Sprintf("invalid workflow: %v", err), err)
fmt.Sprintf("invalid workflow: %v", schemaErr), schemaErr)
// Write the invalid YAML to a .invalid.yml file for inspection
invalidFile := strings.TrimSuffix(lockFile, ".lock.yml") + ".invalid.yml"
if writeErr := os.WriteFile(invalidFile, []byte(yamlContent), 0644); writeErr == nil {
Expand Down
7 changes: 7 additions & 0 deletions pkg/workflow/schema_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ func (c *Compiler) validateGitHubActionsSchema(yamlContent string) error {
return fmt.Errorf("failed to parse YAML for schema validation: %w", err)
}

return c.validateGitHubActionsSchemaFromParsed(workflowData)
}

// validateGitHubActionsSchemaFromParsed validates pre-parsed workflow data against the
// GitHub Actions schema. Callers that already hold a parsed representation of the
// compiled YAML should use this variant to avoid an extra yaml.Unmarshal call.
func (c *Compiler) validateGitHubActionsSchemaFromParsed(workflowData any) error {
// Get the cached compiled schema
schema, err := getCompiledSchema()
if err != nil {
Expand Down
14 changes: 10 additions & 4 deletions pkg/workflow/secret_extraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ var secretLog = logger.New("workflow:secret_extraction")
// Matches: ${{ secrets.SECRET_NAME }} or ${{ secrets.SECRET_NAME || 'default' }}
var secretExprPattern = regexp.MustCompile(`\$\{\{\s*secrets\.([A-Z_][A-Z0-9_]*)\s*(?:\|\|.*?)?\s*\}\}`)

// Pre-compiled regex patterns for ExtractSecretsFromValue (performance optimization)
var (
// secretsExprFindPattern matches all ${{ ... }} expressions in a value
secretsExprFindPattern = regexp.MustCompile(`\$\{\{[^}]+\}\}`)
// secretsNamePattern extracts the secret variable name from an expression
secretsNamePattern = regexp.MustCompile(`secrets\.([A-Z_][A-Z0-9_]*)`)
)

// SecretExpression represents a parsed secret expression
type SecretExpression struct {
VarName string // The secret variable name (e.g., "DD_API_KEY")
Expand Down Expand Up @@ -47,15 +55,13 @@ func ExtractSecretsFromValue(value string) map[string]string {

// Find all ${{ ... }} expressions in the value
// Pattern matches from ${{ to }} allowing nested content
exprPattern := regexp.MustCompile(`\$\{\{[^}]+\}\}`)
expressions := exprPattern.FindAllString(value, -1)
expressions := secretsExprFindPattern.FindAllString(value, -1)

// For each expression, check if it contains secrets.VARIABLE_NAME
// This handles both simple cases like "${{ secrets.TOKEN }}"
// and complex sub-expressions like "${{ github.workflow && secrets.TOKEN }}"
secretPattern := regexp.MustCompile(`secrets\.([A-Z_][A-Z0-9_]*)`)
for _, expr := range expressions {
matches := secretPattern.FindAllStringSubmatch(expr, -1)
matches := secretsNamePattern.FindAllStringSubmatch(expr, -1)
for _, match := range matches {
if len(match) >= 2 {
varName := match[1]
Expand Down
18 changes: 18 additions & 0 deletions pkg/workflow/template_injection_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ var (
func validateNoTemplateInjection(yamlContent string) error {
templateInjectionValidationLog.Print("Validating compiled YAML for template injection risks")

// Fast-path: if the YAML contains no unsafe context expressions at all, skip the
// expensive full YAML parse. The unsafe patterns we detect are:
// ${{ github.event.* }}, ${{ steps.*.outputs.* }}, ${{ inputs.* }}
// If none of those strings appear anywhere in the compiled YAML, there can be
// no violations.
if !unsafeContextRegex.MatchString(yamlContent) {
templateInjectionValidationLog.Print("No unsafe context expressions found – skipping template injection check")
return nil
}

// Parse YAML to walk the tree and extract run fields
var workflow map[string]any
if err := yaml.Unmarshal([]byte(yamlContent), &workflow); err != nil {
Expand All @@ -82,6 +92,14 @@ func validateNoTemplateInjection(yamlContent string) error {
return nil
}

return validateNoTemplateInjectionFromParsed(workflow)
}

// validateNoTemplateInjectionFromParsed checks a pre-parsed workflow map for template
// injection vulnerabilities. It is called by validateNoTemplateInjection (which
// handles the YAML parse) and may also be called directly when the caller already
// holds a parsed representation of the compiled YAML, avoiding a redundant parse.
func validateNoTemplateInjectionFromParsed(workflow map[string]any) error {
// Extract all run blocks from the workflow
runBlocks := extractRunBlocks(workflow)
templateInjectionValidationLog.Printf("Found %d run blocks to scan", len(runBlocks))
Expand Down
24 changes: 16 additions & 8 deletions pkg/workflow/yaml.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,20 @@ import (
"sort"
"strconv"
"strings"
"sync"

"github.com/github/gh-aw/pkg/logger"
"github.com/goccy/go-yaml"
)

var yamlLog = logger.New("workflow:yaml")

// yamlNullPattern matches `: null` at the end of a line (pre-compiled for performance)
var yamlNullPattern = regexp.MustCompile(`:\s*null\s*$`)

// unquoteYAMLKeyCache caches compiled regexes for UnquoteYAMLKey by key name
var unquoteYAMLKeyCache sync.Map

// UnquoteYAMLKey removes quotes from a YAML key at the start of a line.
//
// The YAML marshaler automatically adds quotes around YAML reserved words and keywords
Expand Down Expand Up @@ -131,9 +138,14 @@ func UnquoteYAMLKey(yamlStr string, key string) string {
// Pattern: (start of line or newline) + (optional whitespace) + quoted key + colon
pattern := `(^|\n)([ \t]*)"` + regexp.QuoteMeta(key) + `":`

// Replacement: keep the line start and whitespace, but remove quotes from the key
// Need to use ReplaceAllStringFunc to properly construct the replacement
re := regexp.MustCompile(pattern)
// Use cached compiled regex to avoid recompiling on every call
var re *regexp.Regexp
if cached, ok := unquoteYAMLKeyCache.Load(key); ok {
re = cached.(*regexp.Regexp)
} else {
re = regexp.MustCompile(pattern)
unquoteYAMLKeyCache.Store(key, re)
}
return re.ReplaceAllStringFunc(yamlStr, func(match string) string {
// Find the submatch groups
submatches := re.FindStringSubmatch(match)
Expand Down Expand Up @@ -296,14 +308,10 @@ func OrderMapFields(data map[string]any, priorityFields []string) yaml.MapSlice
func CleanYAMLNullValues(yamlStr string) string {
yamlLog.Print("Cleaning null values from YAML")

// Create a regex pattern that matches `: null` at the end of a line
// Pattern: colon + optional whitespace + "null" + optional whitespace + end of line
pattern := regexp.MustCompile(`:\s*null\s*$`)

// Split into lines, process each line, and rejoin
lines := strings.Split(yamlStr, "\n")
for i, line := range lines {
lines[i] = pattern.ReplaceAllString(line, ":")
lines[i] = yamlNullPattern.ReplaceAllString(line, ":")
}

return strings.Join(lines, "\n")
Expand Down
Loading