From fbbaca86f45cf9eec18c0d4329d88b9f4530fcc5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:42:11 +0000 Subject: [PATCH 1/4] feat: add custom model token weights support in engine frontmatter - Add `token-weights` field to engine schema (JSON Schema) - Add `EngineTokenWeights` and `EngineTokenClassWeights` structs to `EngineConfig` - Parse `token-weights` from frontmatter in `ExtractEngineConfig` - Embed custom weights in compiled YAML via `GH_AW_INFO_TOKEN_WEIGHTS` env var - Write token weights to `aw_info.json` in `generate_aw_info.cjs` - Add `CustomTokenWeights` type and `populateEffectiveTokensWithCustomWeights` in effective_tokens.go - Auto-read custom weights from `aw_info.json` in `analyzeTokenUsage` - Add comprehensive tests for all new functionality Agent-Logs-Url: https://github.com/github/gh-aw/sessions/e8a47829-09aa-429e-b0a5-7b1d0d5c6e7a Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- actions/setup/js/generate_aw_info.cjs | 15 ++ pkg/cli/effective_tokens.go | 115 ++++++++++++- pkg/cli/effective_tokens_test.go | 91 +++++++++++ pkg/cli/logs_models.go | 25 +-- pkg/cli/token_usage.go | 33 +++- pkg/cli/token_usage_test.go | 18 +-- pkg/parser/schemas/main_workflow_schema.json | 48 ++++++ pkg/workflow/compiler_yaml.go | 6 + pkg/workflow/engine.go | 95 +++++++++++ pkg/workflow/engine_test.go | 161 +++++++++++++++++++ 10 files changed, 572 insertions(+), 35 deletions(-) diff --git a/actions/setup/js/generate_aw_info.cjs b/actions/setup/js/generate_aw_info.cjs index e984be8ee46..8f3891fde5e 100644 --- a/actions/setup/js/generate_aw_info.cjs +++ b/actions/setup/js/generate_aw_info.cjs @@ -86,6 +86,21 @@ async function main(core, ctx) { awInfo.cli_version = cliVersion; } + // Include custom token weights when set (engine.token-weights in workflow frontmatter) + const tokenWeightsEnv = process.env.GH_AW_INFO_TOKEN_WEIGHTS; + if (tokenWeightsEnv) { + try { + const tokenWeights = JSON.parse(tokenWeightsEnv); + if (tokenWeights !== null && typeof tokenWeights === "object" && !Array.isArray(tokenWeights)) { + awInfo.token_weights = tokenWeights; + } else { + core.warning(`GH_AW_INFO_TOKEN_WEIGHTS must be a JSON object, ignoring`); + } + } catch { + core.warning(`Failed to parse GH_AW_INFO_TOKEN_WEIGHTS: ${tokenWeightsEnv}`); + } + } + // Include aw_context when the workflow was triggered via workflow_dispatch with // the aw_context input set by a calling agentic workflow's dispatch_workflow handler. // Validates JSON format and structure before populating the context key in aw_info.json. diff --git a/pkg/cli/effective_tokens.go b/pkg/cli/effective_tokens.go index 152adc397d6..64029f3f158 100644 --- a/pkg/cli/effective_tokens.go +++ b/pkg/cli/effective_tokens.go @@ -30,6 +30,7 @@ package cli import ( _ "embed" "encoding/json" + "maps" "math" "strings" @@ -50,6 +51,27 @@ type tokenClassWeights struct { CacheWrite float64 `json:"cache_write"` } +// customTokenClassWeights holds per-token-class weight overrides from aw_info.json. +// The JSON keys use hyphens to match the frontmatter schema (engine.token-weights.token-class-weights). +type customTokenClassWeights struct { + Input float64 `json:"input"` + CachedInput float64 `json:"cached-input"` + Output float64 `json:"output"` + Reasoning float64 `json:"reasoning"` + CacheWrite float64 `json:"cache-write"` +} + +// CustomTokenWeights provides per-workflow overrides for effective token computation. +// It is populated from the token_weights field in aw_info.json, which is written at +// compile time from the engine.token-weights frontmatter configuration. +type CustomTokenWeights struct { + // Multipliers maps model names to cost multipliers (relative to reference model). + // Keys are matched case-insensitively with prefix matching as fallback. + Multipliers map[string]float64 `json:"multipliers,omitempty"` + // TokenClassWeights overrides any or all per-token-class weights. + TokenClassWeights *customTokenClassWeights `json:"token-class-weights,omitempty"` +} + // modelMultipliersData is the top-level structure of model_multipliers.json. type modelMultipliersData struct { Version string `json:"version"` @@ -196,28 +218,105 @@ func computeModelEffectiveTokens(model string, inputTokens, outputTokens, cacheR // entry and computes the TotalEffectiveTokens aggregate on the summary. // It is a no-op when summary is nil. func populateEffectiveTokens(summary *TokenUsageSummary) { + populateEffectiveTokensWithCustomWeights(summary, nil) +} + +// populateEffectiveTokensWithCustomWeights is like populateEffectiveTokens but +// merges custom into the built-in weights before computing effective tokens. +// Custom weights take precedence over the defaults loaded from model_multipliers.json. +// It is a no-op when summary is nil. +func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom *CustomTokenWeights) { if summary == nil { return } + multipliers, classWeights := resolveEffectiveWeights(custom) + total := 0 for model, usage := range summary.ByModel { if usage == nil { continue } - eff := computeModelEffectiveTokens( - model, - usage.InputTokens, - usage.OutputTokens, - usage.CacheReadTokens, - usage.CacheWriteTokens, - ) + eff := computeModelEffectiveTokensWithWeights(model, usage.InputTokens, usage.OutputTokens, + usage.CacheReadTokens, usage.CacheWriteTokens, multipliers, classWeights) usage.EffectiveTokens = eff total += eff } summary.TotalEffectiveTokens = total if effectiveTokensLog.Enabled() { - effectiveTokensLog.Printf("Effective tokens: total=%d models=%d", total, len(summary.ByModel)) + effectiveTokensLog.Printf("Effective tokens: total=%d models=%d custom=%v", total, len(summary.ByModel), custom != nil) } } + +// resolveEffectiveWeights merges optional custom weights with the built-in defaults. +// The returned multipliers map is a copy so callers may not modify loadedMultipliers. +func resolveEffectiveWeights(custom *CustomTokenWeights) (map[string]float64, tokenClassWeights) { + initMultipliers() + + // Copy the base multipliers to avoid mutating the shared global + merged := make(map[string]float64, len(loadedMultipliers)) + maps.Copy(merged, loadedMultipliers) + classWeights := loadedTokenWeights + + if custom == nil { + return merged, classWeights + } + + // Override/add per-model multipliers (normalise keys to lowercase) + for model, mult := range custom.Multipliers { + merged[strings.ToLower(strings.TrimSpace(model))] = mult + } + + // Override per-token-class weights where non-zero values are provided + if tcw := custom.TokenClassWeights; tcw != nil { + if tcw.Input != 0 { + classWeights.Input = tcw.Input + } + if tcw.CachedInput != 0 { + classWeights.CachedInput = tcw.CachedInput + } + if tcw.Output != 0 { + classWeights.Output = tcw.Output + } + if tcw.Reasoning != 0 { + classWeights.Reasoning = tcw.Reasoning + } + if tcw.CacheWrite != 0 { + classWeights.CacheWrite = tcw.CacheWrite + } + } + + return merged, classWeights +} + +// computeModelEffectiveTokensWithWeights computes effective tokens using caller-provided +// multiplier table and token class weights instead of the global defaults. +func computeModelEffectiveTokensWithWeights(model string, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int, multipliers map[string]float64, w tokenClassWeights) int { + base := w.Input*float64(inputTokens) + + w.CachedInput*float64(cacheReadTokens) + + w.Output*float64(outputTokens) + + w.CacheWrite*float64(cacheWriteTokens) + if base == 0 { + return 0 + } + + key := strings.ToLower(strings.TrimSpace(model)) + mult := 1.0 + if key != "" { + if m, ok := multipliers[key]; ok { + mult = m + } else { + // Longest prefix match + best := "" + for name, m := range multipliers { + if strings.HasPrefix(key, name) && len(name) > len(best) { + best = name + mult = m + } + } + } + } + + return int(math.Round(base * mult)) +} diff --git a/pkg/cli/effective_tokens_test.go b/pkg/cli/effective_tokens_test.go index d00bdde0ccc..a13a3bd6a9d 100644 --- a/pkg/cli/effective_tokens_test.go +++ b/pkg/cli/effective_tokens_test.go @@ -223,3 +223,94 @@ func TestModelMultipliersJSONEmbedded(t *testing.T) { require.NotNil(t, loadedMultipliers, "multipliers should be loaded from embedded JSON") assert.NotEmpty(t, loadedMultipliers, "should have at least one multiplier entry") } + +func TestResolveEffectiveWeightsNoCustom(t *testing.T) { + loadedMultipliers = nil + + multipliers, classWeights := resolveEffectiveWeights(nil) + + assert.NotEmpty(t, multipliers, "should have built-in multipliers") + assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "default input weight") + assert.InDelta(t, 0.1, classWeights.CachedInput, 1e-9, "default cached input weight") + assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "default output weight") +} + +func TestResolveEffectiveWeightsCustomMultipliers(t *testing.T) { + loadedMultipliers = nil + + custom := &CustomTokenWeights{ + Multipliers: map[string]float64{ + "my-custom-model": 2.5, + "claude-sonnet-4.5": 1.5, // override existing + }, + } + multipliers, classWeights := resolveEffectiveWeights(custom) + + assert.InDelta(t, 2.5, multipliers["my-custom-model"], 1e-9, "custom model multiplier") + assert.InDelta(t, 1.5, multipliers["claude-sonnet-4.5"], 1e-9, "overridden model multiplier") + // Built-in models not mentioned in custom should remain + assert.InDelta(t, 0.1, multipliers["claude-haiku-4.5"], 1e-9, "unmodified built-in multiplier") + // Class weights unchanged when not specified + assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "output weight unchanged") +} + +func TestResolveEffectiveWeightsCustomClassWeights(t *testing.T) { + loadedMultipliers = nil + + custom := &CustomTokenWeights{ + TokenClassWeights: &customTokenClassWeights{ + Output: 6.0, + CachedInput: 0.05, + }, + } + _, classWeights := resolveEffectiveWeights(custom) + + assert.InDelta(t, 6.0, classWeights.Output, 1e-9, "custom output weight") + assert.InDelta(t, 0.05, classWeights.CachedInput, 1e-9, "custom cached input weight") + // Unset fields keep their defaults + assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "input weight unchanged") + assert.InDelta(t, 4.0, classWeights.Reasoning, 1e-9, "reasoning weight unchanged") +} + +func TestPopulateEffectiveTokensWithCustomWeights(t *testing.T) { + loadedMultipliers = nil + + summary := &TokenUsageSummary{ + ByModel: map[string]*ModelTokenUsage{ + "my-custom-model": { + InputTokens: 1000, + OutputTokens: 200, + }, + "claude-sonnet-4.5": { + InputTokens: 500, + OutputTokens: 100, + }, + }, + } + + custom := &CustomTokenWeights{ + Multipliers: map[string]float64{ + "my-custom-model": 3.0, + }, + } + + populateEffectiveTokensWithCustomWeights(summary, custom) + + // my-custom-model: base = 1.0*1000 + 4.0*200 = 1800; ET = 3.0 * 1800 = 5400 + customModel := summary.ByModel["my-custom-model"] + require.NotNil(t, customModel, "custom model should be present") + assert.Equal(t, 5400, customModel.EffectiveTokens, "custom model effective tokens at 3.0x") + + // claude-sonnet-4.5: base = 1.0*500 + 4.0*100 = 900; ET = 1.0 * 900 = 900 + sonnet := summary.ByModel["claude-sonnet-4.5"] + require.NotNil(t, sonnet, "sonnet should be present") + assert.Equal(t, 900, sonnet.EffectiveTokens, "sonnet effective tokens at 1x") + + assert.Equal(t, 6300, summary.TotalEffectiveTokens, "total = custom + sonnet") +} + +func TestPopulateEffectiveTokensWithCustomWeightsNilSummary(t *testing.T) { + assert.NotPanics(t, func() { + populateEffectiveTokensWithCustomWeights(nil, nil) + }) +} diff --git a/pkg/cli/logs_models.go b/pkg/cli/logs_models.go index c39490ebc9d..898b4b55cda 100644 --- a/pkg/cli/logs_models.go +++ b/pkg/cli/logs_models.go @@ -267,18 +267,19 @@ type AwContext struct { // AwInfo represents the structure of aw_info.json files type AwInfo struct { - EngineID string `json:"engine_id"` - EngineName string `json:"engine_name"` - Model string `json:"model"` - Version string `json:"version"` - CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version - WorkflowName string `json:"workflow_name"` - Staged bool `json:"staged"` - AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name) - FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility) - Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata - CreatedAt string `json:"created_at"` - Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs + EngineID string `json:"engine_id"` + EngineName string `json:"engine_name"` + Model string `json:"model"` + Version string `json:"version"` + CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version + WorkflowName string `json:"workflow_name"` + Staged bool `json:"staged"` + AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name) + FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility) + Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata + CreatedAt string `json:"created_at"` + Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs + TokenWeights *CustomTokenWeights `json:"token_weights,omitempty"` // Custom model cost data (from engine.token-weights) // Additional fields that might be present RunID any `json:"run_id,omitempty"` RunNumber any `json:"run_number,omitempty"` diff --git a/pkg/cli/token_usage.go b/pkg/cli/token_usage.go index 1fa37dbb6a4..39780491c77 100644 --- a/pkg/cli/token_usage.go +++ b/pkg/cli/token_usage.go @@ -75,8 +75,10 @@ type ModelTokenUsageRow struct { // tokenUsageJSONLPath is the relative path within the firewall logs directory const tokenUsageJSONLPath = "api-proxy-logs/token-usage.jsonl" -// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary -func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) { +// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary. +// Custom weights, when non-nil, override the built-in model multipliers and token class +// weights for effective token computation. +func parseTokenUsageFile(filePath string, customWeights *CustomTokenWeights) (*TokenUsageSummary, error) { tokenUsageLog.Printf("Parsing token usage file: %s", filePath) file, err := os.Open(filePath) @@ -155,8 +157,8 @@ func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) { lineNum, summary.TotalInputTokens, summary.TotalOutputTokens, summary.TotalCacheReadTokens, summary.TotalCacheWriteTokens, summary.TotalRequests) - // Compute effective tokens using per-model multipliers - populateEffectiveTokens(summary) + // Compute effective tokens using per-model multipliers (with optional custom overrides) + populateEffectiveTokensWithCustomWeights(summary, customWeights) return summary, nil } @@ -210,7 +212,9 @@ func findTokenUsageFile(runDir string) string { return "" } -// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory +// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory. +// It automatically reads custom token weights from aw_info.json when present and +// applies them to the effective token computation. func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) { tokenUsageLog.Printf("Analyzing token usage in: %s", runDir) @@ -226,7 +230,24 @@ func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) } } - return parseTokenUsageFile(filePath) + // Try to load custom token weights from aw_info.json for this run + customWeights := extractCustomTokenWeightsFromDir(runDir) + + return parseTokenUsageFile(filePath, customWeights) +} + +// extractCustomTokenWeightsFromDir reads aw_info.json from a run directory and returns +// any custom token weights embedded there at compile time. Returns nil when not found. +func extractCustomTokenWeightsFromDir(runDir string) *CustomTokenWeights { + awInfoPath := findAwInfoPath(runDir) + if awInfoPath == "" { + return nil + } + awInfo, err := parseAwInfo(awInfoPath, false) + if err != nil || awInfo == nil { + return nil + } + return awInfo.TokenWeights } // TotalTokens returns the sum of all token types diff --git a/pkg/cli/token_usage_test.go b/pkg/cli/token_usage_test.go index ae43465ae12..fe218ec926f 100644 --- a/pkg/cli/token_usage_test.go +++ b/pkg/cli/token_usage_test.go @@ -20,7 +20,7 @@ func TestParseTokenUsageFile(t *testing.T) { content := `{"timestamp":"2026-04-01T17:56:38.042Z","request_id":"abc-123","provider":"anthropic","model":"claude-sonnet-4-6","path":"/v1/messages","status":200,"streaming":true,"input_tokens":100,"output_tokens":200,"cache_read_tokens":5000,"cache_write_tokens":3000,"duration_ms":2500,"response_bytes":1500}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644), "should write test file") - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") @@ -49,7 +49,7 @@ func TestParseTokenUsageFile(t *testing.T) { {"timestamp":"2026-04-01T17:58:00.000Z","request_id":"3","provider":"anthropic","model":"claude-haiku-4-5","path":"/v1/messages","status":200,"streaming":false,"input_tokens":769,"output_tokens":86,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":700,"response_bytes":500}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644), "should write test file") - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") @@ -75,7 +75,7 @@ func TestParseTokenUsageFile(t *testing.T) { filePath := filepath.Join(tmpDir, "token-usage.jsonl") require.NoError(t, os.WriteFile(filePath, []byte(""), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on empty file") assert.Nil(t, summary, "should return nil for empty file") }) @@ -85,7 +85,7 @@ func TestParseTokenUsageFile(t *testing.T) { filePath := filepath.Join(tmpDir, "token-usage.jsonl") require.NoError(t, os.WriteFile(filePath, []byte("\n\n\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on blank-only file") assert.Nil(t, summary, "should return nil for blank-only file") }) @@ -99,7 +99,7 @@ func TestParseTokenUsageFile(t *testing.T) { also not json` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should not error on mixed content") require.NotNil(t, summary, "should return summary from valid lines") assert.Equal(t, 1, summary.TotalRequests, "should count only valid entries") @@ -107,7 +107,7 @@ also not json` }) t.Run("file not found returns error", func(t *testing.T) { - _, err := parseTokenUsageFile("/nonexistent/path/token-usage.jsonl") + _, err := parseTokenUsageFile("/nonexistent/path/token-usage.jsonl", nil) assert.Error(t, err, "should error on missing file") }) @@ -118,7 +118,7 @@ also not json` content := `{"timestamp":"2026-04-01T17:56:38.042Z","request_id":"1","provider":"anthropic","model":"","path":"/v1/messages","status":200,"streaming":true,"input_tokens":50,"output_tokens":25,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":500,"response_bytes":200}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err, "should parse without error") require.NotNil(t, summary, "should return non-nil summary") require.Contains(t, summary.ByModel, "unknown", "should use 'unknown' for empty model") @@ -266,7 +266,7 @@ func TestCacheEfficiency(t *testing.T) { content := `{"provider":"anthropic","model":"sonnet","input_tokens":100,"output_tokens":50,"cache_read_tokens":0,"cache_write_tokens":0,"duration_ms":100}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err) require.NotNil(t, summary) assert.InDelta(t, 0.0, summary.CacheEfficiency, 0.001, "cache efficiency should be 0 with no cache reads") @@ -278,7 +278,7 @@ func TestCacheEfficiency(t *testing.T) { content := `{"provider":"anthropic","model":"sonnet","input_tokens":100,"output_tokens":50,"cache_read_tokens":9900,"cache_write_tokens":0,"duration_ms":100}` require.NoError(t, os.WriteFile(filePath, []byte(content+"\n"), 0o644)) - summary, err := parseTokenUsageFile(filePath) + summary, err := parseTokenUsageFile(filePath, nil) require.NoError(t, err) require.NotNil(t, summary) assert.InDelta(t, 0.99, summary.CacheEfficiency, 0.001, "cache efficiency should be ~99%") diff --git a/pkg/parser/schemas/main_workflow_schema.json b/pkg/parser/schemas/main_workflow_schema.json index 59d50266bec..885aa5f3efa 100644 --- a/pkg/parser/schemas/main_workflow_schema.json +++ b/pkg/parser/schemas/main_workflow_schema.json @@ -9137,6 +9137,54 @@ "description": "Custom API endpoint hostname for the agentic engine. Used for GitHub Enterprise Cloud (GHEC), GitHub Enterprise Server (GHES), or custom AI endpoints. Example: 'api.acme.ghe.com' for GHEC, 'api.enterprise.githubcopilot.com' for GHES, or custom endpoint hostnames.", "examples": ["api.acme.ghe.com", "api.enterprise.githubcopilot.com", "api.custom.endpoint.com"] }, + "token-weights": { + "type": "object", + "description": "Custom model token weights for effective token computation. Overrides or extends the built-in model multipliers from model_multipliers.json. Useful for custom models or adjusted cost ratios.", + "properties": { + "multipliers": { + "type": "object", + "description": "Per-model cost multipliers relative to the reference model (claude-sonnet-4.5 = 1.0). Keys are model names (case-insensitive, prefix matching supported). Values are numeric multipliers.", + "additionalProperties": { + "type": "number", + "minimum": 0 + }, + "examples": [{ "my-custom-model": 2.5, "gpt-5": 3.0 }] + }, + "token-class-weights": { + "type": "object", + "description": "Per-token-class weights applied before the model multiplier. Any specified weight overrides the corresponding default.", + "properties": { + "input": { + "type": "number", + "minimum": 0, + "description": "Weight for input tokens (default: 1.0)" + }, + "cached-input": { + "type": "number", + "minimum": 0, + "description": "Weight for cached input tokens (default: 0.1)" + }, + "output": { + "type": "number", + "minimum": 0, + "description": "Weight for output tokens (default: 4.0)" + }, + "reasoning": { + "type": "number", + "minimum": 0, + "description": "Weight for reasoning tokens (default: 4.0)" + }, + "cache-write": { + "type": "number", + "minimum": 0, + "description": "Weight for cache write tokens (default: 1.0)" + } + }, + "additionalProperties": false + } + }, + "additionalProperties": false + }, "args": { "type": "array", "items": { diff --git a/pkg/workflow/compiler_yaml.go b/pkg/workflow/compiler_yaml.go index d81ad684f19..cc30612ae18 100644 --- a/pkg/workflow/compiler_yaml.go +++ b/pkg/workflow/compiler_yaml.go @@ -702,6 +702,12 @@ func (c *Compiler) generateCreateAwInfo(yaml *strings.Builder, data *WorkflowDat fmt.Fprintf(yaml, " CUSTOM_GITHUB_TOKEN: %s\n", customToken) } } + // Embed custom token weights when specified in engine.token-weights + if data.EngineConfig != nil && data.EngineConfig.TokenWeights != nil { + if tokenWeightsJSON, err := json.Marshal(data.EngineConfig.TokenWeights); err == nil { + fmt.Fprintf(yaml, " GH_AW_INFO_TOKEN_WEIGHTS: '%s'\n", string(tokenWeightsJSON)) + } + } fmt.Fprintf(yaml, " uses: %s\n", GetActionPin("actions/github-script")) yaml.WriteString(" with:\n") yaml.WriteString(" script: |\n") diff --git a/pkg/workflow/engine.go b/pkg/workflow/engine.go index 2473dd47dbc..4ffc0693a64 100644 --- a/pkg/workflow/engine.go +++ b/pkg/workflow/engine.go @@ -11,6 +11,27 @@ import ( var engineLog = logger.New("workflow:engine") +// EngineTokenClassWeights holds per-token-class weights for effective token computation. +// Each field corresponds to one token class; a zero value means "use default". +type EngineTokenClassWeights struct { + Input float64 `json:"input,omitempty"` + CachedInput float64 `json:"cached-input,omitempty"` + Output float64 `json:"output,omitempty"` + Reasoning float64 `json:"reasoning,omitempty"` + CacheWrite float64 `json:"cache-write,omitempty"` +} + +// EngineTokenWeights defines custom model cost information for effective token computation. +// It mirrors the structure of model_multipliers.json and allows per-workflow overrides. +// Specified under engine.token-weights in the workflow frontmatter. +type EngineTokenWeights struct { + // Multipliers maps model names to cost multipliers relative to the reference model. + // Keys are matched case-insensitively with prefix matching as a fallback. + Multipliers map[string]float64 `json:"multipliers,omitempty"` + // TokenClassWeights overrides the per-token-class weights used before the model multiplier. + TokenClassWeights *EngineTokenClassWeights `json:"token-class-weights,omitempty"` +} + // EngineConfig represents the parsed engine configuration type EngineConfig struct { ID string @@ -26,6 +47,9 @@ type EngineConfig struct { Args []string Agent string // Agent identifier for copilot --agent flag (copilot engine only) APITarget string // Custom API endpoint hostname (e.g., "api.acme.ghe.com" or "api.enterprise.githubcopilot.com") + // TokenWeights provides custom model cost data for effective token computation. + // When set, overrides or extends the built-in model_multipliers.json values. + TokenWeights *EngineTokenWeights // Inline definition fields (populated when engine.runtime is specified in frontmatter) IsInlineDefinition bool // true when the engine is defined inline via engine.runtime + optional engine.provider @@ -284,6 +308,14 @@ func (c *Compiler) ExtractEngineConfig(frontmatter map[string]any) (string, *Eng } } + // Extract optional 'token-weights' field (custom model cost data) + if tokenWeightsRaw, hasTokenWeights := engineObj["token-weights"]; hasTokenWeights { + if tw := parseEngineTokenWeights(tokenWeightsRaw); tw != nil { + config.TokenWeights = tw + engineLog.Printf("Extracted token-weights: %d multipliers", len(tw.Multipliers)) + } + } + // Return the ID as the engineSetting for backwards compatibility engineLog.Printf("Extracted engine configuration: ID=%s", config.ID) return config.ID, config @@ -398,3 +430,66 @@ func parseRequestShape(requestObj map[string]any) *RequestShape { } return shape } + +// parseEngineTokenWeights converts a raw token-weights config value (from engine.token-weights) +// into an EngineTokenWeights. Returns nil when the input is not a usable map. +func parseEngineTokenWeights(raw any) *EngineTokenWeights { + obj, ok := raw.(map[string]any) + if !ok { + return nil + } + + tw := &EngineTokenWeights{} + + // Parse multipliers: map of model name → float64 + if multipliersRaw, ok := obj["multipliers"]; ok { + if multipliersMap, ok := multipliersRaw.(map[string]any); ok && len(multipliersMap) > 0 { + tw.Multipliers = make(map[string]float64, len(multipliersMap)) + for model, val := range multipliersMap { + switch v := val.(type) { + case float64: + tw.Multipliers[model] = v + case int: + tw.Multipliers[model] = float64(v) + case uint64: + tw.Multipliers[model] = float64(v) + } + } + } + } + + // Parse token-class-weights + if tcwRaw, ok := obj["token-class-weights"]; ok { + if tcwMap, ok := tcwRaw.(map[string]any); ok { + tcw := &EngineTokenClassWeights{} + setFloat := func(dst *float64, key string) { + if v, ok := tcwMap[key]; ok { + switch f := v.(type) { + case float64: + *dst = f + case int: + *dst = float64(f) + case uint64: + *dst = float64(f) + } + } + } + setFloat(&tcw.Input, "input") + setFloat(&tcw.CachedInput, "cached-input") + setFloat(&tcw.Output, "output") + setFloat(&tcw.Reasoning, "reasoning") + setFloat(&tcw.CacheWrite, "cache-write") + // Only assign if at least one weight was set + if tcw.Input != 0 || tcw.CachedInput != 0 || tcw.Output != 0 || + tcw.Reasoning != 0 || tcw.CacheWrite != 0 { + tw.TokenClassWeights = tcw + } + } + } + + // Return nil when nothing useful was parsed + if len(tw.Multipliers) == 0 && tw.TokenClassWeights == nil { + return nil + } + return tw +} diff --git a/pkg/workflow/engine_test.go b/pkg/workflow/engine_test.go index d8f65c5c837..ce4771f2370 100644 --- a/pkg/workflow/engine_test.go +++ b/pkg/workflow/engine_test.go @@ -326,3 +326,164 @@ func TestAPITargetExtraction(t *testing.T) { }) } } + +func TestParseEngineTokenWeights(t *testing.T) { + tests := []struct { + name string + raw any + wantNil bool + wantMultipliers map[string]float64 + wantClassWeights *EngineTokenClassWeights + }{ + { + name: "nil input returns nil", + raw: nil, + wantNil: true, + }, + { + name: "non-map input returns nil", + raw: "not-a-map", + wantNil: true, + }, + { + name: "empty map returns nil", + raw: map[string]any{}, + wantNil: true, + }, + { + name: "multipliers only", + raw: map[string]any{ + "multipliers": map[string]any{ + "my-model": float64(2.5), + "gpt-5": float64(3.0), + }, + }, + wantMultipliers: map[string]float64{ + "my-model": 2.5, + "gpt-5": 3.0, + }, + }, + { + name: "token-class-weights only", + raw: map[string]any{ + "token-class-weights": map[string]any{ + "output": float64(6.0), + }, + }, + wantClassWeights: &EngineTokenClassWeights{ + Output: 6.0, + }, + }, + { + name: "both multipliers and token-class-weights", + raw: map[string]any{ + "multipliers": map[string]any{ + "custom-model": float64(1.5), + }, + "token-class-weights": map[string]any{ + "input": float64(1.0), + "cached-input": float64(0.05), + "output": float64(5.0), + "reasoning": float64(5.0), + "cache-write": float64(1.0), + }, + }, + wantMultipliers: map[string]float64{"custom-model": 1.5}, + wantClassWeights: &EngineTokenClassWeights{ + Input: 1.0, + CachedInput: 0.05, + Output: 5.0, + Reasoning: 5.0, + CacheWrite: 1.0, + }, + }, + { + name: "integer multiplier values are accepted", + raw: map[string]any{ + "multipliers": map[string]any{ + "int-model": int(2), + }, + }, + wantMultipliers: map[string]float64{"int-model": 2.0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseEngineTokenWeights(tt.raw) + if tt.wantNil { + if got != nil { + t.Errorf("expected nil, got %+v", got) + } + return + } + if got == nil { + t.Fatal("expected non-nil result") + } + if tt.wantMultipliers != nil { + for model, want := range tt.wantMultipliers { + if got.Multipliers[model] != want { + t.Errorf("multiplier[%q] = %v, want %v", model, got.Multipliers[model], want) + } + } + } + if tt.wantClassWeights != nil { + if got.TokenClassWeights == nil { + t.Fatal("expected TokenClassWeights to be set") + } + want := tt.wantClassWeights + tcw := got.TokenClassWeights + if want.Input != 0 && tcw.Input != want.Input { + t.Errorf("Input weight = %v, want %v", tcw.Input, want.Input) + } + if want.CachedInput != 0 && tcw.CachedInput != want.CachedInput { + t.Errorf("CachedInput weight = %v, want %v", tcw.CachedInput, want.CachedInput) + } + if want.Output != 0 && tcw.Output != want.Output { + t.Errorf("Output weight = %v, want %v", tcw.Output, want.Output) + } + if want.Reasoning != 0 && tcw.Reasoning != want.Reasoning { + t.Errorf("Reasoning weight = %v, want %v", tcw.Reasoning, want.Reasoning) + } + if want.CacheWrite != 0 && tcw.CacheWrite != want.CacheWrite { + t.Errorf("CacheWrite weight = %v, want %v", tcw.CacheWrite, want.CacheWrite) + } + } + }) + } +} + +func TestExtractEngineConfigTokenWeights(t *testing.T) { + compiler := NewCompiler() + + frontmatter := map[string]any{ + "engine": map[string]any{ + "id": "claude", + "token-weights": map[string]any{ + "multipliers": map[string]any{ + "my-custom-model": float64(2.5), + }, + "token-class-weights": map[string]any{ + "output": float64(6.0), + }, + }, + }, + } + + _, config := compiler.ExtractEngineConfig(frontmatter) + if config == nil { + t.Fatal("Expected non-nil config") + } + if config.TokenWeights == nil { + t.Fatal("Expected TokenWeights to be set") + } + if config.TokenWeights.Multipliers["my-custom-model"] != 2.5 { + t.Errorf("Expected multiplier 2.5, got %v", config.TokenWeights.Multipliers["my-custom-model"]) + } + if config.TokenWeights.TokenClassWeights == nil { + t.Fatal("Expected TokenClassWeights to be set") + } + if config.TokenWeights.TokenClassWeights.Output != 6.0 { + t.Errorf("Expected output weight 6.0, got %v", config.TokenWeights.TokenClassWeights.Output) + } +} From 8ec2a1e78d3be2ec29f3d6823b0b3e3914ef7d9b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 14:46:04 +0000 Subject: [PATCH 2/4] docs: address code review feedback - improve documentation for token weights Agent-Logs-Url: https://github.com/github/gh-aw/sessions/e8a47829-09aa-429e-b0a5-7b1d0d5c6e7a Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- actions/setup/js/generate_aw_info.cjs | 5 ++++- pkg/cli/effective_tokens.go | 7 ++++++- pkg/workflow/engine.go | 5 ++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/actions/setup/js/generate_aw_info.cjs b/actions/setup/js/generate_aw_info.cjs index 8f3891fde5e..f943ede658f 100644 --- a/actions/setup/js/generate_aw_info.cjs +++ b/actions/setup/js/generate_aw_info.cjs @@ -86,7 +86,10 @@ async function main(core, ctx) { awInfo.cli_version = cliVersion; } - // Include custom token weights when set (engine.token-weights in workflow frontmatter) + // Include custom token weights when set (engine.token-weights in workflow frontmatter). + // Deep structure validation is intentionally minimal here: the JSON schema and Go parser + // already validate the structure at compile time. We only verify the top-level type to + // guard against unexpected env-var values at runtime. const tokenWeightsEnv = process.env.GH_AW_INFO_TOKEN_WEIGHTS; if (tokenWeightsEnv) { try { diff --git a/pkg/cli/effective_tokens.go b/pkg/cli/effective_tokens.go index 64029f3f158..87f55e87714 100644 --- a/pkg/cli/effective_tokens.go +++ b/pkg/cli/effective_tokens.go @@ -53,6 +53,9 @@ type tokenClassWeights struct { // customTokenClassWeights holds per-token-class weight overrides from aw_info.json. // The JSON keys use hyphens to match the frontmatter schema (engine.token-weights.token-class-weights). +// A zero value for any field means "not set — use the built-in default" rather than +// "weight is zero". This matches the convention in model_multipliers.json and means users +// cannot explicitly request a zero weight through this mechanism. type customTokenClassWeights struct { Input float64 `json:"input"` CachedInput float64 `json:"cached-input"` @@ -66,9 +69,11 @@ type customTokenClassWeights struct { // compile time from the engine.token-weights frontmatter configuration. type CustomTokenWeights struct { // Multipliers maps model names to cost multipliers (relative to reference model). - // Keys are matched case-insensitively with prefix matching as fallback. + // Keys are normalized to lowercase before being stored; matching is case-insensitive + // with longest-prefix fallback (e.g. "my-model-v2" matches "my-model" if present). Multipliers map[string]float64 `json:"multipliers,omitempty"` // TokenClassWeights overrides any or all per-token-class weights. + // A nil pointer means no overrides; individual zero fields mean "use default". TokenClassWeights *customTokenClassWeights `json:"token-class-weights,omitempty"` } diff --git a/pkg/workflow/engine.go b/pkg/workflow/engine.go index 4ffc0693a64..58fcaef68b3 100644 --- a/pkg/workflow/engine.go +++ b/pkg/workflow/engine.go @@ -432,7 +432,10 @@ func parseRequestShape(requestObj map[string]any) *RequestShape { } // parseEngineTokenWeights converts a raw token-weights config value (from engine.token-weights) -// into an EngineTokenWeights. Returns nil when the input is not a usable map. +// into an EngineTokenWeights. Returns nil when the input is not a usable map or contains +// no recognisable data. Multiplier values of unexpected numeric types (anything other than +// float64, int, or uint64) are silently ignored — this matches the behaviour of the YAML +// parser which produces float64 for JSON-number literals and integers for integer literals. func parseEngineTokenWeights(raw any) *EngineTokenWeights { obj, ok := raw.(map[string]any) if !ok { From d9271078ede2d7b3c59c4d7d78be250affca55e4 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:11:56 +0000 Subject: [PATCH 3/4] fix: escape single quotes in YAML token weights env var; use exclusiveMinimum in schema - Escape single quotes in GH_AW_INFO_TOKEN_WEIGHTS value using YAML single-quoted scalar convention ('' for ') to prevent YAML parse failures with model keys containing single quotes (e.g. "bob's-model") - Change token-class-weights schema from minimum:0 to exclusiveMinimum:0 so that zero weight values are rejected by schema validation, consistent with the implementation treating zero as "use default" - Add test TestTokenWeightsSingleQuoteEscapingInYAML to verify the escaping Agent-Logs-Url: https://github.com/github/gh-aw/sessions/bc515783-422c-49bf-af85-8b62c099cb73 Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- pkg/parser/schemas/main_workflow_schema.json | 10 +++--- pkg/workflow/compiler_yaml.go | 4 ++- pkg/workflow/engine_test.go | 37 ++++++++++++++++++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/pkg/parser/schemas/main_workflow_schema.json b/pkg/parser/schemas/main_workflow_schema.json index 885aa5f3efa..474e46cb63c 100644 --- a/pkg/parser/schemas/main_workflow_schema.json +++ b/pkg/parser/schemas/main_workflow_schema.json @@ -9156,27 +9156,27 @@ "properties": { "input": { "type": "number", - "minimum": 0, + "exclusiveMinimum": 0, "description": "Weight for input tokens (default: 1.0)" }, "cached-input": { "type": "number", - "minimum": 0, + "exclusiveMinimum": 0, "description": "Weight for cached input tokens (default: 0.1)" }, "output": { "type": "number", - "minimum": 0, + "exclusiveMinimum": 0, "description": "Weight for output tokens (default: 4.0)" }, "reasoning": { "type": "number", - "minimum": 0, + "exclusiveMinimum": 0, "description": "Weight for reasoning tokens (default: 4.0)" }, "cache-write": { "type": "number", - "minimum": 0, + "exclusiveMinimum": 0, "description": "Weight for cache write tokens (default: 1.0)" } }, diff --git a/pkg/workflow/compiler_yaml.go b/pkg/workflow/compiler_yaml.go index cc30612ae18..51232690f9d 100644 --- a/pkg/workflow/compiler_yaml.go +++ b/pkg/workflow/compiler_yaml.go @@ -705,7 +705,9 @@ func (c *Compiler) generateCreateAwInfo(yaml *strings.Builder, data *WorkflowDat // Embed custom token weights when specified in engine.token-weights if data.EngineConfig != nil && data.EngineConfig.TokenWeights != nil { if tokenWeightsJSON, err := json.Marshal(data.EngineConfig.TokenWeights); err == nil { - fmt.Fprintf(yaml, " GH_AW_INFO_TOKEN_WEIGHTS: '%s'\n", string(tokenWeightsJSON)) + // Escape single quotes for YAML single-quoted scalar safety + escapedTokenWeightsJSON := strings.ReplaceAll(string(tokenWeightsJSON), "'", "''") + fmt.Fprintf(yaml, " GH_AW_INFO_TOKEN_WEIGHTS: '%s'\n", escapedTokenWeightsJSON) } } fmt.Fprintf(yaml, " uses: %s\n", GetActionPin("actions/github-script")) diff --git a/pkg/workflow/engine_test.go b/pkg/workflow/engine_test.go index ce4771f2370..a44491822ff 100644 --- a/pkg/workflow/engine_test.go +++ b/pkg/workflow/engine_test.go @@ -3,6 +3,7 @@ package workflow import ( + "strings" "testing" ) @@ -487,3 +488,39 @@ func TestExtractEngineConfigTokenWeights(t *testing.T) { t.Errorf("Expected output weight 6.0, got %v", config.TokenWeights.TokenClassWeights.Output) } } + +func TestTokenWeightsSingleQuoteEscapingInYAML(t *testing.T) { + compiler := NewCompiler() + registry := GetGlobalEngineRegistry() + engine, err := registry.GetEngine("claude") + if err != nil { + t.Fatalf("Failed to get claude engine: %v", err) + } + + // Model name containing a single quote — must not break YAML single-quoted scalar + workflowData := &WorkflowData{ + Name: "Test Workflow", + EngineConfig: &EngineConfig{ + ID: "claude", + TokenWeights: &EngineTokenWeights{ + Multipliers: map[string]float64{ + "bob's-model": 2.0, // Single quote in key + }, + }, + }, + } + + var out strings.Builder + compiler.generateCreateAwInfo(&out, workflowData, engine) + output := out.String() + + // The generated YAML must not contain an un-escaped single quote inside a single-quoted value. + // In YAML, a single quote inside a single-quoted scalar is represented as ”. + if !strings.Contains(output, "bob''s-model") { + t.Errorf("Expected single quote to be escaped as '' in YAML output, got:\n%s", output) + } + // There must be no dangling unescaped single quote inside the GH_AW_INFO_TOKEN_WEIGHTS value + if strings.Contains(output, "bob's-model") { + t.Errorf("Unescaped single quote found in YAML output:\n%s", output) + } +} From 56a2b4ccfd64f40a001cba59c8626d7fd667909a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 2 Apr 2026 15:33:16 +0000 Subject: [PATCH 4/4] refactor: move TokenWeights/TokenClassWeights to pkg/types for reuse across packages - Create pkg/types/token_weights.go with shared TokenWeights and TokenClassWeights types - Remove duplicate EngineTokenWeights/EngineTokenClassWeights from pkg/workflow/engine.go - Remove duplicate CustomTokenWeights/customTokenClassWeights from pkg/cli/effective_tokens.go - Update pkg/workflow/engine.go, compiler_yaml.go, engine_test.go to use types.TokenWeights - Update pkg/cli/effective_tokens.go, logs_models.go, token_usage.go, effective_tokens_test.go to use types.TokenWeights Follows the same pattern as pkg/types.BaseMCPServerConfig which is shared between pkg/parser and pkg/workflow. Agent-Logs-Url: https://github.com/github/gh-aw/sessions/615a693d-9408-4a21-9496-66073a045646 Co-authored-by: pelikhan <4175913+pelikhan@users.noreply.github.com> --- pkg/cli/effective_tokens.go | 31 +++---------------------------- pkg/cli/effective_tokens_test.go | 9 +++++---- pkg/cli/logs_models.go | 3 ++- pkg/cli/token_usage.go | 5 +++-- pkg/types/token_weights.go | 26 ++++++++++++++++++++++++++ pkg/workflow/engine.go | 32 ++++++-------------------------- pkg/workflow/engine_test.go | 10 ++++++---- 7 files changed, 51 insertions(+), 65 deletions(-) create mode 100644 pkg/types/token_weights.go diff --git a/pkg/cli/effective_tokens.go b/pkg/cli/effective_tokens.go index 87f55e87714..51e27fb5a71 100644 --- a/pkg/cli/effective_tokens.go +++ b/pkg/cli/effective_tokens.go @@ -35,6 +35,7 @@ import ( "strings" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/types" ) var effectiveTokensLog = logger.New("cli:effective_tokens") @@ -51,32 +52,6 @@ type tokenClassWeights struct { CacheWrite float64 `json:"cache_write"` } -// customTokenClassWeights holds per-token-class weight overrides from aw_info.json. -// The JSON keys use hyphens to match the frontmatter schema (engine.token-weights.token-class-weights). -// A zero value for any field means "not set — use the built-in default" rather than -// "weight is zero". This matches the convention in model_multipliers.json and means users -// cannot explicitly request a zero weight through this mechanism. -type customTokenClassWeights struct { - Input float64 `json:"input"` - CachedInput float64 `json:"cached-input"` - Output float64 `json:"output"` - Reasoning float64 `json:"reasoning"` - CacheWrite float64 `json:"cache-write"` -} - -// CustomTokenWeights provides per-workflow overrides for effective token computation. -// It is populated from the token_weights field in aw_info.json, which is written at -// compile time from the engine.token-weights frontmatter configuration. -type CustomTokenWeights struct { - // Multipliers maps model names to cost multipliers (relative to reference model). - // Keys are normalized to lowercase before being stored; matching is case-insensitive - // with longest-prefix fallback (e.g. "my-model-v2" matches "my-model" if present). - Multipliers map[string]float64 `json:"multipliers,omitempty"` - // TokenClassWeights overrides any or all per-token-class weights. - // A nil pointer means no overrides; individual zero fields mean "use default". - TokenClassWeights *customTokenClassWeights `json:"token-class-weights,omitempty"` -} - // modelMultipliersData is the top-level structure of model_multipliers.json. type modelMultipliersData struct { Version string `json:"version"` @@ -230,7 +205,7 @@ func populateEffectiveTokens(summary *TokenUsageSummary) { // merges custom into the built-in weights before computing effective tokens. // Custom weights take precedence over the defaults loaded from model_multipliers.json. // It is a no-op when summary is nil. -func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom *CustomTokenWeights) { +func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom *types.TokenWeights) { if summary == nil { return } @@ -256,7 +231,7 @@ func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom // resolveEffectiveWeights merges optional custom weights with the built-in defaults. // The returned multipliers map is a copy so callers may not modify loadedMultipliers. -func resolveEffectiveWeights(custom *CustomTokenWeights) (map[string]float64, tokenClassWeights) { +func resolveEffectiveWeights(custom *types.TokenWeights) (map[string]float64, tokenClassWeights) { initMultipliers() // Copy the base multipliers to avoid mutating the shared global diff --git a/pkg/cli/effective_tokens_test.go b/pkg/cli/effective_tokens_test.go index a13a3bd6a9d..b1cf61c2059 100644 --- a/pkg/cli/effective_tokens_test.go +++ b/pkg/cli/effective_tokens_test.go @@ -5,6 +5,7 @@ package cli import ( "testing" + "github.com/github/gh-aw/pkg/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -238,7 +239,7 @@ func TestResolveEffectiveWeightsNoCustom(t *testing.T) { func TestResolveEffectiveWeightsCustomMultipliers(t *testing.T) { loadedMultipliers = nil - custom := &CustomTokenWeights{ + custom := &types.TokenWeights{ Multipliers: map[string]float64{ "my-custom-model": 2.5, "claude-sonnet-4.5": 1.5, // override existing @@ -257,8 +258,8 @@ func TestResolveEffectiveWeightsCustomMultipliers(t *testing.T) { func TestResolveEffectiveWeightsCustomClassWeights(t *testing.T) { loadedMultipliers = nil - custom := &CustomTokenWeights{ - TokenClassWeights: &customTokenClassWeights{ + custom := &types.TokenWeights{ + TokenClassWeights: &types.TokenClassWeights{ Output: 6.0, CachedInput: 0.05, }, @@ -288,7 +289,7 @@ func TestPopulateEffectiveTokensWithCustomWeights(t *testing.T) { }, } - custom := &CustomTokenWeights{ + custom := &types.TokenWeights{ Multipliers: map[string]float64{ "my-custom-model": 3.0, }, diff --git a/pkg/cli/logs_models.go b/pkg/cli/logs_models.go index 898b4b55cda..1c7ab54ab34 100644 --- a/pkg/cli/logs_models.go +++ b/pkg/cli/logs_models.go @@ -5,6 +5,7 @@ import ( "time" "github.com/github/gh-aw/pkg/logger" + "github.com/github/gh-aw/pkg/types" "github.com/github/gh-aw/pkg/workflow" ) @@ -279,7 +280,7 @@ type AwInfo struct { Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata CreatedAt string `json:"created_at"` Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs - TokenWeights *CustomTokenWeights `json:"token_weights,omitempty"` // Custom model cost data (from engine.token-weights) + TokenWeights *types.TokenWeights `json:"token_weights,omitempty"` // Custom model cost data (from engine.token-weights) // Additional fields that might be present RunID any `json:"run_id,omitempty"` RunNumber any `json:"run_number,omitempty"` diff --git a/pkg/cli/token_usage.go b/pkg/cli/token_usage.go index 39780491c77..f02ff9e7c77 100644 --- a/pkg/cli/token_usage.go +++ b/pkg/cli/token_usage.go @@ -11,6 +11,7 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/timeutil" + "github.com/github/gh-aw/pkg/types" ) var tokenUsageLog = logger.New("cli:token_usage") @@ -78,7 +79,7 @@ const tokenUsageJSONLPath = "api-proxy-logs/token-usage.jsonl" // parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary. // Custom weights, when non-nil, override the built-in model multipliers and token class // weights for effective token computation. -func parseTokenUsageFile(filePath string, customWeights *CustomTokenWeights) (*TokenUsageSummary, error) { +func parseTokenUsageFile(filePath string, customWeights *types.TokenWeights) (*TokenUsageSummary, error) { tokenUsageLog.Printf("Parsing token usage file: %s", filePath) file, err := os.Open(filePath) @@ -238,7 +239,7 @@ func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) // extractCustomTokenWeightsFromDir reads aw_info.json from a run directory and returns // any custom token weights embedded there at compile time. Returns nil when not found. -func extractCustomTokenWeightsFromDir(runDir string) *CustomTokenWeights { +func extractCustomTokenWeightsFromDir(runDir string) *types.TokenWeights { awInfoPath := findAwInfoPath(runDir) if awInfoPath == "" { return nil diff --git a/pkg/types/token_weights.go b/pkg/types/token_weights.go new file mode 100644 index 00000000000..94957443309 --- /dev/null +++ b/pkg/types/token_weights.go @@ -0,0 +1,26 @@ +package types + +// TokenClassWeights holds per-token-class weights for effective token computation. +// Each field corresponds to one token class; a zero value means "use default". +// The JSON keys use hyphens to match the frontmatter schema +// (engine.token-weights.token-class-weights). +type TokenClassWeights struct { + Input float64 `json:"input,omitempty"` + CachedInput float64 `json:"cached-input,omitempty"` + Output float64 `json:"output,omitempty"` + Reasoning float64 `json:"reasoning,omitempty"` + CacheWrite float64 `json:"cache-write,omitempty"` +} + +// TokenWeights defines custom model cost information for effective token computation. +// It mirrors the structure of model_multipliers.json and allows per-workflow overrides. +// Specified under engine.token-weights in the workflow frontmatter and stored in +// aw_info.json at runtime. +type TokenWeights struct { + // Multipliers maps model names to cost multipliers relative to the reference model. + // Keys are matched case-insensitively with prefix matching as a fallback. + Multipliers map[string]float64 `json:"multipliers,omitempty"` + // TokenClassWeights overrides the per-token-class weights used before the model multiplier. + // A nil pointer means no overrides; individual zero fields mean "use default". + TokenClassWeights *TokenClassWeights `json:"token-class-weights,omitempty"` +} diff --git a/pkg/workflow/engine.go b/pkg/workflow/engine.go index 58fcaef68b3..3e8f024e757 100644 --- a/pkg/workflow/engine.go +++ b/pkg/workflow/engine.go @@ -7,31 +7,11 @@ import ( "github.com/github/gh-aw/pkg/logger" "github.com/github/gh-aw/pkg/stringutil" + "github.com/github/gh-aw/pkg/types" ) var engineLog = logger.New("workflow:engine") -// EngineTokenClassWeights holds per-token-class weights for effective token computation. -// Each field corresponds to one token class; a zero value means "use default". -type EngineTokenClassWeights struct { - Input float64 `json:"input,omitempty"` - CachedInput float64 `json:"cached-input,omitempty"` - Output float64 `json:"output,omitempty"` - Reasoning float64 `json:"reasoning,omitempty"` - CacheWrite float64 `json:"cache-write,omitempty"` -} - -// EngineTokenWeights defines custom model cost information for effective token computation. -// It mirrors the structure of model_multipliers.json and allows per-workflow overrides. -// Specified under engine.token-weights in the workflow frontmatter. -type EngineTokenWeights struct { - // Multipliers maps model names to cost multipliers relative to the reference model. - // Keys are matched case-insensitively with prefix matching as a fallback. - Multipliers map[string]float64 `json:"multipliers,omitempty"` - // TokenClassWeights overrides the per-token-class weights used before the model multiplier. - TokenClassWeights *EngineTokenClassWeights `json:"token-class-weights,omitempty"` -} - // EngineConfig represents the parsed engine configuration type EngineConfig struct { ID string @@ -49,7 +29,7 @@ type EngineConfig struct { APITarget string // Custom API endpoint hostname (e.g., "api.acme.ghe.com" or "api.enterprise.githubcopilot.com") // TokenWeights provides custom model cost data for effective token computation. // When set, overrides or extends the built-in model_multipliers.json values. - TokenWeights *EngineTokenWeights + TokenWeights *types.TokenWeights // Inline definition fields (populated when engine.runtime is specified in frontmatter) IsInlineDefinition bool // true when the engine is defined inline via engine.runtime + optional engine.provider @@ -432,17 +412,17 @@ func parseRequestShape(requestObj map[string]any) *RequestShape { } // parseEngineTokenWeights converts a raw token-weights config value (from engine.token-weights) -// into an EngineTokenWeights. Returns nil when the input is not a usable map or contains +// into a types.TokenWeights. Returns nil when the input is not a usable map or contains // no recognisable data. Multiplier values of unexpected numeric types (anything other than // float64, int, or uint64) are silently ignored — this matches the behaviour of the YAML // parser which produces float64 for JSON-number literals and integers for integer literals. -func parseEngineTokenWeights(raw any) *EngineTokenWeights { +func parseEngineTokenWeights(raw any) *types.TokenWeights { obj, ok := raw.(map[string]any) if !ok { return nil } - tw := &EngineTokenWeights{} + tw := &types.TokenWeights{} // Parse multipliers: map of model name → float64 if multipliersRaw, ok := obj["multipliers"]; ok { @@ -464,7 +444,7 @@ func parseEngineTokenWeights(raw any) *EngineTokenWeights { // Parse token-class-weights if tcwRaw, ok := obj["token-class-weights"]; ok { if tcwMap, ok := tcwRaw.(map[string]any); ok { - tcw := &EngineTokenClassWeights{} + tcw := &types.TokenClassWeights{} setFloat := func(dst *float64, key string) { if v, ok := tcwMap[key]; ok { switch f := v.(type) { diff --git a/pkg/workflow/engine_test.go b/pkg/workflow/engine_test.go index a44491822ff..2a4d401538a 100644 --- a/pkg/workflow/engine_test.go +++ b/pkg/workflow/engine_test.go @@ -5,6 +5,8 @@ package workflow import ( "strings" "testing" + + "github.com/github/gh-aw/pkg/types" ) // TestEngineVersionTypeHandling tests that engine.version correctly handles @@ -334,7 +336,7 @@ func TestParseEngineTokenWeights(t *testing.T) { raw any wantNil bool wantMultipliers map[string]float64 - wantClassWeights *EngineTokenClassWeights + wantClassWeights *types.TokenClassWeights }{ { name: "nil input returns nil", @@ -371,7 +373,7 @@ func TestParseEngineTokenWeights(t *testing.T) { "output": float64(6.0), }, }, - wantClassWeights: &EngineTokenClassWeights{ + wantClassWeights: &types.TokenClassWeights{ Output: 6.0, }, }, @@ -390,7 +392,7 @@ func TestParseEngineTokenWeights(t *testing.T) { }, }, wantMultipliers: map[string]float64{"custom-model": 1.5}, - wantClassWeights: &EngineTokenClassWeights{ + wantClassWeights: &types.TokenClassWeights{ Input: 1.0, CachedInput: 0.05, Output: 5.0, @@ -502,7 +504,7 @@ func TestTokenWeightsSingleQuoteEscapingInYAML(t *testing.T) { Name: "Test Workflow", EngineConfig: &EngineConfig{ ID: "claude", - TokenWeights: &EngineTokenWeights{ + TokenWeights: &types.TokenWeights{ Multipliers: map[string]float64{ "bob's-model": 2.0, // Single quote in key },