Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 204 additions & 6 deletions pkg/workflow/allowed_domains_sanitization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package workflow
import (
"os"
"path/filepath"
"slices"
"strings"
"testing"

Expand All @@ -13,6 +14,22 @@ import (
"github.com/github/gh-aw/pkg/testutil"
)

// extractQuotedCSV returns the comma-separated domain list embedded inside
// the first pair of double-quotes in line. Used to enable exact-entry checks
// (avoiding substring false-positives like "corp.example.com" matching "copilot.corp.example.com").
func extractQuotedCSV(line string) string {
start := strings.Index(line, `"`)
if start < 0 {
return line
}
rest := line[start+1:]
end := strings.Index(rest, `"`)
if end < 0 {
return rest
}
return rest[:end]
}

// TestAllowedDomainsFromNetworkConfig tests that GH_AW_ALLOWED_DOMAINS is computed
// from network configuration for sanitization
func TestAllowedDomainsFromNetworkConfig(t *testing.T) {
Expand Down Expand Up @@ -374,10 +391,12 @@ Test that empty allowed-domains falls back to network config.
// TestComputeAllowedDomainsForSanitization tests the computeAllowedDomainsForSanitization function
func TestComputeAllowedDomainsForSanitization(t *testing.T) {
tests := []struct {
name string
engineID string
networkPerms *NetworkPermissions
expectedDomains []string
name string
engineID string
apiTarget string
networkPerms *NetworkPermissions
expectedDomains []string
unexpectedDomains []string
}{
{
name: "Copilot with custom domains",
Expand Down Expand Up @@ -434,6 +453,30 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) {
"example.com",
},
},
{
name: "Copilot with GHES api-target includes api and base domains",
engineID: "copilot",
apiTarget: "api.acme.ghe.com",
networkPerms: nil,
expectedDomains: []string{
"api.acme.ghe.com", // GHES API domain
"acme.ghe.com", // GHES base domain (derived from api-target)
"api.github.com", // Copilot default
"github.com", // Copilot default
},
},
{
name: "non-api prefix api-target only adds the configured hostname",
engineID: "copilot",
apiTarget: "copilot.corp.example.com",
networkPerms: nil,
expectedDomains: []string{
"copilot.corp.example.com", // configured hostname
},
unexpectedDomains: []string{
"corp.example.com", // base hostname should NOT be added for non-api. prefix
},
},
}

for _, tt := range tests {
Expand All @@ -442,20 +485,175 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) {
compiler := NewCompiler()
data := &WorkflowData{
EngineConfig: &EngineConfig{
ID: tt.engineID,
ID: tt.engineID,
APITarget: tt.apiTarget,
},
NetworkPermissions: tt.networkPerms,
}

// Call the function
domainsStr := compiler.computeAllowedDomainsForSanitization(data)

// Verify expected domains are present
// Verify expected domains are present (substring match is fine here since domain names
// in a CSV string that are exact entries won't appear as substrings of other entries
// when checking expected ones – we only need exact match for the negative "not present" check)
for _, expectedDomain := range tt.expectedDomains {
if !strings.Contains(domainsStr, expectedDomain) {
t.Errorf("Expected domain '%s' not found in result: %s", expectedDomain, domainsStr)
}
}

// Verify unexpected domains are absent using exact membership (not substring)
// to avoid false positives where "corp.example.com" matches "copilot.corp.example.com"
parts := strings.Split(domainsStr, ",")
for _, unexpectedDomain := range tt.unexpectedDomains {
if slices.Contains(parts, unexpectedDomain) {
t.Errorf("Unexpected domain '%s' found in result: %s", unexpectedDomain, domainsStr)
}
}
})
}
}

