Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions pkg/sliceutil/sliceutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,36 @@ func ContainsAny(s string, substrings ...string) bool {
func ContainsIgnoreCase(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}

// Filter returns a new slice containing only elements that match the predicate.
// This is a pure function that does not modify the input slice.
func Filter[T any](slice []T, predicate func(T) bool) []T {
result := make([]T, 0, len(slice))
for _, item := range slice {
if predicate(item) {
result = append(result, item)
}
}
return result
}

// Map transforms each element in a slice using the provided function.
// This is a pure function that does not modify the input slice.
func Map[T, U any](slice []T, transform func(T) U) []U {
result := make([]U, len(slice))
for i, item := range slice {
result[i] = transform(item)
}
return result
}

// MapToSlice converts a map's keys to a slice.
// The order of elements is not guaranteed as map iteration order is undefined.
// This is a pure function that does not modify the input map.
func MapToSlice[K comparable, V any](m map[K]V) []K {
result := make([]K, 0, len(m))
for key := range m {
result = append(result, key)
}
return result
}
82 changes: 39 additions & 43 deletions pkg/workflow/action_pins.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/githubnext/gh-aw/pkg/console"
"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/sliceutil"
)

var actionPinsLog = logger.New("workflow:action_pins")
Expand Down Expand Up @@ -85,7 +86,7 @@ func getActionPins() []ActionPin {
actionPinsLog.Printf("Found %d key/version mismatches in action_pins.json", mismatchCount)
}

