diff --git a/actions/setup/js/generate_aw_info.cjs b/actions/setup/js/generate_aw_info.cjs index e984be8ee46..f943ede658f 100644 --- a/actions/setup/js/generate_aw_info.cjs +++ b/actions/setup/js/generate_aw_info.cjs @@ -86,6 +86,24 @@ async function main(core, ctx) { awInfo.cli_version = cliVersion; } + // Include custom token weights when set (engine.token-weights in workflow frontmatter). + // Deep structure validation is intentionally minimal here: the JSON schema and Go parser + // already validate the structure at compile time. We only verify the top-level type to + // guard against unexpected env-var values at runtime. + const tokenWeightsEnv = process.env.GH_AW_INFO_TOKEN_WEIGHTS; + if (tokenWeightsEnv) { + try { + const tokenWeights = JSON.parse(tokenWeightsEnv); + if (tokenWeights !== null && typeof tokenWeights === "object" && !Array.isArray(tokenWeights)) { + awInfo.token_weights = tokenWeights; + } else { + core.warning(`GH_AW_INFO_TOKEN_WEIGHTS must be a JSON object, ignoring`); + } + } catch { + core.warning(`Failed to parse GH_AW_INFO_TOKEN_WEIGHTS: ${tokenWeightsEnv}`); + } + } + // Include aw_context when the workflow was triggered via workflow_dispatch with // the aw_context input set by a calling agentic workflow's dispatch_workflow handler. // Validates JSON format and structure before populating the context key in aw_info.json. diff --git a/pkg/cli/effective_tokens.go b/pkg/cli/effective_tokens.go index 152adc397d6..51e27fb5a71 100644 --- a/pkg/cli/effective_tokens.go +++ b/pkg/cli/effective_tokens.go @@ -30,10 +30,12 @@ package cli import ( _ "embed" "encoding/json" + "maps" "math" "strings" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/types" ) var effectiveTokensLog = logger.New("cli:effective_tokens") @@ -196,28 +198,105 @@ func computeModelEffectiveTokens(model string, inputTokens, outputTokens, cacheR // entry and computes the TotalEffectiveTokens aggregate on the summary. // It is a no-op when summary is nil. func populateEffectiveTokens(summary *TokenUsageSummary) { + populateEffectiveTokensWithCustomWeights(summary, nil) +} + +// populateEffectiveTokensWithCustomWeights is like populateEffectiveTokens but +// merges custom into the built-in weights before computing effective tokens. +// Custom weights take precedence over the defaults loaded from model_multipliers.json. +// It is a no-op when summary is nil. +func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom *types.TokenWeights) { if summary == nil { return } + multipliers, classWeights := resolveEffectiveWeights(custom) + total := 0 for model, usage := range summary.ByModel { if usage == nil { continue } - eff := computeModelEffectiveTokens( - model, - usage.InputTokens, - usage.OutputTokens, - usage.CacheReadTokens, - usage.CacheWriteTokens, - ) + eff := computeModelEffectiveTokensWithWeights(model, usage.InputTokens, usage.OutputTokens, + usage.CacheReadTokens, usage.CacheWriteTokens, multipliers, classWeights) usage.EffectiveTokens = eff total += eff } summary.TotalEffectiveTokens = total if effectiveTokensLog.Enabled() { - effectiveTokensLog.Printf("Effective tokens: total=%d models=%d", total, len(summary.ByModel)) + effectiveTokensLog.Printf("Effective tokens: total=%d models=%d custom=%v", total, len(summary.ByModel), custom != nil) + } +} + +// resolveEffectiveWeights merges optional custom weights with the built-in defaults. +// The returned multipliers map is a copy so callers may not modify loadedMultipliers. +func resolveEffectiveWeights(custom *types.TokenWeights) (map[string]float64, tokenClassWeights) { + initMultipliers() + + // Copy the base multipliers to avoid mutating the shared global + merged := make(map[string]float64, len(loadedMultipliers)) + maps.Copy(merged, loadedMultipliers) + classWeights := loadedTokenWeights + + if custom == nil { + return merged, classWeights + } + + // Override/add per-model multipliers (normalise keys to lowercase) + for model, mult := range custom.Multipliers { + merged[strings.ToLower(strings.TrimSpace(model))] = mult + } + + // Override per-token-class weights where non-zero values are provided + if tcw := custom.TokenClassWeights; tcw != nil { + if tcw.Input != 0 { + classWeights.Input = tcw.Input + } + if tcw.CachedInput != 0 { + classWeights.CachedInput = tcw.CachedInput + } + if tcw.Output != 0 { + classWeights.Output = tcw.Output + } + if tcw.Reasoning != 0 { + classWeights.Reasoning = tcw.Reasoning + } + if tcw.CacheWrite != 0 { + classWeights.CacheWrite = tcw.CacheWrite + } } + + return merged, classWeights +} + +// computeModelEffectiveTokensWithWeights computes effective tokens using caller-provided +// multiplier table and token class weights instead of the global defaults. +func computeModelEffectiveTokensWithWeights(model string, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int, multipliers map[string]float64, w tokenClassWeights) int { + base := w.Input*float64(inputTokens) + + w.CachedInput*float64(cacheReadTokens) + + w.Output*float64(outputTokens) + + w.CacheWrite*float64(cacheWriteTokens) + if base == 0 { + return 0 + } + + key := strings.ToLower(strings.TrimSpace(model)) + mult := 1.0 + if key != "" { + if m, ok := multipliers[key]; ok { + mult = m + } else { + // Longest prefix match + best := "" + for name, m := range multipliers { + if strings.HasPrefix(key, name) && len(name) > len(best) { + best = name + mult = m + } + } + } + } + + return int(math.Round(base * mult)) } diff --git a/pkg/cli/effective_tokens_test.go b/pkg/cli/effective_tokens_test.go index d00bdde0ccc..b1cf61c2059 100644 --- a/pkg/cli/effective_tokens_test.go +++ b/pkg/cli/effective_tokens_test.go @@ -5,6 +5,7 @@ package cli import ( "testing" + "github.com/github/gh-aw/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -223,3 +224,94 @@ func TestModelMultipliersJSONEmbedded(t *testing.T) { require.NotNil(t, loadedMultipliers, "multipliers should be loaded from embedded JSON") assert.NotEmpty(t, loadedMultipliers, "should have at least one multiplier entry") } + +func TestResolveEffectiveWeightsNoCustom(t *testing.T) { + loadedMultipliers = nil + + multipliers, classWeights := resolveEffectiveWeights(nil) + + assert.NotEmpty(t, multipliers, "should have built-in multipliers") + assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "default input weight") + assert.InDelta(t, 0.1, classWeights.CachedInput, 1e-9, "default cached input weight") + assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "default output weight") +} + +func TestResolveEffectiveWeightsCustomMultipliers(t *testing.T) { + loadedMultipliers = nil + + custom := &types.TokenWeights{ + Multipliers: map[string]float64{ + "my-custom-model": 2.5, + "claude-sonnet-4.5": 1.5, // override existing + }, + } + multipliers, classWeights := resolveEffectiveWeights(custom) + + assert.InDelta(t, 2.5, multipliers["my-custom-model"], 1e-9, "custom model multiplier") + assert.InDelta(t, 1.5, multipliers["claude-sonnet-4.5"], 1e-9, "overridden model multiplier") + // Built-in models not mentioned in custom should remain + assert.InDelta(t, 0.1, multipliers["claude-haiku-4.5"], 1e-9, "unmodified built-in multiplier") + // Class weights unchanged when not specified + assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "output weight unchanged") +} + +func TestResolveEffectiveWeightsCustomClassWeights(t *testing.T) { + loadedMultipliers = nil + + custom := &types.TokenWeights{ + TokenClassWeights: &types.TokenClassWeights{ + Output: 6.0, + CachedInput: 0.05, + }, + } + _, classWeights := resolveEffectiveWeights(custom) + + assert.InDelta(t, 6.0, classWeights.Output, 1e-9, "custom output weight") + assert.InDelta(t, 0.05, classWeights.CachedInput, 1e-9, "custom cached input weight") + // Unset fields keep their defaults + assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "input weight unchanged") + assert.InDelta(t, 4.0, classWeights.Reasoning, 1e-9, "reasoning weight unchanged") +} + +func TestPopulateEffectiveTokensWithCustomWeights(t *testing.T) { + loadedMultipliers = nil + + summary := &TokenUsageSummary{ + ByModel: map[string]*ModelTokenUsage{ + "my-custom-model": { + InputTokens: 1000, + OutputTokens: 200, + }, + "claude-sonnet-4.5": { + InputTokens: 500, + OutputTokens: 100, + }, + }, + } + + custom := &types.TokenWeights{ + Multipliers: map[string]float64{ + "my-custom-model": 3.0, + }, + } + + populateEffectiveTokensWithCustomWeights(summary, custom) + + // my-custom-model: base = 1.0*1000 + 4.0*200 = 1800; ET = 3.0 * 1800 = 5400 + customModel := summary.ByModel["my-custom-model"] + require.NotNil(t, customModel, "custom model should be present") + assert.Equal(t, 5400, customModel.EffectiveTokens, "custom model effective tokens at 3.0x") + + // claude-sonnet-4.5: base = 1.0*500 + 4.0*100 = 900; ET = 1.0 * 900 = 900 + sonnet := summary.ByModel["claude-sonnet-4.5"] + require.NotNil(t, sonnet, "sonnet should be present") + assert.Equal(t, 900, sonnet.EffectiveTokens, "sonnet effective tokens at 1x") + + assert.Equal(t, 6300, summary.TotalEffectiveTokens, "total = custom + sonnet") +} + +func TestPopulateEffectiveTokensWithCustomWeightsNilSummary(t *testing.T) { + assert.NotPanics(t, func() { + populateEffectiveTokensWithCustomWeights(nil, nil) + }) +} diff --git a/pkg/cli/logs_models.go b/pkg/cli/logs_models.go index c39490ebc9d..1c7ab54ab34 100644 --- a/pkg/cli/logs_models.go +++ b/pkg/cli/logs_models.go @@ -5,6 +5,7 @@ import ( "time" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/types" "github.com/github/gh-aw/pkg/workflow" ) @@ -267,18 +268,19 @@ type AwContext struct { // AwInfo represents the structure of aw_info.json files type AwInfo struct { - EngineID string `json:"engine_id"` - EngineName string `json:"engine_name"` - Model string `json:"model"` - Version string `json:"version"` - CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version - WorkflowName string `json:"workflow_name"` - Staged bool `json:"staged"` - AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name) - FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility) - Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata - CreatedAt string `json:"created_at"` - Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs + EngineID string `json:"engine_id"` + EngineName string `json:"engine_name"` + Model string `json:"model"` + Version string `json:"version"` + CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version + WorkflowName string `json:"workflow_name"` + Staged bool `json:"staged"` + AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name) + FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility) + Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata + CreatedAt string `json:"created_at"` + Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs + TokenWeights *types.TokenWeights `json:"token_weights,omitempty"` // Custom model cost data (from engine.token-weights) // Additional fields that might be present RunID any `json:"run_id,omitempty"` RunNumber any `json:"run_number,omitempty"` diff --git a/pkg/cli/token_usage.go b/pkg/cli/token_usage.go index 1fa37dbb6a4..f02ff9e7c77 100644 --- a/pkg/cli/token_usage.go +++ b/pkg/cli/token_usage.go @@ -11,6 +11,7 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/timeutil" + "github.com/github/gh-aw/pkg/types" ) var tokenUsageLog = logger.New("cli:token_usage") @@ -75,8 +76,10 @@ type ModelTokenUsageRow struct { // tokenUsageJSONLPath is the relative path within the firewall logs directory const tokenUsageJSONLPath = "api-proxy-logs/token-usage.jsonl" -// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary -func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) { +// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary. +// Custom weights, when non-nil, override the built-in model multipliers and token class +// weights for effective token computation. +func parseTokenUsageFile(filePath string, customWeights *types.TokenWeights) (*TokenUsageSummary, error) { tokenUsageLog.Printf("Parsing token usage file: %s", filePath) file, err := os.Open(filePath) @@ -155,8 +158,8 @@ func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) { lineNum, summary.TotalInputTokens, summary.TotalOutputTokens, summary.TotalCacheReadTokens, summary.TotalCacheWriteTokens, summary.TotalRequests) - // Compute effective tokens using per-model multipliers - populateEffectiveTokens(summary) + // Compute effective tokens using per-model multipliers (with optional custom overrides) + populateEffectiveTokensWithCustomWeights(summary, customWeights) return summary, nil } @@ -210,7 +213,9 @@ func findTokenUsageFile(runDir string) string { return "" } -// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory +// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory. +// It automatically reads custom token weights from aw_info.json when present and +// applies them to the effective token computation. func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) { tokenUsageLog.Printf("Analyzing token usage in: %s", runDir) @@ -226,7 +231,24 @@ func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) } } - return parseTokenUsageFile(filePath) + // Try to load custom token weights from aw_info.json for this run + customWeights := extractCustomTokenWeightsFromDir(runDir) + + return parseTokenUsageFile(filePath, customWeights) +} + +// extractCustomTokenWeightsFromDir reads aw_info.json from a run directory and returns +// any custom token weights embedded there at compile time. Returns nil when not found. +func extractCustomTokenWeightsFromDir(runDir string) *types.TokenWeights { + awInfoPath := findAwInfoPath(runDir) + if awInfoPath == "" { + return nil + } + awInfo, err := parseAwInfo(awInfoPath, false) + if err != nil || awInfo == nil { + return nil + } + return awInfo.TokenWeights } // TotalTokens returns the sum of all token types diff --git a/pkg/cli/token_usage_test.go b/pkg/cli/token_usage_test.go index ae43465ae12..fe218ec926f 100644 --- a/pkg/cli/token_usage_test.go +++ b/pkg/cli/token_usage_test.go @@ -20,7 +20,7 @@ func TestParseTokenUsageFile(t *testing.T) { content := `{"timestamp":"2026-04-01T17:56:38.042Z","request_id":"abc-123","provider":"anthropic","model":"claude-sonnet-4-6","path":"/v1/messages","status":200,"streaming":true,"input_tokens":100,"output_tokens":200,"cache_read_tokens":5000,"cache_write_tokens":3000,"duration_ms":2500,"response_bytes":1500}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644), "should write test file") - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") @@ -49,7 +49,7 @@ func TestParseTokenUsageFile(t *testing.T) { {"timestamp":"2026-04-01T17:58:00.000Z","request_id":"3","provider":"anthropic","model":"claude-haiku-4-5","path":"/v1/messages","status":200,"streaming":false,"input_tokens":769,"output_tokens":86,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":700,"response_bytes":500}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644), "should write test file") - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") @@ -75,7 +75,7 @@ func TestParseTokenUsageFile(t *testing.T) { filePath := filepath.Join(tmpDir, "token-usage.jsonl") require.NoError(t, os.WriteFile(filePath, []byte(""), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on empty file") assert.Nil(t, summary, "should return nil for empty file") }) @@ -85,7 +85,7 @@ func TestParseTokenUsageFile(t *testing.T) { filePath := filepath.Join(tmpDir, "token-usage.jsonl") require.NoError(t, os.WriteFile(filePath, []byte("\n\n\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on blank-only file") assert.Nil(t, summary, "should return nil for blank-only file") }) @@ -99,7 +99,7 @@ func TestParseTokenUsageFile(t *testing.T) { also not json` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on mixed content") require.NotNil(t, summary, "should return summary from valid lines") assert.Equal(t, 1, summary.TotalRequests, "should count only valid entries") @@ -107,7 +107,7 @@ also not json` }) t.Run("file not found returns error", func(t *testing.T) { - _, err := parseTokenUsageFile("/nonexistent/path/token-usage.jsonl") + _, err := parseTokenUsageFile("/nonexistent/path/token-usage.jsonl", nil) assert.Error(t, err, "should error on missing file") }) @@ -118,7 +118,7 @@ also not json` content := `{"timestamp":"2026-04-01T17:56:38.042Z","request_id":"1","provider":"anthropic","model":"","path":"/v1/messages","status":200,"streaming":true,"input_tokens":50,"output_tokens":25,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":500,"response_bytes":200}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") require.Contains(t, summary.ByModel, "unknown", "should use 'unknown' for empty model") @@ -266,7 +266,7 @@ func TestCacheEfficiency(t *testing.T) { content := `{"provider":"anthropic","model":"sonnet","input_tokens":100,"output_tokens":50,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":100}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err) require.NotNil(t, summary) assert.InDelta(t, 0.0, summary.CacheEfficiency, 0.001, "cache efficiency should be 0 with no cache reads") @@ -278,7 +278,7 @@ func TestCacheEfficiency(t *testing.T) { content := `{"provider":"anthropic","model":"sonnet","input_tokens":100,"output_tokens":50,"cache_read_tokens":9900,"cache_write_tokens":0,"duration_ms":100}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err) require.NotNil(t, summary) assert.InDelta(t, 0.99, summary.CacheEfficiency, 0.001, "cache efficiency should be ~99%") diff --git a/pkg/parser/schemas/main_workflow_schema.json b/pkg/parser/schemas/main_workflow_schema.json index 59d50266bec..474e46cb63c 100644 --- a/pkg/parser/schemas/main_workflow_schema.json +++ b/pkg/parser/schemas/main_workflow_schema.json @@ -9137,6 +9137,54 @@ "description": "Custom API endpoint hostname for the agentic engine. Used for GitHub Enterprise Cloud (GHEC), GitHub Enterprise Server (GHES), or custom AI endpoints. Example: 'api.acme.ghe.com' for GHEC, 'api.enterprise.githubcopilot.com' for GHES, or custom endpoint hostnames.", "examples": ["api.acme.ghe.com", "api.enterprise.githubcopilot.com", "api.custom.endpoint.com"] }, + "token-weights": { + "type": "object", + "description": "Custom model token weights for effective token computation. Overrides or extends the built-in model multipliers from model_multipliers.json. Useful for custom models or adjusted cost ratios.", + "properties": { + "multipliers": { + "type": "object", + "description": "Per-model cost multipliers relative to the reference model (claude-sonnet-4.5 = 1.0). Keys are model names (case-insensitive, prefix matching supported). Values are numeric multipliers.", + "additionalProperties": { + "type": "number", + "minimum": 0 + }, + "examples": [{ "my-custom-model": 2.5, "gpt-5": 3.0 }] + }, + "token-class-weights": { + "type": "object", + "description": "Per-token-class weights applied before the model multiplier. Any specified weight overrides the corresponding default.", + "properties": { + "input": { + "type": "number", + "exclusiveMinimum": 0, + "description": "Weight for input tokens (default: 1.0)" + }, + "cached-input": { + "type": "number", + "exclusiveMinimum": 0, + "description": "Weight for cached input tokens (default: 0.1)" + }, + "output": { + "type": "number", + "exclusiveMinimum": 0, + "description": "Weight for output tokens (default: 4.0)" + }, + "reasoning": { + "type": "number", + "exclusiveMinimum": 0, + "description": "Weight for reasoning tokens (default: 4.0)" + }, + "cache-write": { + "type": "number", + "exclusiveMinimum": 0, + "description": "Weight for cache write tokens (default: 1.0)" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, "args": { "type": "array", "items": { diff --git a/pkg/types/token_weights.go b/pkg/types/token_weights.go new file mode 100644 index 00000000000..94957443309 --- /dev/null +++ b/pkg/types/token_weights.go @@ -0,0 +1,26 @@ +package types + +// TokenClassWeights holds per-token-class weights for effective token computation. +// Each field corresponds to one token class; a zero value means "use default". +// The JSON keys use hyphens to match the frontmatter schema +// (engine.token-weights.token-class-weights). +type TokenClassWeights struct { + Input float64 `json:"input,omitempty"` + CachedInput float64 `json:"cached-input,omitempty"` + Output float64 `json:"output,omitempty"` + Reasoning float64 `json:"reasoning,omitempty"` + CacheWrite float64 `json:"cache-write,omitempty"` +} + +// TokenWeights defines custom model cost information for effective token computation. +// It mirrors the structure of model_multipliers.json and allows per-workflow overrides. +// Specified under engine.token-weights in the workflow frontmatter and stored in +// aw_info.json at runtime. +type TokenWeights struct { + // Multipliers maps model names to cost multipliers relative to the reference model. + // Keys are matched case-insensitively with prefix matching as a fallback. + Multipliers map[string]float64 `json:"multipliers,omitempty"` + // TokenClassWeights overrides the per-token-class weights used before the model multiplier. + // A nil pointer means no overrides; individual zero fields mean "use default". + TokenClassWeights *TokenClassWeights `json:"token-class-weights,omitempty"` +} diff --git a/pkg/workflow/compiler_yaml.go b/pkg/workflow/compiler_yaml.go index d81ad684f19..51232690f9d 100644 --- a/pkg/workflow/compiler_yaml.go +++ b/pkg/workflow/compiler_yaml.go @@ -702,6 +702,14 @@ func (c *Compiler) generateCreateAwInfo(yaml *strings.Builder, data *WorkflowDat fmt.Fprintf(yaml, " CUSTOM_GITHUB_TOKEN: %s\n", customToken) } } + // Embed custom token weights when specified in engine.token-weights + if data.EngineConfig != nil && data.EngineConfig.TokenWeights != nil { + if tokenWeightsJSON, err := json.Marshal(data.EngineConfig.TokenWeights); err == nil { + // Escape single quotes for YAML single-quoted scalar safety + escapedTokenWeightsJSON := strings.ReplaceAll(string(tokenWeightsJSON), "'", "''") + fmt.Fprintf(yaml, " GH_AW_INFO_TOKEN_WEIGHTS: '%s'\n", escapedTokenWeightsJSON) + } + } fmt.Fprintf(yaml, " uses: %s\n", GetActionPin("actions/github-script")) yaml.WriteString(" with:\n") yaml.WriteString(" script: |\n") diff --git a/pkg/workflow/engine.go b/pkg/workflow/engine.go index 2473dd47dbc..3e8f024e757 100644 --- a/pkg/workflow/engine.go +++ b/pkg/workflow/engine.go @@ -7,6 +7,7 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/stringutil" + "github.com/github/gh-aw/pkg/types" ) var engineLog = logger.New("workflow:engine") @@ -26,6 +27,9 @@ type EngineConfig struct { Args []string Agent string // Agent identifier for copilot --agent flag (copilot engine only) APITarget string // Custom API endpoint hostname (e.g., "api.acme.ghe.com" or "api.enterprise.githubcopilot.com") + // TokenWeights provides custom model cost data for effective token computation. + // When set, overrides or extends the built-in model_multipliers.json values. + TokenWeights *types.TokenWeights // Inline definition fields (populated when engine.runtime is specified in frontmatter) IsInlineDefinition bool // true when the engine is defined inline via engine.runtime + optional engine.provider @@ -284,6 +288,14 @@ func (c *Compiler) ExtractEngineConfig(frontmatter map[string]any) (string, *Eng } } + // Extract optional 'token-weights' field (custom model cost data) + if tokenWeightsRaw, hasTokenWeights := engineObj["token-weights"]; hasTokenWeights { + if tw := parseEngineTokenWeights(tokenWeightsRaw); tw != nil { + config.TokenWeights = tw + engineLog.Printf("Extracted token-weights: %d multipliers", len(tw.Multipliers)) + } + } + // Return the ID as the engineSetting for backwards compatibility engineLog.Printf("Extracted engine configuration: ID=%s", config.ID) return config.ID, config @@ -398,3 +410,69 @@ func parseRequestShape(requestObj map[string]any) *RequestShape { } return shape } + +// parseEngineTokenWeights converts a raw token-weights config value (from engine.token-weights) +// into a types.TokenWeights. Returns nil when the input is not a usable map or contains +// no recognisable data. Multiplier values of unexpected numeric types (anything other than +// float64, int, or uint64) are silently ignored — this matches the behaviour of the YAML +// parser which produces float64 for JSON-number literals and integers for integer literals. +func parseEngineTokenWeights(raw any) *types.TokenWeights { + obj, ok := raw.(map[string]any) + if !ok { + return nil + } + + tw := &types.TokenWeights{} + + // Parse multipliers: map of model name → float64 + if multipliersRaw, ok := obj["multipliers"]; ok { + if multipliersMap, ok := multipliersRaw.(map[string]any); ok && len(multipliersMap) > 0 { + tw.Multipliers = make(map[string]float64, len(multipliersMap)) + for model, val := range multipliersMap { + switch v := val.(type) { + case float64: + tw.Multipliers[model] = v + case int: + tw.Multipliers[model] = float64(v) + case uint64: + tw.Multipliers[model] = float64(v) + } + } + } + } + + // Parse token-class-weights + if tcwRaw, ok := obj["token-class-weights"]; ok { + if tcwMap, ok := tcwRaw.(map[string]any); ok { + tcw := &types.TokenClassWeights{} + setFloat := func(dst *float64, key string) { + if v, ok := tcwMap[key]; ok { + switch f := v.(type) { + case float64: + *dst = f + case int: + *dst = float64(f) + case uint64: + *dst = float64(f) + } + } + } + setFloat(&tcw.Input, "input") + setFloat(&tcw.CachedInput, "cached-input") + setFloat(&tcw.Output, "output") + setFloat(&tcw.Reasoning, "reasoning") + setFloat(&tcw.CacheWrite, "cache-write") + // Only assign if at least one weight was set + if tcw.Input != 0 || tcw.CachedInput != 0 || tcw.Output != 0 || + tcw.Reasoning != 0 || tcw.CacheWrite != 0 { + tw.TokenClassWeights = tcw + } + } + } + + // Return nil when nothing useful was parsed + if len(tw.Multipliers) == 0 && tw.TokenClassWeights == nil { + return nil + } + return tw +} diff --git a/pkg/workflow/engine_test.go b/pkg/workflow/engine_test.go index d8f65c5c837..2a4d401538a 100644 --- a/pkg/workflow/engine_test.go +++ b/pkg/workflow/engine_test.go @@ -3,7 +3,10 @@ package workflow import ( + "strings" "testing" + + "github.com/github/gh-aw/pkg/types" ) // TestEngineVersionTypeHandling tests that engine.version correctly handles @@ -326,3 +329,200 @@ func TestAPITargetExtraction(t *testing.T) { }) } } + +func TestParseEngineTokenWeights(t *testing.T) { + tests := []struct { + name string + raw any + wantNil bool + wantMultipliers map[string]float64 + wantClassWeights *types.TokenClassWeights + }{ + { + name: "nil input returns nil", + raw: nil, + wantNil: true, + }, + { + name: "non-map input returns nil", + raw: "not-a-map", + wantNil: true, + }, + { + name: "empty map returns nil", + raw: map[string]any{}, + wantNil: true, + }, + { + name: "multipliers only", + raw: map[string]any{ + "multipliers": map[string]any{ + "my-model": float64(2.5), + "gpt-5": float64(3.0), + }, + }, + wantMultipliers: map[string]float64{ + "my-model": 2.5, + "gpt-5": 3.0, + }, + }, + { + name: "token-class-weights only", + raw: map[string]any{ + "token-class-weights": map[string]any{ + "output": float64(6.0), + }, + }, + wantClassWeights: &types.TokenClassWeights{ + Output: 6.0, + }, + }, + { + name: "both multipliers and token-class-weights", + raw: map[string]any{ + "multipliers": map[string]any{ + "custom-model": float64(1.5), + }, + "token-class-weights": map[string]any{ + "input": float64(1.0), + "cached-input": float64(0.05), + "output": float64(5.0), + "reasoning": float64(5.0), + "cache-write": float64(1.0), + }, + }, + wantMultipliers: map[string]float64{"custom-model": 1.5}, + wantClassWeights: &types.TokenClassWeights{ + Input: 1.0, + CachedInput: 0.05, + Output: 5.0, + Reasoning: 5.0, + CacheWrite: 1.0, + }, + }, + { + name: "integer multiplier values are accepted", + raw: map[string]any{ + "multipliers": map[string]any{ + "int-model": int(2), + }, + }, + wantMultipliers: map[string]float64{"int-model": 2.0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseEngineTokenWeights(tt.raw) + if tt.wantNil { + if got != nil { + t.Errorf("expected nil, got %+v", got) + } + return + } + if got == nil { + t.Fatal("expected non-nil result") + } + if tt.wantMultipliers != nil { + for model, want := range tt.wantMultipliers { + if got.Multipliers[model] != want { + t.Errorf("multiplier[%q] = %v, want %v", model, got.Multipliers[model], want) + } + } + } + if tt.wantClassWeights != nil { + if got.TokenClassWeights == nil { + t.Fatal("expected TokenClassWeights to be set") + } + want := tt.wantClassWeights + tcw := got.TokenClassWeights + if want.Input != 0 && tcw.Input != want.Input { + t.Errorf("Input weight = %v, want %v", tcw.Input, want.Input) + } + if want.CachedInput != 0 && tcw.CachedInput != want.CachedInput { + t.Errorf("CachedInput weight = %v, want %v", tcw.CachedInput, want.CachedInput) + } + if want.Output != 0 && tcw.Output != want.Output { + t.Errorf("Output weight = %v, want %v", tcw.Output, want.Output) + } + if want.Reasoning != 0 && tcw.Reasoning != want.Reasoning { + t.Errorf("Reasoning weight = %v, want %v", tcw.Reasoning, want.Reasoning) + } + if want.CacheWrite != 0 && tcw.CacheWrite != want.CacheWrite { + t.Errorf("CacheWrite weight = %v, want %v", tcw.CacheWrite, want.CacheWrite) + } + } + }) + } +} + +func TestExtractEngineConfigTokenWeights(t *testing.T) { + compiler := NewCompiler() + + frontmatter := map[string]any{ + "engine": map[string]any{ + "id": "claude", + "token-weights": map[string]any{ + "multipliers": map[string]any{ + "my-custom-model": float64(2.5), + }, + "token-class-weights": map[string]any{ + "output": float64(6.0), + }, + }, + }, + } + + _, config := compiler.ExtractEngineConfig(frontmatter) + if config == nil { + t.Fatal("Expected non-nil config") + } + if config.TokenWeights == nil { + t.Fatal("Expected TokenWeights to be set") + } + if config.TokenWeights.Multipliers["my-custom-model"] != 2.5 { + t.Errorf("Expected multiplier 2.5, got %v", config.TokenWeights.Multipliers["my-custom-model"]) + } + if config.TokenWeights.TokenClassWeights == nil { + t.Fatal("Expected TokenClassWeights to be set") + } + if config.TokenWeights.TokenClassWeights.Output != 6.0 { + t.Errorf("Expected output weight 6.0, got %v", config.TokenWeights.TokenClassWeights.Output) + } +} + +func TestTokenWeightsSingleQuoteEscapingInYAML(t *testing.T) { + compiler := NewCompiler() + registry := GetGlobalEngineRegistry() + engine, err := registry.GetEngine("claude") + if err != nil { + t.Fatalf("Failed to get claude engine: %v", err) + } + + // Model name containing a single quote — must not break YAML single-quoted scalar + workflowData := &WorkflowData{ + Name: "Test Workflow", + EngineConfig: &EngineConfig{ + ID: "claude", + TokenWeights: &types.TokenWeights{ + Multipliers: map[string]float64{ + "bob's-model": 2.0, // Single quote in key + }, + }, + }, + } + + var out strings.Builder + compiler.generateCreateAwInfo(&out, workflowData, engine) + output := out.String() + + // The generated YAML must not contain an un-escaped single quote inside a single-quoted value. + // In YAML, a single quote inside a single-quoted scalar is represented as ”. + if !strings.Contains(output, "bob''s-model") { + t.Errorf("Expected single quote to be escaped as '' in YAML output, got:\n%s", output) + } + // There must be no dangling unescaped single quote inside the GH_AW_INFO_TOKEN_WEIGHTS value + if strings.Contains(output, "bob's-model") { + t.Errorf("Unescaped single quote found in YAML output:\n%s", output) + } +}