// TestAPITargetDomainsInCompiledWorkflow is a regression test verifying that when engine.api-target
// is configured, both --allow-domains (AWF firewall flag) and GH_AW_ALLOWED_DOMAINS (sanitization
// env var) in the compiled lock file contain the api-target hostname and its derived base hostname.
func TestAPITargetDomainsInCompiledWorkflow(t *testing.T) {
tests := []struct {
name string
workflow string
expectedDomains []string
unexpectedDomains []string
}{
{
name: "GHES api-target adds api and base domains to allow-domains and GH_AW_ALLOWED_DOMAINS",
workflow: `---
on: push
permissions:
contents: read
issues: read
pull-requests: read
engine:
id: copilot
api-target: api.acme.ghe.com
strict: false
safe-outputs:
create-issue:
---

# Test Workflow

Test workflow with GHES api-target.
`,
expectedDomains: []string{
"api.acme.ghe.com", // GHES API domain
"acme.ghe.com", // GHES base domain derived from api-target
"api.github.com", // Copilot default
"github.com", // Copilot default
},
},
{
name: "non-api prefix api-target only adds the configured hostname",
workflow: `---
on: push
permissions:
contents: read
issues: read
pull-requests: read
engine:
id: copilot
api-target: copilot.corp.example.com
strict: false
safe-outputs:
create-issue:
---

# Test Workflow

Test workflow with non-api prefix api-target.
`,
expectedDomains: []string{
"copilot.corp.example.com", // configured hostname
},
unexpectedDomains: []string{
"corp.example.com", // base hostname should NOT be added for non-api. prefix
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := testutil.TempDir(t, "api-target-domains-test")
testFile := filepath.Join(tmpDir, "test-workflow.md")
if err := os.WriteFile(testFile, []byte(tt.workflow), 0644); err != nil {
t.Fatal(err)
}

compiler := NewCompiler()
if err := compiler.CompileWorkflow(testFile); err != nil {
t.Fatalf("Failed to compile workflow: %v", err)
}

lockFile := stringutil.MarkdownToLockFile(testFile)
lockContent, err := os.ReadFile(lockFile)
if err != nil {
t.Fatalf("Failed to read lock file: %v", err)
}
lockStr := string(lockContent)

// Check --allow-domains in AWF command contains expected domains
allowDomainsIdx := strings.Index(lockStr, "--allow-domains")
if allowDomainsIdx < 0 {
t.Fatal("--allow-domains flag not found in compiled lock file")
}
// Extract the line with --allow-domains for more targeted checking
allowDomainsEnd := strings.Index(lockStr[allowDomainsIdx:], "\n")
if allowDomainsEnd < 0 {
allowDomainsEnd = len(lockStr) - allowDomainsIdx
}
allowDomainsLine := lockStr[allowDomainsIdx : allowDomainsIdx+allowDomainsEnd]

for _, domain := range tt.expectedDomains {
if !strings.Contains(allowDomainsLine, domain) {
t.Errorf("Expected domain %q not found in --allow-domains.\nLine: %s", domain, allowDomainsLine)
}
}
// Use exact CSV membership for "not present" checks to avoid false positives
// (e.g. "corp.example.com" would substring-match "copilot.corp.example.com")
allowedDomainsCSV := extractQuotedCSV(allowDomainsLine)
allowedParts := strings.Split(allowedDomainsCSV, ",")
for _, domain := range tt.unexpectedDomains {
if slices.Contains(allowedParts, domain) {
t.Errorf("Unexpected domain %q found in --allow-domains.\nLine: %s", domain, allowDomainsLine)
}
}

// Check GH_AW_ALLOWED_DOMAINS env var contains expected domains
lines := strings.Split(lockStr, "\n")
var domainsLine string
for _, line := range lines {
if strings.Contains(line, "GH_AW_ALLOWED_DOMAINS:") {
domainsLine = line
break
}
}
if domainsLine == "" {
t.Fatal("GH_AW_ALLOWED_DOMAINS not found in compiled lock file")
}

for _, domain := range tt.expectedDomains {
if !strings.Contains(domainsLine, domain) {
t.Errorf("Expected domain %q not found in GH_AW_ALLOWED_DOMAINS.\nLine: %s", domain, domainsLine)
}
}
// Use exact CSV membership for "not present" checks
allowedDomainsEnvCSV := extractQuotedCSV(domainsLine)
allowedEnvParts := strings.Split(allowedDomainsEnvCSV, ",")
for _, domain := range tt.unexpectedDomains {
if slices.Contains(allowedEnvParts, domain) {
t.Errorf("Unexpected domain %q found in GH_AW_ALLOWED_DOMAINS.\nLine: %s", domain, domainsLine)
}
}
})
}
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/workflow/claude_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ func (e *ClaudeEngine) GetExecutionSteps(workflowData *WorkflowData, logFile str
// Build the AWF-wrapped command using helper function
// Get allowed domains (Claude defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains)
allowedDomains := GetClaudeAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes)
// Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set
if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" {
allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget)
}

// Build AWF command with all configuration
// AWF v0.15.0+ uses chroot mode by default, providing transparent access to host binaries
Expand Down
4 changes: 4 additions & 0 deletions pkg/workflow/codex_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ func (e *CodexEngine) GetExecutionSteps(workflowData *WorkflowData, logFile stri
// Build AWF-wrapped command using helper function
// Get allowed domains (Codex defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains)
allowedDomains := GetCodexAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes)
// Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set
if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" {
allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget)
}

