diff --git a/pkg/workflow/allowed_domains_sanitization_test.go b/pkg/workflow/allowed_domains_sanitization_test.go index 98e0bbe5c76..790782028d3 100644 --- a/pkg/workflow/allowed_domains_sanitization_test.go +++ b/pkg/workflow/allowed_domains_sanitization_test.go @@ -5,6 +5,7 @@ package workflow import ( "os" "path/filepath" + "slices" "strings" "testing" @@ -13,6 +14,22 @@ import ( "github.com/github/gh-aw/pkg/testutil" ) +// extractQuotedCSV returns the comma-separated domain list embedded inside +// the first pair of double-quotes in line. Used to enable exact-entry checks +// (avoiding substring false-positives like "corp.example.com" matching "copilot.corp.example.com"). +func extractQuotedCSV(line string) string { + start := strings.Index(line, `"`) + if start < 0 { + return line + } + rest := line[start+1:] + end := strings.Index(rest, `"`) + if end < 0 { + return rest + } + return rest[:end] +} + // TestAllowedDomainsFromNetworkConfig tests that GH_AW_ALLOWED_DOMAINS is computed // from network configuration for sanitization func TestAllowedDomainsFromNetworkConfig(t *testing.T) { @@ -374,10 +391,12 @@ Test that empty allowed-domains falls back to network config. // TestComputeAllowedDomainsForSanitization tests the computeAllowedDomainsForSanitization function func TestComputeAllowedDomainsForSanitization(t *testing.T) { tests := []struct { - name string - engineID string - networkPerms *NetworkPermissions - expectedDomains []string + name string + engineID string + apiTarget string + networkPerms *NetworkPermissions + expectedDomains []string + unexpectedDomains []string }{ { name: "Copilot with custom domains", @@ -434,6 +453,30 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) { "example.com", }, }, + { + name: "Copilot with GHES api-target includes api and base domains", + engineID: "copilot", + apiTarget: "api.acme.ghe.com", + networkPerms: nil, + expectedDomains: []string{ + "api.acme.ghe.com", // GHES API domain + "acme.ghe.com", // GHES base domain (derived from api-target) + "api.github.com", // Copilot default + "github.com", // Copilot default + }, + }, + { + name: "non-api prefix api-target only adds the configured hostname", + engineID: "copilot", + apiTarget: "copilot.corp.example.com", + networkPerms: nil, + expectedDomains: []string{ + "copilot.corp.example.com", // configured hostname + }, + unexpectedDomains: []string{ + "corp.example.com", // base hostname should NOT be added for non-api. prefix + }, + }, } for _, tt := range tests { @@ -442,7 +485,8 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) { compiler := NewCompiler() data := &WorkflowData{ EngineConfig: &EngineConfig{ - ID: tt.engineID, + ID: tt.engineID, + APITarget: tt.apiTarget, }, NetworkPermissions: tt.networkPerms, } @@ -450,12 +494,166 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) { // Call the function domainsStr := compiler.computeAllowedDomainsForSanitization(data) - // Verify expected domains are present + // Verify expected domains are present (substring match is fine here since domain names + // in a CSV string that are exact entries won't appear as substrings of other entries + // when checking expected ones – we only need exact match for the negative "not present" check) for _, expectedDomain := range tt.expectedDomains { if !strings.Contains(domainsStr, expectedDomain) { t.Errorf("Expected domain '%s' not found in result: %s", expectedDomain, domainsStr) } } + + // Verify unexpected domains are absent using exact membership (not substring) + // to avoid false positives where "corp.example.com" matches "copilot.corp.example.com" + parts := strings.Split(domainsStr, ",") + for _, unexpectedDomain := range tt.unexpectedDomains { + if slices.Contains(parts, unexpectedDomain) { + t.Errorf("Unexpected domain '%s' found in result: %s", unexpectedDomain, domainsStr) + } + } + }) + } +} + +// TestAPITargetDomainsInCompiledWorkflow is a regression test verifying that when engine.api-target +// is configured, both --allow-domains (AWF firewall flag) and GH_AW_ALLOWED_DOMAINS (sanitization +// env var) in the compiled lock file contain the api-target hostname and its derived base hostname. +func TestAPITargetDomainsInCompiledWorkflow(t *testing.T) { + tests := []struct { + name string + workflow string + expectedDomains []string + unexpectedDomains []string + }{ + { + name: "GHES api-target adds api and base domains to allow-domains and GH_AW_ALLOWED_DOMAINS", + workflow: `--- +on: push +permissions: + contents: read + issues: read + pull-requests: read +engine: + id: copilot + api-target: api.acme.ghe.com +strict: false +safe-outputs: + create-issue: +--- + +# Test Workflow + +Test workflow with GHES api-target. +`, + expectedDomains: []string{ + "api.acme.ghe.com", // GHES API domain + "acme.ghe.com", // GHES base domain derived from api-target + "api.github.com", // Copilot default + "github.com", // Copilot default + }, + }, + { + name: "non-api prefix api-target only adds the configured hostname", + workflow: `--- +on: push +permissions: + contents: read + issues: read + pull-requests: read +engine: + id: copilot + api-target: copilot.corp.example.com +strict: false +safe-outputs: + create-issue: +--- + +# Test Workflow + +Test workflow with non-api prefix api-target. +`, + expectedDomains: []string{ + "copilot.corp.example.com", // configured hostname + }, + unexpectedDomains: []string{ + "corp.example.com", // base hostname should NOT be added for non-api. prefix + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := testutil.TempDir(t, "api-target-domains-test") + testFile := filepath.Join(tmpDir, "test-workflow.md") + if err := os.WriteFile(testFile, []byte(tt.workflow), 0644); err != nil { + t.Fatal(err) + } + + compiler := NewCompiler() + if err := compiler.CompileWorkflow(testFile); err != nil { + t.Fatalf("Failed to compile workflow: %v", err) + } + + lockFile := stringutil.MarkdownToLockFile(testFile) + lockContent, err := os.ReadFile(lockFile) + if err != nil { + t.Fatalf("Failed to read lock file: %v", err) + } + lockStr := string(lockContent) + + // Check --allow-domains in AWF command contains expected domains + allowDomainsIdx := strings.Index(lockStr, "--allow-domains") + if allowDomainsIdx < 0 { + t.Fatal("--allow-domains flag not found in compiled lock file") + } + // Extract the line with --allow-domains for more targeted checking + allowDomainsEnd := strings.Index(lockStr[allowDomainsIdx:], "\n") + if allowDomainsEnd < 0 { + allowDomainsEnd = len(lockStr) - allowDomainsIdx + } + allowDomainsLine := lockStr[allowDomainsIdx : allowDomainsIdx+allowDomainsEnd] + + for _, domain := range tt.expectedDomains { + if !strings.Contains(allowDomainsLine, domain) { + t.Errorf("Expected domain %q not found in --allow-domains.\nLine: %s", domain, allowDomainsLine) + } + } + // Use exact CSV membership for "not present" checks to avoid false positives + // (e.g. "corp.example.com" would substring-match "copilot.corp.example.com") + allowedDomainsCSV := extractQuotedCSV(allowDomainsLine) + allowedParts := strings.Split(allowedDomainsCSV, ",") + for _, domain := range tt.unexpectedDomains { + if slices.Contains(allowedParts, domain) { + t.Errorf("Unexpected domain %q found in --allow-domains.\nLine: %s", domain, allowDomainsLine) + } + } + + // Check GH_AW_ALLOWED_DOMAINS env var contains expected domains + lines := strings.Split(lockStr, "\n") + var domainsLine string + for _, line := range lines { + if strings.Contains(line, "GH_AW_ALLOWED_DOMAINS:") { + domainsLine = line + break + } + } + if domainsLine == "" { + t.Fatal("GH_AW_ALLOWED_DOMAINS not found in compiled lock file") + } + + for _, domain := range tt.expectedDomains { + if !strings.Contains(domainsLine, domain) { + t.Errorf("Expected domain %q not found in GH_AW_ALLOWED_DOMAINS.\nLine: %s", domain, domainsLine) + } + } + // Use exact CSV membership for "not present" checks + allowedDomainsEnvCSV := extractQuotedCSV(domainsLine) + allowedEnvParts := strings.Split(allowedDomainsEnvCSV, ",") + for _, domain := range tt.unexpectedDomains { + if slices.Contains(allowedEnvParts, domain) { + t.Errorf("Unexpected domain %q found in GH_AW_ALLOWED_DOMAINS.\nLine: %s", domain, domainsLine) + } + } }) } } diff --git a/pkg/workflow/claude_engine.go b/pkg/workflow/claude_engine.go index 33651bf37f5..1aba19274ed 100644 --- a/pkg/workflow/claude_engine.go +++ b/pkg/workflow/claude_engine.go @@ -293,6 +293,10 @@ func (e *ClaudeEngine) GetExecutionSteps(workflowData *WorkflowData, logFile str // Build the AWF-wrapped command using helper function // Get allowed domains (Claude defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains) allowedDomains := GetClaudeAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes) + // Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set + if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" { + allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget) + } // Build AWF command with all configuration // AWF v0.15.0+ uses chroot mode by default, providing transparent access to host binaries diff --git a/pkg/workflow/codex_engine.go b/pkg/workflow/codex_engine.go index 3c546c018da..8cfd960e341 100644 --- a/pkg/workflow/codex_engine.go +++ b/pkg/workflow/codex_engine.go @@ -218,6 +218,10 @@ func (e *CodexEngine) GetExecutionSteps(workflowData *WorkflowData, logFile stri // Build AWF-wrapped command using helper function // Get allowed domains (Codex defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains) allowedDomains := GetCodexAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes) + // Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set + if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" { + allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget) + } // Build the command with agent file handling if specified // INSTRUCTION reading is done inside the AWF command to avoid Docker Compose interpolation diff --git a/pkg/workflow/copilot_engine_execution.go b/pkg/workflow/copilot_engine_execution.go index ceeeb0d406e..f706145aa63 100644 --- a/pkg/workflow/copilot_engine_execution.go +++ b/pkg/workflow/copilot_engine_execution.go @@ -179,6 +179,10 @@ func (e *CopilotEngine) GetExecutionSteps(workflowData *WorkflowData, logFile st // Build AWF-wrapped command using helper function - no mkdir needed, AWF handles it // Get allowed domains (copilot defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains) allowedDomains := GetCopilotAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes) + // Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set + if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" { + allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget) + } // AWF v0.15.0+ uses chroot mode by default, providing transparent access to host binaries // AWF v0.15.0+ with --env-all handles PATH natively (chroot mode is default): diff --git a/pkg/workflow/domains.go b/pkg/workflow/domains.go index 91585734366..94895103bd1 100644 --- a/pkg/workflow/domains.go +++ b/pkg/workflow/domains.go @@ -642,6 +642,62 @@ func formatBlockedDomains(network *NetworkPermissions) string { return strings.Join(blockedDomains, ",") } +// GetAPITargetDomains returns the set of domains to add to the allow-list when engine.api-target is set. +// For a GHES instance with api-target "api.acme.ghe.com", this returns both the API domain +// ("api.acme.ghe.com") and the base hostname ("acme.ghe.com") so that both the GitHub web UI +// and API requests pass through the firewall without manual lock file edits. +// Returns nil for empty apiTarget. +func GetAPITargetDomains(apiTarget string) []string { + if apiTarget == "" { + return nil + } + + domains := []string{apiTarget} + + // Derive the base hostname by stripping the first subdomain label, but only for + // API-style hostnames that start with "api.". + // e.g., "api.acme.ghe.com" → "acme.ghe.com" + // Only add the base hostname if it still looks like a multi-label hostname (contains a dot). + if strings.HasPrefix(apiTarget, "api.") { + if idx := strings.Index(apiTarget, "."); idx > 0 { + baseHost := apiTarget[idx+1:] + if strings.Contains(baseHost, ".") && baseHost != apiTarget { + domains = append(domains, baseHost) + } + } + } + + return domains +} + +// mergeAPITargetDomains merges the api-target domains into an existing comma-separated domain string. +// When engine.api-target is set, both the API hostname and its base hostname are added to the allow-list. +// Returns the original string unchanged when apiTarget is empty. +func mergeAPITargetDomains(domainsStr string, apiTarget string) string { + extraDomains := GetAPITargetDomains(apiTarget) + if len(extraDomains) == 0 { + return domainsStr + } + + domainMap := make(map[string]bool) + for d := range strings.SplitSeq(domainsStr, ",") { + d = strings.TrimSpace(d) + if d != "" { + domainMap[d] = true + } + } + for _, d := range extraDomains { + domainMap[d] = true + } + + result := make([]string, 0, len(domainMap)) + for d := range domainMap { + result = append(result, d) + } + sort.Strings(result) + return strings.Join(result, ",") +} + // computeAllowedDomainsForSanitization computes the allowed domains for sanitization // based on the engine and network configuration, matching what's provided to the firewall func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) string { @@ -655,20 +711,28 @@ func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) stri // Compute domains based on engine type, including tools and runtimes to match // what's provided to the actual firewall at runtime + var base string switch engineID { case "copilot": - return GetCopilotAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) + base = GetCopilotAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) case "codex": - return GetCodexAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) + base = GetCodexAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) case "claude": - return GetClaudeAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) + base = GetClaudeAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) case "gemini": - return GetGeminiAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) + base = GetGeminiAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes) default: // For other engines, use network permissions only domains := GetAllowedDomains(data.NetworkPermissions) - return strings.Join(domains, ",") + base = strings.Join(domains, ",") } + + // Add GHES/custom API target domains so GH_AW_ALLOWED_DOMAINS stays in sync with --allow-domains + if data.EngineConfig != nil && data.EngineConfig.APITarget != "" { + base = mergeAPITargetDomains(base, data.EngineConfig.APITarget) + } + + return base } // expandAllowedDomains expands a list of domain entries (which may include ecosystem diff --git a/pkg/workflow/domains_test.go b/pkg/workflow/domains_test.go index dac7f4c1156..aa2a024fa85 100644 --- a/pkg/workflow/domains_test.go +++ b/pkg/workflow/domains_test.go @@ -1045,3 +1045,118 @@ func TestDefaultSafeOutputsEcosystem(t *testing.T) { t.Errorf("Expected %d unique domains in default-safe-outputs (union of components), got %d", len(expectedDomains), len(result)) } } + +func TestGetAPITargetDomains(t *testing.T) { + tests := []struct { + name string + apiTarget string + expected []string + }{ + { + name: "empty api-target returns nil", + apiTarget: "", + expected: nil, + }, + { + name: "GHES api-target with api. prefix returns both api and base domains", + apiTarget: "api.acme.ghe.com", + expected: []string{"api.acme.ghe.com", "acme.ghe.com"}, + }, + { + name: "GHES api-target custom domain", + apiTarget: "api.contoso-aw.ghe.com", + expected: []string{"api.contoso-aw.ghe.com", "contoso-aw.ghe.com"}, + }, + { + name: "enterprise githubcopilot.com api-target", + apiTarget: "api.enterprise.githubcopilot.com", + expected: []string{"api.enterprise.githubcopilot.com", "enterprise.githubcopilot.com"}, + }, + { + name: "non-api. prefix hostname returns only itself", + apiTarget: "copilot.example.com", + expected: []string{"copilot.example.com"}, + }, + { + name: "single label hostname (no dot) returns only itself", + apiTarget: "localhost", + expected: []string{"localhost"}, + }, + { + name: "two-label hostname does not add TLD alone", + apiTarget: "example.com", + expected: []string{"example.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetAPITargetDomains(tt.apiTarget) + if tt.expected == nil { + if result != nil { + t.Errorf("Expected nil, got %v", result) + } + return + } + if len(result) != len(tt.expected) { + t.Errorf("Expected %d domains %v, got %d domains %v", len(tt.expected), tt.expected, len(result), result) + return + } + for _, expected := range tt.expected { + if !slices.Contains(result, expected) { + t.Errorf("Expected domain %q not found in result %v", expected, result) + } + } + }) + } +} + +func TestMergeAPITargetDomains(t *testing.T) { + tests := []struct { + name string + domainsStr string + apiTarget string + wantIn []string + wantNotIn []string + }{ + { + name: "empty api-target leaves domains unchanged", + domainsStr: "github.com,api.github.com", + apiTarget: "", + wantIn: []string{"github.com", "api.github.com"}, + }, + { + name: "GHES api-target adds both api and base domains", + domainsStr: "github.com,api.github.com", + apiTarget: "api.acme.ghe.com", + wantIn: []string{"github.com", "api.github.com", "api.acme.ghe.com", "acme.ghe.com"}, + }, + { + name: "result is sorted and deduplicated", + domainsStr: "api.acme.ghe.com,github.com", + apiTarget: "api.acme.ghe.com", + wantIn: []string{"api.acme.ghe.com", "acme.ghe.com", "github.com"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mergeAPITargetDomains(tt.domainsStr, tt.apiTarget) + domains := strings.Split(result, ",") + domainSet := make(map[string]bool) + for _, d := range domains { + domainSet[d] = true + } + for _, want := range tt.wantIn { + if !domainSet[want] { + t.Errorf("Expected domain %q in result %q, but not found", want, result) + } + } + for _, notWant := range tt.wantNotIn { + if domainSet[notWant] { + t.Errorf("Did not expect domain %q in result %q", notWant, result) + } + } + }) + } +} diff --git a/pkg/workflow/gemini_engine.go b/pkg/workflow/gemini_engine.go index 8987697fdab..6d1676480fe 100644 --- a/pkg/workflow/gemini_engine.go +++ b/pkg/workflow/gemini_engine.go @@ -240,6 +240,10 @@ func (e *GeminiEngine) GetExecutionSteps(workflowData *WorkflowData, logFile str workflowData.Tools, workflowData.Runtimes, ) + // Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set + if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" { + allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget) + } npmPathSetup := GetNpmBinPathSetup() geminiCommandWithPath := fmt.Sprintf("%s && %s", npmPathSetup, geminiCommand)