// Convert map to sorted slice
// Convert map values to slice - immutable initialization
pins := make([]ActionPin, 0, len(data.Entries))
for _, pin := range data.Entries {
pins = append(pins, pin)
Expand All @@ -109,16 +110,23 @@ func getActionPins() []ActionPin {
return cachedActionPins
}

// sortPinsByVersion sorts action pins by version in descending order (highest first)
// Uses Go's standard library sort with custom comparison function
func sortPinsByVersion(pins []ActionPin) {
sort.Slice(pins, func(i, j int) bool {
// sortPinsByVersion sorts action pins by version in descending order (highest first).
// This function returns a new sorted slice without modifying the input.
// This is an immutable operation for better safety and clarity.
func sortPinsByVersion(pins []ActionPin) []ActionPin {
// Create a copy to avoid mutating the input
result := make([]ActionPin, len(pins))
copy(result, pins)

sort.Slice(result, func(i, j int) bool {
// Strip 'v' prefix for comparison
v1 := strings.TrimPrefix(pins[i].Version, "v")
v2 := strings.TrimPrefix(pins[j].Version, "v")
v1 := strings.TrimPrefix(result[i].Version, "v")
v2 := strings.TrimPrefix(result[j].Version, "v")
// Return true if v1 > v2 to get descending order
return compareVersions(v1, v2) > 0
})

return result
}

// GetActionPin returns the pinned action reference for a given action repository
Expand All @@ -128,24 +136,21 @@ func sortPinsByVersion(pins []ActionPin) {
func GetActionPin(actionRepo string) string {
actionPins := getActionPins()

// Find all pins matching the repo
var matchingPins []ActionPin
for _, pin := range actionPins {
if pin.Repo == actionRepo {
matchingPins = append(matchingPins, pin)
}
}
// Find all pins matching the repo - using functional filter
matchingPins := sliceutil.Filter(actionPins, func(pin ActionPin) bool {
return pin.Repo == actionRepo
})

if len(matchingPins) == 0 {
// If no pin exists, return empty string to signal that this action is not pinned
return ""
}

// Sort matching pins by version (descending - latest first)
sortPinsByVersion(matchingPins)
// Sort matching pins by version (descending - latest first) - immutable operation
sortedPins := sortPinsByVersion(matchingPins)

// Return the latest version (first after sorting)
latestPin := matchingPins[0]
latestPin := sortedPins[0]
return formatActionReference(actionRepo, latestPin.SHA, latestPin.Version)
}

Expand Down Expand Up @@ -185,22 +190,19 @@ func GetActionPinWithData(actionRepo, version string, data *WorkflowData) (strin
actionPinsLog.Printf("Falling back to hardcoded pins for %s@%s", actionRepo, version)
actionPins := getActionPins()

// Find all pins matching the repo
var matchingPins []ActionPin
for _, pin := range actionPins {
if pin.Repo == actionRepo {
matchingPins = append(matchingPins, pin)
}
}
// Find all pins matching the repo - using functional filter
matchingPins := sliceutil.Filter(actionPins, func(pin ActionPin) bool {
return pin.Repo == actionRepo
})

if len(matchingPins) == 0 {
// No pins found for this repo, will handle below
actionPinsLog.Printf("No hardcoded pins found for %s", actionRepo)
} else {
actionPinsLog.Printf("Found %d hardcoded pin(s) for %s", len(matchingPins), actionRepo)

// Sort matching pins by version (descending - highest first)
sortPinsByVersion(matchingPins)
// Sort matching pins by version (descending - highest first) - immutable operation
matchingPins = sortPinsByVersion(matchingPins)

// First, try to find an exact version match (for version tags)
for _, pin := range matchingPins {
Expand Down Expand Up @@ -228,13 +230,10 @@ func GetActionPinWithData(actionRepo, version string, data *WorkflowData) (strin
// Semver compatibility means respecting major version boundaries
// (e.g., v5 -> highest v5.x.x, not v6.x.x)
if !data.StrictMode && len(matchingPins) > 0 {
// Filter for semver-compatible pins (matching major version)
var compatiblePins []ActionPin
for _, pin := range matchingPins {
if isSemverCompatible(pin.Version, version) {
compatiblePins = append(compatiblePins, pin)
}
}
// Filter for semver-compatible pins (matching major version) - using functional filter
compatiblePins := sliceutil.Filter(matchingPins, func(pin ActionPin) bool {
return isSemverCompatible(pin.Version, version)
})

// If we found compatible pins, use the highest one (first after sorting)
// Otherwise fall back to the highest overall pin
Expand Down Expand Up @@ -417,21 +416,18 @@ func ApplyActionPinsToTypedSteps(steps []*WorkflowStep, data *WorkflowData) []*W
func GetActionPinByRepo(repo string) (ActionPin, bool) {
actionPins := getActionPins()

// Find all pins matching the repo
var matchingPins []ActionPin
for _, pin := range actionPins {
if pin.Repo == repo {
matchingPins = append(matchingPins, pin)
}
}
// Find all pins matching the repo - using functional filter
matchingPins := sliceutil.Filter(actionPins, func(pin ActionPin) bool {
return pin.Repo == repo
})

if len(matchingPins) == 0 {
return ActionPin{}, false
}

// Sort matching pins by version (descending - latest first)
sortPinsByVersion(matchingPins)
// Sort matching pins by version (descending - latest first) - immutable operation
sortedPins := sortPinsByVersion(matchingPins)

// Return the latest version (first after sorting)
return matchingPins[0], true
return sortedPins[0], true
}
7 changes: 2 additions & 5 deletions pkg/workflow/action_pins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -831,11 +831,8 @@ func TestSortPinsByVersion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Make a copy to avoid modifying the test case
result := make([]ActionPin, len(tt.input))
copy(result, tt.input)

sortPinsByVersion(result)
// sortPinsByVersion now returns a new sorted slice (immutable operation)
result := sortPinsByVersion(tt.input)

if len(result) != len(tt.expected) {
t.Errorf("sortPinsByVersion() length = %d, want %d", len(result), len(tt.expected))
Expand Down
8 changes: 3 additions & 5 deletions pkg/workflow/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strings"

"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/sliceutil"
)

var envLog = logger.New("workflow:env")
Expand All @@ -20,11 +21,8 @@ func writeHeadersToYAML(yaml *strings.Builder, headers map[string]string, indent

envLog.Printf("Writing %d headers to YAML", len(headers))

// Sort keys for deterministic output
keys := make([]string, 0, len(headers))
for key := range headers {
keys = append(keys, key)
}
// Sort keys for deterministic output - using functional helper
keys := sliceutil.MapToSlice(headers)
sort.Strings(keys)

// Write each header with proper comma placement
Expand Down
19 changes: 7 additions & 12 deletions pkg/workflow/mcp_config_custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/githubnext/gh-aw/pkg/console"
"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/parser"
"github.com/githubnext/gh-aw/pkg/sliceutil"
"github.com/githubnext/gh-aw/pkg/types"
)

Expand Down Expand Up @@ -316,10 +317,8 @@ func renderSharedMCPConfig(yaml *strings.Builder, toolName string, toolConfig ma
case "env":
if renderer.Format == "toml" {
fmt.Fprintf(yaml, "%senv = { ", renderer.IndentLevel)
envKeys := make([]string, 0, len(mcpConfig.Env))
for key := range mcpConfig.Env {
envKeys = append(envKeys, key)
}
// Using functional helper to extract map keys
envKeys := sliceutil.MapToSlice(mcpConfig.Env)
sort.Strings(envKeys)
for i, envKey := range envKeys {
if i > 0 {
Expand Down Expand Up @@ -390,10 +389,8 @@ func renderSharedMCPConfig(yaml *strings.Builder, toolName string, toolConfig ma
// TOML format for HTTP headers (Codex style)
if len(mcpConfig.Headers) > 0 {
fmt.Fprintf(yaml, "%shttp_headers = { ", renderer.IndentLevel)
headerKeys := make([]string, 0, len(mcpConfig.Headers))
for key := range mcpConfig.Headers {
headerKeys = append(headerKeys, key)
}
// Using functional helper to extract map keys
headerKeys := sliceutil.MapToSlice(mcpConfig.Headers)
sort.Strings(headerKeys)
for i, headerKey := range headerKeys {
if i > 0 {
Expand All @@ -409,10 +406,8 @@ func renderSharedMCPConfig(yaml *strings.Builder, toolName string, toolConfig ma
comma = ""
}
fmt.Fprintf(yaml, "%s\"headers\": {\n", renderer.IndentLevel)
headerKeys := make([]string, 0, len(mcpConfig.Headers))
for key := range mcpConfig.Headers {
headerKeys = append(headerKeys, key)
}
// Using functional helper to extract map keys
headerKeys := sliceutil.MapToSlice(mcpConfig.Headers)
sort.Strings(headerKeys)
for headerIndex, headerKey := range headerKeys {
headerComma := ","
Expand Down
31 changes: 10 additions & 21 deletions pkg/workflow/mcp_setup_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import (

"github.com/githubnext/gh-aw/pkg/constants"
"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/sliceutil"
)

var mcpSetupGeneratorLog = logger.New("workflow:mcp_setup_generator")
Expand Down Expand Up @@ -329,10 +330,7 @@ func (c *Compiler) generateMCPSetup(yaml *strings.Builder, tools map[string]any,
yaml.WriteString(" run: |\n")

// Generate individual tool files (sorted by name for stable code generation)
safeInputToolNames := make([]string, 0, len(workflowData.SafeInputs.Tools))
for toolName := range workflowData.SafeInputs.Tools {
safeInputToolNames = append(safeInputToolNames, toolName)
}
safeInputToolNames := sliceutil.MapToSlice(workflowData.SafeInputs.Tools)
sort.Strings(safeInputToolNames)

for _, toolName := range safeInputToolNames {
Expand Down Expand Up @@ -408,11 +406,8 @@ func (c *Compiler) generateMCPSetup(yaml *strings.Builder, tools map[string]any,

safeInputsSecrets := collectSafeInputsSecrets(workflowData.SafeInputs)
if len(safeInputsSecrets) > 0 {
// Sort env var names for consistent output
envVarNames := make([]string, 0, len(safeInputsSecrets))
for envVarName := range safeInputsSecrets {
envVarNames = append(envVarNames, envVarName)
}
// Sort env var names for consistent output - using functional helper
envVarNames := sliceutil.MapToSlice(safeInputsSecrets)
sort.Strings(envVarNames)

for _, envVarName := range envVarNames {
Expand Down Expand Up @@ -452,10 +447,8 @@ func (c *Compiler) generateMCPSetup(yaml *strings.Builder, tools map[string]any,
yaml.WriteString(" env:\n")

// Sort environment variable names for consistent output
envVarNames := make([]string, 0, len(mcpEnvVars))
for envVarName := range mcpEnvVars {
envVarNames = append(envVarNames, envVarName)
}
// Using functional helper to extract map keys
envVarNames := sliceutil.MapToSlice(mcpEnvVars)
sort.Strings(envVarNames)

// Write environment variables in sorted order
Expand Down Expand Up @@ -522,10 +515,8 @@ func (c *Compiler) generateMCPSetup(yaml *strings.Builder, tools map[string]any,

// Add user-configured environment variables
if len(gatewayConfig.Env) > 0 {
envVarNames := make([]string, 0, len(gatewayConfig.Env))
for envVarName := range gatewayConfig.Env {
envVarNames = append(envVarNames, envVarName)
}
// Using functional helper to extract map keys
envVarNames := sliceutil.MapToSlice(gatewayConfig.Env)
sort.Strings(envVarNames)

for _, envVarName := range envVarNames {
Expand Down Expand Up @@ -612,10 +603,8 @@ func (c *Compiler) generateMCPSetup(yaml *strings.Builder, tools map[string]any,
containerCmd += " -e GH_AW_SAFE_OUTPUTS_API_KEY"
}
if len(gatewayConfig.Env) > 0 {
envVarNames := make([]string, 0, len(gatewayConfig.Env))
for envVarName := range gatewayConfig.Env {
envVarNames = append(envVarNames, envVarName)
}
// Using functional helper to extract map keys
envVarNames := sliceutil.MapToSlice(gatewayConfig.Env)
sort.Strings(envVarNames)
for _, envVarName := range envVarNames {
containerCmd += " -e " + envVarName
Expand Down
8 changes: 3 additions & 5 deletions pkg/workflow/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"

"github.com/githubnext/gh-aw/pkg/logger"
"github.com/githubnext/gh-aw/pkg/sliceutil"
)

var sandboxLog = logger.New("workflow:sandbox")
Expand Down Expand Up @@ -166,11 +167,8 @@ func generateSRTConfigJSON(workflowData *WorkflowData) (string, error) {
}
}

// Convert to slice
allowedDomains := make([]string, 0, len(domainMap))
for domain := range domainMap {
allowedDomains = append(allowedDomains, domain)
}
// Convert map keys to slice - using functional helper
allowedDomains := sliceutil.MapToSlice(domainMap)
SortStrings(allowedDomains)

srtConfig := &SandboxRuntimeConfig{
Expand Down
Loading