// Build the command with agent file handling if specified
// INSTRUCTION reading is done inside the AWF command to avoid Docker Compose interpolation
Expand Down
4 changes: 4 additions & 0 deletions pkg/workflow/copilot_engine_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ func (e *CopilotEngine) GetExecutionSteps(workflowData *WorkflowData, logFile st
// Build AWF-wrapped command using helper function - no mkdir needed, AWF handles it
// Get allowed domains (copilot defaults + network permissions + HTTP MCP server URLs + runtime ecosystem domains)
allowedDomains := GetCopilotAllowedDomainsWithToolsAndRuntimes(workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes)
// Add GHES/custom API target domains to the firewall allow-list when engine.api-target is set
if workflowData.EngineConfig != nil && workflowData.EngineConfig.APITarget != "" {
allowedDomains = mergeAPITargetDomains(allowedDomains, workflowData.EngineConfig.APITarget)
}

// AWF v0.15.0+ uses chroot mode by default, providing transparent access to host binaries
// AWF v0.15.0+ with --env-all handles PATH natively (chroot mode is default):
Expand Down
74 changes: 69 additions & 5 deletions pkg/workflow/domains.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,62 @@ func formatBlockedDomains(network *NetworkPermissions) string {
return strings.Join(blockedDomains, ",")
}

// GetAPITargetDomains returns the set of domains to add to the allow-list when engine.api-target is set.
// For a GHES instance with api-target "api.acme.ghe.com", this returns both the API domain
// ("api.acme.ghe.com") and the base hostname ("acme.ghe.com") so that both the GitHub web UI
// and API requests pass through the firewall without manual lock file edits.
// Returns nil for empty apiTarget.
func GetAPITargetDomains(apiTarget string) []string {
if apiTarget == "" {
return nil
}

domains := []string{apiTarget}

// Derive the base hostname by stripping the first subdomain label, but only for
// API-style hostnames that start with "api.".
// e.g., "api.acme.ghe.com" → "acme.ghe.com"
// Only add the base hostname if it still looks like a multi-label hostname (contains a dot).
if strings.HasPrefix(apiTarget, "api.") {
if idx := strings.Index(apiTarget, "."); idx > 0 {
baseHost := apiTarget[idx+1:]
if strings.Contains(baseHost, ".") && baseHost != apiTarget {
domains = append(domains, baseHost)
}
}
}

return domains
}

// mergeAPITargetDomains merges the api-target domains into an existing comma-separated domain string.
// When engine.api-target is set, both the API hostname and its base hostname are added to the allow-list.
// Returns the original string unchanged when apiTarget is empty.
func mergeAPITargetDomains(domainsStr string, apiTarget string) string {
extraDomains := GetAPITargetDomains(apiTarget)
if len(extraDomains) == 0 {
return domainsStr
}

domainMap := make(map[string]bool)
for d := range strings.SplitSeq(domainsStr, ",") {
d = strings.TrimSpace(d)
if d != "" {
domainMap[d] = true
}
}
for _, d := range extraDomains {
domainMap[d] = true
}

result := make([]string, 0, len(domainMap))
for d := range domainMap {
result = append(result, d)
}
sort.Strings(result)
return strings.Join(result, ",")
}

// computeAllowedDomainsForSanitization computes the allowed domains for sanitization
// based on the engine and network configuration, matching what's provided to the firewall
func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) string {
Expand All @@ -655,20 +711,28 @@ func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) stri

// Compute domains based on engine type, including tools and runtimes to match
// what's provided to the actual firewall at runtime
var base string
switch engineID {
case "copilot":
return GetCopilotAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
base = GetCopilotAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
case "codex":
return GetCodexAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
base = GetCodexAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
case "claude":
return GetClaudeAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
base = GetClaudeAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
case "gemini":
return GetGeminiAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
base = GetGeminiAllowedDomainsWithToolsAndRuntimes(data.NetworkPermissions, data.Tools, data.Runtimes)
default:
// For other engines, use network permissions only
domains := GetAllowedDomains(data.NetworkPermissions)
return strings.Join(domains, ",")
base = strings.Join(domains, ",")
}

// Add GHES/custom API target domains so GH_AW_ALLOWED_DOMAINS stays in sync with --allow-domains
if data.EngineConfig != nil && data.EngineConfig.APITarget != "" {
base = mergeAPITargetDomains(base, data.EngineConfig.APITarget)
}

return base
Comment on lines +730 to +735
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added regression tests in commit 0872ff5:

  • Extended TestComputeAllowedDomainsForSanitization with two new table-driven cases: one asserting api.acme.ghe.com + acme.ghe.com appear in the sanitization domain set, and one asserting the base hostname is not added for non-api. prefixed targets.
  • Added TestAPITargetDomainsInCompiledWorkflow (integration, //go:build integration) that compiles a full workflow with engine.api-target configured and verifies both --allow-domains in the AWF command and GH_AW_ALLOWED_DOMAINS in the env block contain the expected domains. Uses exact CSV-membership checks (via extractQuotedCSV helper) for the "not present" assertions to avoid substring false-positives.

}

// expandAllowedDomains expands a list of domain entries (which may include ecosystem
Expand Down
Loading