diff --git a/pkg/workflow/agentic_engine.go b/pkg/workflow/agentic_engine.go index 02183b0108f..8ae6eea6c03 100644 --- a/pkg/workflow/agentic_engine.go +++ b/pkg/workflow/agentic_engine.go @@ -378,7 +378,10 @@ func (e *BaseEngine) GetSecretValidationStep(workflowData *WorkflowData) GitHubA } // GetFirewallLogsCollectionStep returns an empty slice by default. -// Engines that need to copy session or firewall state files before secret redaction should override this. +// Firewall logs are written to a known location (/tmp/gh-aw/sandbox/firewall/logs/) +// and do not require a separate collection step. The method is still called from +// compiler_yaml_main_job.go to maintain a consistent interface; engines that need +// to copy session or firewall state files before secret redaction should override it. func (e *BaseEngine) GetFirewallLogsCollectionStep(workflowData *WorkflowData) []GitHubActionStep { return []GitHubActionStep{} } diff --git a/pkg/workflow/allowed_domains_sanitization_test.go b/pkg/workflow/allowed_domains_sanitization_test.go index aca1868927b..3f6d1b16e74 100644 --- a/pkg/workflow/allowed_domains_sanitization_test.go +++ b/pkg/workflow/allowed_domains_sanitization_test.go @@ -10,8 +10,8 @@ import ( "testing" "github.com/github/gh-aw/pkg/stringutil" - "github.com/github/gh-aw/pkg/testutil" + "github.com/stretchr/testify/require" ) // extractQuotedCSV returns the comma-separated domain list embedded inside @@ -492,7 +492,8 @@ func TestComputeAllowedDomainsForSanitization(t *testing.T) { } // Call the function - domainsStr := compiler.computeAllowedDomainsForSanitization(data) + domainsStr, err := compiler.computeAllowedDomainsForSanitization(data) + require.NoError(t, err, "computeAllowedDomainsForSanitization should not return an error for valid test data") // Verify expected domains are present (substring match is fine here since domain names // in a CSV string that are exact entries won't appear as substrings of other entries diff --git a/pkg/workflow/claude_engine.go b/pkg/workflow/claude_engine.go index 24b565de341..39fc889422b 100644 --- a/pkg/workflow/claude_engine.go +++ b/pkg/workflow/claude_engine.go @@ -463,13 +463,6 @@ func (e *ClaudeEngine) GetLogParserScriptId() string { return "parse_claude_log" } -// GetFirewallLogsCollectionStep returns the step for collecting firewall logs (before secret redaction) -// No longer needed since we know where the logs are in the sandbox folder structure -func (e *ClaudeEngine) GetFirewallLogsCollectionStep(workflowData *WorkflowData) []GitHubActionStep { - // Collection step removed - firewall logs are now at a known location - return []GitHubActionStep{} -} - // GetSquidLogsSteps returns the steps for uploading and parsing Squid logs (after secret redaction) func (e *ClaudeEngine) GetSquidLogsSteps(workflowData *WorkflowData) []GitHubActionStep { return defaultGetSquidLogsSteps(workflowData, claudeLog) diff --git a/pkg/workflow/codex_engine.go b/pkg/workflow/codex_engine.go index 30e4cd0d4fa..ea5bdba0f83 100644 --- a/pkg/workflow/codex_engine.go +++ b/pkg/workflow/codex_engine.go @@ -380,15 +380,6 @@ mkdir -p "$CODEX_HOME/logs" return steps } -// GetFirewallLogsCollectionStep returns the step for collecting firewall logs (before secret redaction). -// This method is part of the firewall integration interface. It returns an empty slice because -// firewall logs are written to a known location (/tmp/gh-aw/sandbox/firewall/logs/) and don't need -// a separate collection step. The method is still called from compiler_yaml_main_job.go to maintain -// consistent behavior with other engines that may need log collection steps. -func (e *CodexEngine) GetFirewallLogsCollectionStep(workflowData *WorkflowData) []GitHubActionStep { - return []GitHubActionStep{} -} - // GetSquidLogsSteps returns the steps for uploading and parsing Squid logs (after secret redaction) func (e *CodexEngine) GetSquidLogsSteps(workflowData *WorkflowData) []GitHubActionStep { return defaultGetSquidLogsSteps(workflowData, codexEngineLog) diff --git a/pkg/workflow/compile_config_test.go b/pkg/workflow/compile_config_test.go index f6d74713081..00cc157fed8 100644 --- a/pkg/workflow/compile_config_test.go +++ b/pkg/workflow/compile_config_test.go @@ -198,7 +198,10 @@ func TestSafeOutputsConfigGeneration(t *testing.T) { // Use the compiler's generateOutputCollectionStep to verify config is not in env vars var yamlBuilder strings.Builder - compiler.generateOutputCollectionStep(&yamlBuilder, workflowData) + err := compiler.generateOutputCollectionStep(&yamlBuilder, workflowData) + if err != nil { + t.Fatalf("generateOutputCollectionStep returned unexpected error: %v", err) + } generatedYAML := yamlBuilder.String() // Config should NOT be in environment variables anymore - it's in a file diff --git a/pkg/workflow/compiler_activation_job_builder.go b/pkg/workflow/compiler_activation_job_builder.go index e790b1c36a1..7d00f187cd9 100644 --- a/pkg/workflow/compiler_activation_job_builder.go +++ b/pkg/workflow/compiler_activation_job_builder.go @@ -238,9 +238,17 @@ func (c *Compiler) addActivationRepositoryAndOutputSteps(ctx *activationJobBuild ctx.steps = append(ctx.steps, fmt.Sprintf(" uses: %s\n", getCachedActionPin("actions/github-script", data))) var domainsStr string if data.SafeOutputs != nil && len(data.SafeOutputs.AllowedDomains) > 0 { - domainsStr = c.computeExpandedAllowedDomainsForSanitization(data) + expanded, err := c.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + return err + } + domainsStr = expanded } else { - domainsStr = c.computeAllowedDomainsForSanitization(data) + computed, err := c.computeAllowedDomainsForSanitization(data) + if err != nil { + return err + } + domainsStr = computed } var envLines []string if len(data.Bots) > 0 { diff --git a/pkg/workflow/compiler_safe_outputs_job.go b/pkg/workflow/compiler_safe_outputs_job.go index 38a70a1332f..e4041220390 100644 --- a/pkg/workflow/compiler_safe_outputs_job.go +++ b/pkg/workflow/compiler_safe_outputs_job.go @@ -201,7 +201,10 @@ func (c *Compiler) buildConsolidatedSafeOutputsJob(data *WorkflowData, mainJobNa // Critical for workflows that create projects and then add issues/PRs to those projects if hasHandlerManagerTypes { consolidatedSafeOutputsJobLog.Print("Using handler manager for safe outputs") - handlerManagerSteps := c.buildHandlerManagerStep(data) + handlerManagerSteps, err := c.buildHandlerManagerStep(data) + if err != nil { + return nil, nil, err + } steps = append(steps, handlerManagerSteps...) safeOutputStepNames = append(safeOutputStepNames, "process_safe_outputs") diff --git a/pkg/workflow/compiler_safe_outputs_steps.go b/pkg/workflow/compiler_safe_outputs_steps.go index f219bf0d65b..1b43756e514 100644 --- a/pkg/workflow/compiler_safe_outputs_steps.go +++ b/pkg/workflow/compiler_safe_outputs_steps.go @@ -134,7 +134,7 @@ func (c *Compiler) buildSharedPRCheckoutSteps(data *WorkflowData) []string { // buildHandlerManagerStep builds a single step that uses the safe output handler manager // to dispatch messages to appropriate handlers. This replaces multiple individual steps // with a single dispatcher step that processes all safe output types. -func (c *Compiler) buildHandlerManagerStep(data *WorkflowData) []string { +func (c *Compiler) buildHandlerManagerStep(data *WorkflowData) ([]string, error) { consolidatedSafeOutputsStepsLog.Print("Building handler manager step") var steps []string @@ -154,9 +154,17 @@ func (c *Compiler) buildHandlerManagerStep(data *WorkflowData) []string { var domainsStr string if data.SafeOutputs != nil && len(data.SafeOutputs.AllowedDomains) > 0 { // allowed-domains: additional domains unioned with engine/network base set; supports ecosystem identifiers - domainsStr = c.computeExpandedAllowedDomainsForSanitization(data) + expanded, err := c.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + return nil, err + } + domainsStr = expanded } else { - domainsStr = c.computeAllowedDomainsForSanitization(data) + computed, err := c.computeAllowedDomainsForSanitization(data) + if err != nil { + return nil, err + } + domainsStr = computed } if domainsStr != "" { steps = append(steps, fmt.Sprintf(" GH_AW_ALLOWED_DOMAINS: %q\n", domainsStr)) @@ -324,5 +332,5 @@ func (c *Compiler) buildHandlerManagerStep(data *WorkflowData) []string { steps = append(steps, " const { main } = require('"+SetupActionDestination+"/safe_output_handler_manager.cjs');\n") steps = append(steps, " await main();\n") - return steps + return steps, nil } diff --git a/pkg/workflow/compiler_safe_outputs_steps_test.go b/pkg/workflow/compiler_safe_outputs_steps_test.go index 9b280aecbc5..9ad9e942ff1 100644 --- a/pkg/workflow/compiler_safe_outputs_steps_test.go +++ b/pkg/workflow/compiler_safe_outputs_steps_test.go @@ -434,7 +434,8 @@ func TestBuildHandlerManagerStep(t *testing.T) { ParsedFrontmatter: tt.parsedFrontmatter, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) require.NotEmpty(t, steps) diff --git a/pkg/workflow/compiler_yaml.go b/pkg/workflow/compiler_yaml.go index 4c56c96ef6a..2669d9b21fc 100644 --- a/pkg/workflow/compiler_yaml.go +++ b/pkg/workflow/compiler_yaml.go @@ -830,7 +830,7 @@ func (c *Compiler) generateCreateAwInfo(yaml *strings.Builder, data *WorkflowDat yaml.WriteString(" await main(core, context);\n") } -func (c *Compiler) generateOutputCollectionStep(yaml *strings.Builder, data *WorkflowData) { +func (c *Compiler) generateOutputCollectionStep(yaml *strings.Builder, data *WorkflowData) error { // Copy the raw safe-output NDJSON to a /tmp/gh-aw/ path so it can be included in the // unified agent artifact together with all other /tmp/gh-aw/ outputs. yaml.WriteString(" - name: Copy Safe Outputs\n") @@ -857,10 +857,18 @@ func (c *Compiler) generateOutputCollectionStep(yaml *strings.Builder, data *Wor var domainsStr string if data.SafeOutputs != nil && len(data.SafeOutputs.AllowedDomains) > 0 { // allowed-domains: additional domains unioned with engine/network base set; supports ecosystem identifiers - domainsStr = c.computeExpandedAllowedDomainsForSanitization(data) + expanded, err := c.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + return err + } + domainsStr = expanded } else { // Fall back to computing from network configuration (same as firewall) - domainsStr = c.computeAllowedDomainsForSanitization(data) + computed, err := c.computeAllowedDomainsForSanitization(data) + if err != nil { + return err + } + domainsStr = computed } if domainsStr != "" { fmt.Fprintf(yaml, " GH_AW_ALLOWED_DOMAINS: %q\n", domainsStr) @@ -892,6 +900,7 @@ func (c *Compiler) generateOutputCollectionStep(yaml *strings.Builder, data *Wor yaml.WriteString(" const { main } = require('${{ runner.temp }}/gh-aw/actions/collect_ndjson_output.cjs');\n") yaml.WriteString(" await main();\n") + return nil } // processMarkdownBody applies the standard post-processing pipeline to a markdown body: diff --git a/pkg/workflow/compiler_yaml_main_job.go b/pkg/workflow/compiler_yaml_main_job.go index f50e3dbbd4e..20a46956c13 100644 --- a/pkg/workflow/compiler_yaml_main_job.go +++ b/pkg/workflow/compiler_yaml_main_job.go @@ -461,7 +461,9 @@ func (c *Compiler) generateMainJobSteps(yaml *strings.Builder, data *WorkflowDat // Add output collection step only if safe-outputs feature is used (GH_AW_SAFE_OUTPUTS functionality) if data.SafeOutputs != nil { - c.generateOutputCollectionStep(yaml, data) + if err := c.generateOutputCollectionStep(yaml, data); err != nil { + return err + } } // Merge engine-declared output files into the unified artifact instead of creating a diff --git a/pkg/workflow/crush_engine.go b/pkg/workflow/crush_engine.go index f0468e03391..c72e94cb762 100644 --- a/pkg/workflow/crush_engine.go +++ b/pkg/workflow/crush_engine.go @@ -140,12 +140,18 @@ func (e *CrushEngine) GetExecutionSteps(workflowData *WorkflowData, logFile stri if modelConfigured { model = workflowData.EngineConfig.Model } - allowedDomains := GetCrushAllowedDomainsWithToolsAndRuntimes( + // The model was validated by validateUniversalLLMConsumerModel before reaching here, + // so a malformed model (e.g. leading slash) must never occur. Panic is the correct + // response to an internal invariant violation. + allowedDomains, err := GetCrushAllowedDomainsWithToolsAndRuntimes( model, workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes, ) + if err != nil { + panic(fmt.Sprintf("BUG: invalid model %q reached domain computation (should have been caught by validation): %v", model, err)) + } npmPathSetup := GetNpmBinPathSetup() crushCommandWithPath := fmt.Sprintf("%s && %s", npmPathSetup, crushCommand) diff --git a/pkg/workflow/crush_engine_test.go b/pkg/workflow/crush_engine_test.go index 27104fef359..11d3673ecbd 100644 --- a/pkg/workflow/crush_engine_test.go +++ b/pkg/workflow/crush_engine_test.go @@ -446,20 +446,40 @@ func TestCrushEngineFirewallIntegration(t *testing.T) { func TestExtractProviderFromModel(t *testing.T) { t.Run("standard provider/model format", func(t *testing.T) { - assert.Equal(t, "anthropic", extractCrushProviderFromModel("anthropic/claude-sonnet-4-20250514")) - assert.Equal(t, "openai", extractCrushProviderFromModel("openai/gpt-4.1")) - assert.Equal(t, "google", extractCrushProviderFromModel("google/gemini-2.5-pro")) + provider, err := extractProviderFromModel("anthropic/claude-sonnet-4-20250514") + require.NoError(t, err) + assert.Equal(t, "anthropic", provider) + + provider, err = extractProviderFromModel("openai/gpt-4.1") + require.NoError(t, err) + assert.Equal(t, "openai", provider) + + provider, err = extractProviderFromModel("google/gemini-2.5-pro") + require.NoError(t, err) + assert.Equal(t, "google", provider) }) - t.Run("empty model defaults to copilot", func(t *testing.T) { - assert.Equal(t, "copilot", extractCrushProviderFromModel("")) + t.Run("empty model returns empty provider", func(t *testing.T) { + provider, err := extractProviderFromModel("") + require.NoError(t, err) + assert.Empty(t, provider) }) - t.Run("no slash defaults to copilot", func(t *testing.T) { - assert.Equal(t, "copilot", extractCrushProviderFromModel("claude-sonnet-4-20250514")) + t.Run("no slash returns empty provider", func(t *testing.T) { + provider, err := extractProviderFromModel("claude-sonnet-4-20250514") + require.NoError(t, err) + assert.Empty(t, provider) }) t.Run("case insensitive provider", func(t *testing.T) { - assert.Equal(t, "openai", extractCrushProviderFromModel("OpenAI/gpt-4.1")) + provider, err := extractProviderFromModel("OpenAI/gpt-4.1") + require.NoError(t, err) + assert.Equal(t, "openai", provider) + }) + + t.Run("leading slash returns error", func(t *testing.T) { + _, err := extractProviderFromModel("/gpt-4.1") + require.Error(t, err, "Leading slash (empty provider prefix) must return an error") + assert.Contains(t, err.Error(), "provider prefix is empty") }) } diff --git a/pkg/workflow/domains.go b/pkg/workflow/domains.go index 9034f683574..64b7db9290d 100644 --- a/pkg/workflow/domains.go +++ b/pkg/workflow/domains.go @@ -124,7 +124,7 @@ var CrushBaseDefaultDomains = []string{ } // crushProviderDomains maps provider prefixes to their API domains. -// Used by extractCrushProviderFromModel() and GetCrushDefaultDomains(). +// Used by extractProviderFromModel() and GetCrushDefaultDomains(). var crushProviderDomains = map[string]string{ "copilot": "api.githubcopilot.com", "anthropic": "api.anthropic.com", @@ -161,7 +161,7 @@ var OpenCodeBaseDefaultDomains = []string{ } // openCodeProviderDomains maps provider prefixes to their API domains. -// Used by extractOpenCodeProviderFromModel() and GetOpenCodeDefaultDomains(). +// Used by extractProviderFromModel() and GetOpenCodeDefaultDomains(). var openCodeProviderDomains = map[string]string{ "copilot": "api.githubcopilot.com", "anthropic": "api.anthropic.com", @@ -186,24 +186,36 @@ var OpenCodeDefaultDomains = []string{ "registry.npmjs.org", // npm package downloads } -// extractOpenCodeProviderFromModel extracts the provider name from an OpenCode model string. -// OpenCode uses "provider/model" format (e.g., "anthropic/claude-sonnet-4-20250514"). -// Returns the provider prefix, or "copilot" as default if no slash is found. -func extractOpenCodeProviderFromModel(model string) string { +// extractProviderFromModel parses "provider/model" format and returns the +// lowercase provider prefix. Returns ("", nil) when no model is given or the +// format contains no slash (no provider prefix detected). Returns an error when +// the format is explicitly malformed – a leading slash like "/gpt-4.1" means +// the provider prefix is intentionally empty, which is always invalid. +// Both OpenCode and Crush use this same "provider/model" convention. +func extractProviderFromModel(model string) (string, error) { if model == "" { - return "copilot" + return "", nil } parts := strings.SplitN(model, "/", 2) if len(parts) < 2 { - return "copilot" + // No slash: no "provider/model" format; no provider to extract. + return "", nil } - return strings.ToLower(parts[0]) + provider := strings.ToLower(parts[0]) + if provider == "" { + return "", fmt.Errorf("invalid engine.model %q: provider prefix is empty; use provider/model format (for example: openai/gpt-4.1, anthropic/claude-sonnet-4)", model) + } + return provider, nil } // GetOpenCodeDefaultDomains returns the default domains for OpenCode based on the model provider. // It starts with OpenCodeBaseDefaultDomains and adds the provider-specific API domain. -func GetOpenCodeDefaultDomains(model string) []string { - provider := extractOpenCodeProviderFromModel(model) +// Returns an error if the model string is malformed (e.g. a leading slash). +func GetOpenCodeDefaultDomains(model string) ([]string, error) { + provider, err := extractProviderFromModel(model) + if err != nil { + return nil, err + } domains := make([]string, 0, len(OpenCodeBaseDefaultDomains)+1) domains = append(domains, OpenCodeBaseDefaultDomains...) @@ -211,33 +223,24 @@ func GetOpenCodeDefaultDomains(model string) []string { domains = append(domains, domain) } - return domains + return domains, nil } // GetOpenCodeAllowedDomainsWithToolsAndRuntimes merges OpenCode default domains with NetworkPermissions, HTTP MCP server domains, and runtime ecosystem domains. // Pass the selected model so provider-specific API domains are included. -func GetOpenCodeAllowedDomainsWithToolsAndRuntimes(model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) string { +// Returns an error if the model string is malformed (e.g. a leading slash). +func GetOpenCodeAllowedDomainsWithToolsAndRuntimes(model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) (string, error) { return GetAllowedDomainsForEngineWithModel(constants.OpenCodeEngine, model, network, tools, runtimes) } -// extractCrushProviderFromModel extracts the provider name from a Crush model string. -// Crush uses "provider/model" format (e.g., "anthropic/claude-sonnet-4-20250514"). -// Returns the provider prefix, or "copilot" as default if no slash is found. -func extractCrushProviderFromModel(model string) string { - if model == "" { - return "copilot" - } - parts := strings.SplitN(model, "/", 2) - if len(parts) < 2 { - return "copilot" - } - return strings.ToLower(parts[0]) -} - // GetCrushDefaultDomains returns the default domains for Crush based on the model provider. // It starts with CrushBaseDefaultDomains and adds the provider-specific API domain. -func GetCrushDefaultDomains(model string) []string { - provider := extractCrushProviderFromModel(model) +// Returns an error if the model string is malformed (e.g. a leading slash). +func GetCrushDefaultDomains(model string) ([]string, error) { + provider, err := extractProviderFromModel(model) + if err != nil { + return nil, err + } domains := make([]string, 0, len(CrushBaseDefaultDomains)+1) domains = append(domains, CrushBaseDefaultDomains...) @@ -245,13 +248,14 @@ func GetCrushDefaultDomains(model string) []string { domains = append(domains, domain) } - return domains + return domains, nil } // GetCrushAllowedDomainsWithToolsAndRuntimes merges Crush default domains with NetworkPermissions, HTTP MCP server domains, and runtime ecosystem domains. // Pass the selected model (e.g. "anthropic/claude-sonnet-4-20250514") so provider-specific // API domains are included. Returns a deduplicated, sorted, comma-separated string suitable for AWF's --allow-domains flag. -func GetCrushAllowedDomainsWithToolsAndRuntimes(model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) string { +// Returns an error if the model string is malformed (e.g. a leading slash). +func GetCrushAllowedDomainsWithToolsAndRuntimes(model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) (string, error) { return GetAllowedDomainsForEngineWithModel(constants.CrushEngine, model, network, tools, runtimes) } @@ -686,7 +690,8 @@ var engineDefaultDomains = map[constants.EngineName][]string{ // resolved via GetOpenCodeDefaultDomains(model) / GetCrushDefaultDomains(model) // rather than the static engineDefaultDomains map. // Falls back to an empty default domain list for unknown engines. -func getDefaultDomainsForEngine(engine constants.EngineName, model string) []string { +// Returns an error if the model string is malformed (e.g. a leading slash). +func getDefaultDomainsForEngine(engine constants.EngineName, model string) ([]string, error) { if engine == constants.OpenCodeEngine { return GetOpenCodeDefaultDomains(model) } @@ -694,7 +699,7 @@ func getDefaultDomainsForEngine(engine constants.EngineName, model string) []str return GetCrushDefaultDomains(model) } - return engineDefaultDomains[engine] + return engineDefaultDomains[engine], nil } // GetAllowedDomainsForEngineWithModel merges the engine's default domains with @@ -703,8 +708,13 @@ func getDefaultDomainsForEngine(engine constants.EngineName, model string) []str // selected model so the correct default domains are included. // Returns a deduplicated, sorted, comma-separated string suitable for AWF's // --allow-domains flag. -func GetAllowedDomainsForEngineWithModel(engine constants.EngineName, model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) string { - return mergeDomainsWithNetworkToolsAndRuntimes(getDefaultDomainsForEngine(engine, model), network, tools, runtimes) +// Returns an error if the model string is malformed (e.g. a leading slash). +func GetAllowedDomainsForEngineWithModel(engine constants.EngineName, model string, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) (string, error) { + defaults, err := getDefaultDomainsForEngine(engine, model) + if err != nil { + return "", err + } + return mergeDomainsWithNetworkToolsAndRuntimes(defaults, network, tools, runtimes), nil } // GetAllowedDomainsForEngine merges the engine's default domains with NetworkPermissions, @@ -714,7 +724,9 @@ func GetAllowedDomainsForEngineWithModel(engine constants.EngineName, model stri // For model/provider-specific engines such as Crush, prefer // GetAllowedDomainsForEngineWithModel so provider domains are included. func GetAllowedDomainsForEngine(engine constants.EngineName, network *NetworkPermissions, tools map[string]any, runtimes map[string]any) string { - return GetAllowedDomainsForEngineWithModel(engine, "", network, tools, runtimes) + // Empty model never triggers provider-format validation, so no error is possible here. + result, _ := GetAllowedDomainsForEngineWithModel(engine, "", network, tools, runtimes) + return result } // GetCopilotAllowedDomainsWithToolsAndRuntimes merges Copilot default domains with NetworkPermissions, HTTP MCP server domains, and runtime ecosystem domains @@ -865,12 +877,13 @@ func mergeAPITargetDomains(domainsStr string, apiTarget string) string { // The result is cached in data.CachedAllowedDomainsStr after the first call so that // repeated calls (e.g. from the activation job, safe-outputs steps, and agent run step) // do not recompute the same domain list. -func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) string { +// Returns an error if the engine's model is malformed (e.g. a leading slash). +func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) (string, error) { // Return cached result if available (engine/network/tools/runtimes do not change during compilation). // CachedAllowedDomainsComputed is used as the sentinel so that a legitimately empty domain // list is not confused with "not yet computed". if data.CachedAllowedDomainsComputed { - return data.CachedAllowedDomainsStr + return data.CachedAllowedDomainsStr, nil } // Determine which engine is being used @@ -898,13 +911,21 @@ func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) stri if data.EngineConfig != nil { model = data.EngineConfig.Model } - base = GetOpenCodeAllowedDomainsWithToolsAndRuntimes(model, data.NetworkPermissions, data.Tools, data.Runtimes) + var err error + base, err = GetOpenCodeAllowedDomainsWithToolsAndRuntimes(model, data.NetworkPermissions, data.Tools, data.Runtimes) + if err != nil { + return "", err + } case "crush": model := "" if data.EngineConfig != nil { model = data.EngineConfig.Model } - base = GetCrushAllowedDomainsWithToolsAndRuntimes(model, data.NetworkPermissions, data.Tools, data.Runtimes) + var err error + base, err = GetCrushAllowedDomainsWithToolsAndRuntimes(model, data.NetworkPermissions, data.Tools, data.Runtimes) + if err != nil { + return "", err + } default: // For other engines, use network permissions only domains := GetAllowedDomains(data.NetworkPermissions) @@ -927,7 +948,7 @@ func (c *Compiler) computeAllowedDomainsForSanitization(data *WorkflowData) stri // Set the boolean sentinel first so that an empty result is also treated as cached. data.CachedAllowedDomainsComputed = true data.CachedAllowedDomainsStr = base - return base + return base, nil } // expandAllowedDomains expands a list of domain entries (which may include ecosystem @@ -952,9 +973,13 @@ func expandAllowedDomains(entries []string) []string { // unioning the engine/network base set with the safe-outputs.allowed-domains entries. // It always includes "localhost" and "github.com" in the result. // The allowed-domains entries support ecosystem identifiers (same syntax as network.allowed). -func (c *Compiler) computeExpandedAllowedDomainsForSanitization(data *WorkflowData) string { +// Returns an error if the engine's model is malformed (e.g. a leading slash). +func (c *Compiler) computeExpandedAllowedDomainsForSanitization(data *WorkflowData) (string, error) { // Start from the base set (engine defaults + network.allowed + tools + runtimes) - base := c.computeAllowedDomainsForSanitization(data) + base, err := c.computeAllowedDomainsForSanitization(data) + if err != nil { + return "", err + } domainMap := make(map[string]bool) @@ -982,5 +1007,5 @@ func (c *Compiler) computeExpandedAllowedDomainsForSanitization(data *WorkflowDa domainMap["github.com"] = true // Produce a sorted, comma-separated result - return strings.Join(slices.Sorted(maps.Keys(domainMap)), ",") + return strings.Join(slices.Sorted(maps.Keys(domainMap)), ","), nil } diff --git a/pkg/workflow/domains_test.go b/pkg/workflow/domains_test.go index 2c7112d02b9..8b33f860cb0 100644 --- a/pkg/workflow/domains_test.go +++ b/pkg/workflow/domains_test.go @@ -1094,7 +1094,10 @@ func TestComputeExpandedAllowedDomainsForSanitization(t *testing.T) { AllowedDomains: []string{"extra-domain.com"}, }, } - result := compiler.computeExpandedAllowedDomainsForSanitization(data) + result, err := compiler.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !strings.Contains(result, "extra-domain.com") { t.Error("Expected extra-domain.com in result") } @@ -1113,7 +1116,10 @@ func TestComputeExpandedAllowedDomainsForSanitization(t *testing.T) { AllowedDomains: []string{"extra-domain.com"}, }, } - result := compiler.computeExpandedAllowedDomainsForSanitization(data) + result, err := compiler.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !strings.Contains(result, "localhost") { t.Error("Expected localhost to always be in allowed-domains result") } @@ -1126,7 +1132,10 @@ func TestComputeExpandedAllowedDomainsForSanitization(t *testing.T) { AllowedDomains: []string{"extra-domain.com"}, }, } - result := compiler.computeExpandedAllowedDomainsForSanitization(data) + result, err := compiler.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !strings.Contains(result, "github.com") { t.Error("Expected github.com to always be in allowed-domains result") } @@ -1139,7 +1148,10 @@ func TestComputeExpandedAllowedDomainsForSanitization(t *testing.T) { AllowedDomains: []string{"python", "dev-tools"}, }, } - result := compiler.computeExpandedAllowedDomainsForSanitization(data) + result, err := compiler.computeExpandedAllowedDomainsForSanitization(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } if !strings.Contains(result, "pypi.org") { t.Error("Expected pypi.org from python ecosystem in result") } diff --git a/pkg/workflow/opencode_engine.go b/pkg/workflow/opencode_engine.go index 33ace2fdd57..196bf54f2ea 100644 --- a/pkg/workflow/opencode_engine.go +++ b/pkg/workflow/opencode_engine.go @@ -119,12 +119,18 @@ func (e *OpenCodeEngine) GetExecutionSteps(workflowData *WorkflowData, logFile s if modelConfigured { model = workflowData.EngineConfig.Model } - allowedDomains := GetOpenCodeAllowedDomainsWithToolsAndRuntimes( + // The model was validated by validateUniversalLLMConsumerModel before reaching here, + // so a malformed model (e.g. leading slash) must never occur. Panic is the correct + // response to an internal invariant violation. + allowedDomains, err := GetOpenCodeAllowedDomainsWithToolsAndRuntimes( model, workflowData.NetworkPermissions, workflowData.Tools, workflowData.Runtimes, ) + if err != nil { + panic(fmt.Sprintf("BUG: invalid model %q reached domain computation (should have been caught by validation): %v", model, err)) + } npmPathSetup := GetNpmBinPathSetup() openCodeCommandWithPath := fmt.Sprintf("%s && %s", npmPathSetup, openCodeCommand) diff --git a/pkg/workflow/safe_outputs_actions_test.go b/pkg/workflow/safe_outputs_actions_test.go index d309daf19ca..dcf6592169a 100644 --- a/pkg/workflow/safe_outputs_actions_test.go +++ b/pkg/workflow/safe_outputs_actions_test.go @@ -372,7 +372,8 @@ func TestHandlerManagerStepIncludesActionsEnvVar(t *testing.T) { }, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) fullYAML := strings.Join(steps, "") assert.Contains(t, fullYAML, "GH_AW_SAFE_OUTPUT_ACTIONS", "Should include GH_AW_SAFE_OUTPUT_ACTIONS env var") @@ -388,7 +389,8 @@ func TestHandlerManagerStepNoActionsEnvVar(t *testing.T) { }, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) fullYAML := strings.Join(steps, "") assert.NotContains(t, fullYAML, "GH_AW_SAFE_OUTPUT_ACTIONS", "Should not include GH_AW_SAFE_OUTPUT_ACTIONS when no actions") diff --git a/pkg/workflow/safe_outputs_cross_repo_config_test.go b/pkg/workflow/safe_outputs_cross_repo_config_test.go index 7603f416511..3a9fdedca7c 100644 --- a/pkg/workflow/safe_outputs_cross_repo_config_test.go +++ b/pkg/workflow/safe_outputs_cross_repo_config_test.go @@ -489,7 +489,8 @@ func TestHandlerManagerStepPerOutputTokenInHandlerConfig(t *testing.T) { SafeOutputs: tt.safeOutputs, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) stepsContent := strings.Join(steps, "") // Verify tokens appear somewhere in the step content (handler config JSON) diff --git a/pkg/workflow/safe_outputs_handler_manager_token_test.go b/pkg/workflow/safe_outputs_handler_manager_token_test.go index 59bcf912661..c001ab7b4ab 100644 --- a/pkg/workflow/safe_outputs_handler_manager_token_test.go +++ b/pkg/workflow/safe_outputs_handler_manager_token_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestHandlerManagerGitHubTokenEnvVarForCrossRepo verifies that GITHUB_TOKEN is exposed as @@ -154,7 +155,8 @@ func TestHandlerManagerGitHubTokenEnvVarForCrossRepo(t *testing.T) { SafeOutputs: compiler.extractSafeOutputsConfig(tt.frontmatter), } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) yamlStr := strings.Join(steps, "") if tt.shouldHaveGitHubToken { @@ -300,7 +302,8 @@ func TestHandlerManagerProjectGitHubTokenEnvVar(t *testing.T) { } // Build the handler manager step - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) yamlStr := strings.Join(steps, "") if tt.shouldHaveToken { diff --git a/pkg/workflow/safe_scripts_test.go b/pkg/workflow/safe_scripts_test.go index 5e39125805a..e02332a0ac2 100644 --- a/pkg/workflow/safe_scripts_test.go +++ b/pkg/workflow/safe_scripts_test.go @@ -306,7 +306,8 @@ func TestHandlerManagerStepIncludesScriptsEnvVar(t *testing.T) { }, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) fullYAML := strings.Join(steps, "") assert.Contains(t, fullYAML, "GH_AW_SAFE_OUTPUT_SCRIPTS", "Should include GH_AW_SAFE_OUTPUT_SCRIPTS env var") @@ -323,7 +324,8 @@ func TestHandlerManagerStepNoScriptsEnvVar(t *testing.T) { }, } - steps := compiler.buildHandlerManagerStep(workflowData) + steps, err := compiler.buildHandlerManagerStep(workflowData) + require.NoError(t, err) fullYAML := strings.Join(steps, "") assert.NotContains(t, fullYAML, "GH_AW_SAFE_OUTPUT_SCRIPTS", "Should not include GH_AW_SAFE_OUTPUT_SCRIPTS env var when no scripts")