Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions actions/setup/js/generate_aw_info.cjs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,24 @@ async function main(core, ctx) {
awInfo.cli_version = cliVersion;
}

// Include custom token weights when set (engine.token-weights in workflow frontmatter).
// Deep structure validation is intentionally minimal here: the JSON schema and Go parser
// already validate the structure at compile time. We only verify the top-level type to
// guard against unexpected env-var values at runtime.
const tokenWeightsEnv = process.env.GH_AW_INFO_TOKEN_WEIGHTS;
if (tokenWeightsEnv) {
try {
const tokenWeights = JSON.parse(tokenWeightsEnv);
if (tokenWeights !== null && typeof tokenWeights === "object" && !Array.isArray(tokenWeights)) {
awInfo.token_weights = tokenWeights;
} else {
core.warning(`GH_AW_INFO_TOKEN_WEIGHTS must be a JSON object, ignoring`);
}
} catch {
core.warning(`Failed to parse GH_AW_INFO_TOKEN_WEIGHTS: ${tokenWeightsEnv}`);
}
}

// Include aw_context when the workflow was triggered via workflow_dispatch with
// the aw_context input set by a calling agentic workflow's dispatch_workflow handler.
// Validates JSON format and structure before populating the context key in aw_info.json.
Expand Down
95 changes: 87 additions & 8 deletions pkg/cli/effective_tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ package cli
import (
_ "embed"
"encoding/json"
"maps"
"math"
"strings"

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/types"
)

var effectiveTokensLog = logger.New("cli:effective_tokens")
Expand Down Expand Up @@ -196,28 +198,105 @@ func computeModelEffectiveTokens(model string, inputTokens, outputTokens, cacheR
// entry and computes the TotalEffectiveTokens aggregate on the summary.
// It is a no-op when summary is nil.
func populateEffectiveTokens(summary *TokenUsageSummary) {
populateEffectiveTokensWithCustomWeights(summary, nil)
}

// populateEffectiveTokensWithCustomWeights is like populateEffectiveTokens but
// merges custom into the built-in weights before computing effective tokens.
// Custom weights take precedence over the defaults loaded from model_multipliers.json.
// It is a no-op when summary is nil.
func populateEffectiveTokensWithCustomWeights(summary *TokenUsageSummary, custom *types.TokenWeights) {
if summary == nil {
return
}

multipliers, classWeights := resolveEffectiveWeights(custom)

total := 0
for model, usage := range summary.ByModel {
if usage == nil {
continue
}
eff := computeModelEffectiveTokens(
model,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadTokens,
usage.CacheWriteTokens,
)
eff := computeModelEffectiveTokensWithWeights(model, usage.InputTokens, usage.OutputTokens,
usage.CacheReadTokens, usage.CacheWriteTokens, multipliers, classWeights)
usage.EffectiveTokens = eff
total += eff
}
summary.TotalEffectiveTokens = total

if effectiveTokensLog.Enabled() {
effectiveTokensLog.Printf("Effective tokens: total=%d models=%d", total, len(summary.ByModel))
effectiveTokensLog.Printf("Effective tokens: total=%d models=%d custom=%v", total, len(summary.ByModel), custom != nil)
}
}

// resolveEffectiveWeights merges optional custom weights with the built-in defaults.
// The returned multipliers map is a copy so callers may not modify loadedMultipliers.
func resolveEffectiveWeights(custom *types.TokenWeights) (map[string]float64, tokenClassWeights) {
initMultipliers()

// Copy the base multipliers to avoid mutating the shared global
merged := make(map[string]float64, len(loadedMultipliers))
maps.Copy(merged, loadedMultipliers)
classWeights := loadedTokenWeights

if custom == nil {
return merged, classWeights
}

// Override/add per-model multipliers (normalise keys to lowercase)
for model, mult := range custom.Multipliers {
merged[strings.ToLower(strings.TrimSpace(model))] = mult
}

// Override per-token-class weights where non-zero values are provided
if tcw := custom.TokenClassWeights; tcw != nil {
if tcw.Input != 0 {
classWeights.Input = tcw.Input
}
if tcw.CachedInput != 0 {
classWeights.CachedInput = tcw.CachedInput
}
if tcw.Output != 0 {
classWeights.Output = tcw.Output
}
if tcw.Reasoning != 0 {
classWeights.Reasoning = tcw.Reasoning
}
if tcw.CacheWrite != 0 {
classWeights.CacheWrite = tcw.CacheWrite
}
}

return merged, classWeights
}

// computeModelEffectiveTokensWithWeights computes effective tokens using caller-provided
// multiplier table and token class weights instead of the global defaults.
func computeModelEffectiveTokensWithWeights(model string, inputTokens, outputTokens, cacheReadTokens, cacheWriteTokens int, multipliers map[string]float64, w tokenClassWeights) int {
base := w.Input*float64(inputTokens) +
w.CachedInput*float64(cacheReadTokens) +
w.Output*float64(outputTokens) +
w.CacheWrite*float64(cacheWriteTokens)
if base == 0 {
return 0
}

key := strings.ToLower(strings.TrimSpace(model))
mult := 1.0
if key != "" {
if m, ok := multipliers[key]; ok {
mult = m
} else {
// Longest prefix match
best := ""
for name, m := range multipliers {
if strings.HasPrefix(key, name) && len(name) > len(best) {
best = name
mult = m
}
}
}
}

return int(math.Round(base * mult))
}
92 changes: 92 additions & 0 deletions pkg/cli/effective_tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cli
import (
"testing"

"github.com/github/gh-aw/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -223,3 +224,94 @@ func TestModelMultipliersJSONEmbedded(t *testing.T) {
require.NotNil(t, loadedMultipliers, "multipliers should be loaded from embedded JSON")
assert.NotEmpty(t, loadedMultipliers, "should have at least one multiplier entry")
}

func TestResolveEffectiveWeightsNoCustom(t *testing.T) {
loadedMultipliers = nil

multipliers, classWeights := resolveEffectiveWeights(nil)

assert.NotEmpty(t, multipliers, "should have built-in multipliers")
assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "default input weight")
assert.InDelta(t, 0.1, classWeights.CachedInput, 1e-9, "default cached input weight")
assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "default output weight")
}

func TestResolveEffectiveWeightsCustomMultipliers(t *testing.T) {
loadedMultipliers = nil

custom := &types.TokenWeights{
Multipliers: map[string]float64{
"my-custom-model": 2.5,
"claude-sonnet-4.5": 1.5, // override existing
},
}
multipliers, classWeights := resolveEffectiveWeights(custom)

assert.InDelta(t, 2.5, multipliers["my-custom-model"], 1e-9, "custom model multiplier")
assert.InDelta(t, 1.5, multipliers["claude-sonnet-4.5"], 1e-9, "overridden model multiplier")
// Built-in models not mentioned in custom should remain
assert.InDelta(t, 0.1, multipliers["claude-haiku-4.5"], 1e-9, "unmodified built-in multiplier")
// Class weights unchanged when not specified
assert.InDelta(t, 4.0, classWeights.Output, 1e-9, "output weight unchanged")
}

func TestResolveEffectiveWeightsCustomClassWeights(t *testing.T) {
loadedMultipliers = nil

custom := &types.TokenWeights{
TokenClassWeights: &types.TokenClassWeights{
Output: 6.0,
CachedInput: 0.05,
},
}
_, classWeights := resolveEffectiveWeights(custom)

assert.InDelta(t, 6.0, classWeights.Output, 1e-9, "custom output weight")
assert.InDelta(t, 0.05, classWeights.CachedInput, 1e-9, "custom cached input weight")
// Unset fields keep their defaults
assert.InDelta(t, 1.0, classWeights.Input, 1e-9, "input weight unchanged")
assert.InDelta(t, 4.0, classWeights.Reasoning, 1e-9, "reasoning weight unchanged")
}

func TestPopulateEffectiveTokensWithCustomWeights(t *testing.T) {
loadedMultipliers = nil

summary := &TokenUsageSummary{
ByModel: map[string]*ModelTokenUsage{
"my-custom-model": {
InputTokens: 1000,
OutputTokens: 200,
},
"claude-sonnet-4.5": {
InputTokens: 500,
OutputTokens: 100,
},
},
}

custom := &types.TokenWeights{
Multipliers: map[string]float64{
"my-custom-model": 3.0,
},
}

populateEffectiveTokensWithCustomWeights(summary, custom)

// my-custom-model: base = 1.0*1000 + 4.0*200 = 1800; ET = 3.0 * 1800 = 5400
customModel := summary.ByModel["my-custom-model"]
require.NotNil(t, customModel, "custom model should be present")
assert.Equal(t, 5400, customModel.EffectiveTokens, "custom model effective tokens at 3.0x")

// claude-sonnet-4.5: base = 1.0*500 + 4.0*100 = 900; ET = 1.0 * 900 = 900
sonnet := summary.ByModel["claude-sonnet-4.5"]
require.NotNil(t, sonnet, "sonnet should be present")
assert.Equal(t, 900, sonnet.EffectiveTokens, "sonnet effective tokens at 1x")

assert.Equal(t, 6300, summary.TotalEffectiveTokens, "total = custom + sonnet")
}

func TestPopulateEffectiveTokensWithCustomWeightsNilSummary(t *testing.T) {
assert.NotPanics(t, func() {
populateEffectiveTokensWithCustomWeights(nil, nil)
})
}
26 changes: 14 additions & 12 deletions pkg/cli/logs_models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/types"
"github.com/github/gh-aw/pkg/workflow"
)

Expand Down Expand Up @@ -267,18 +268,19 @@ type AwContext struct {

// AwInfo represents the structure of aw_info.json files
type AwInfo struct {
EngineID string `json:"engine_id"`
EngineName string `json:"engine_name"`
Model string `json:"model"`
Version string `json:"version"`
CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version
WorkflowName string `json:"workflow_name"`
Staged bool `json:"staged"`
AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name)
FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility)
Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata
CreatedAt string `json:"created_at"`
Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs
EngineID string `json:"engine_id"`
EngineName string `json:"engine_name"`
Model string `json:"model"`
Version string `json:"version"`
CLIVersion string `json:"cli_version,omitempty"` // gh-aw CLI version
WorkflowName string `json:"workflow_name"`
Staged bool `json:"staged"`
AwfVersion string `json:"awf_version,omitempty"` // AWF firewall version (new name)
FirewallVersion string `json:"firewall_version,omitempty"` // AWF firewall version (old name, for backward compatibility)
Steps AwInfoSteps `json:"steps,omitzero"` // Steps metadata
CreatedAt string `json:"created_at"`
Context *AwContext `json:"context,omitempty"` // aw_context data passed via workflow_dispatch inputs
TokenWeights *types.TokenWeights `json:"token_weights,omitempty"` // Custom model cost data (from engine.token-weights)
// Additional fields that might be present
RunID any `json:"run_id,omitempty"`
RunNumber any `json:"run_number,omitempty"`
Expand Down
34 changes: 28 additions & 6 deletions pkg/cli/token_usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/timeutil"
"github.com/github/gh-aw/pkg/types"
)

var tokenUsageLog = logger.New("cli:token_usage")
Expand Down Expand Up @@ -75,8 +76,10 @@ type ModelTokenUsageRow struct {
// tokenUsageJSONLPath is the relative path within the firewall logs directory
const tokenUsageJSONLPath = "api-proxy-logs/token-usage.jsonl"

// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary
func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) {
// parseTokenUsageFile parses a token-usage.jsonl file and returns the aggregated summary.
// Custom weights, when non-nil, override the built-in model multipliers and token class
// weights for effective token computation.
func parseTokenUsageFile(filePath string, customWeights *types.TokenWeights) (*TokenUsageSummary, error) {
tokenUsageLog.Printf("Parsing token usage file: %s", filePath)

file, err := os.Open(filePath)
Expand Down Expand Up @@ -155,8 +158,8 @@ func parseTokenUsageFile(filePath string) (*TokenUsageSummary, error) {
lineNum, summary.TotalInputTokens, summary.TotalOutputTokens,
summary.TotalCacheReadTokens, summary.TotalCacheWriteTokens, summary.TotalRequests)

// Compute effective tokens using per-model multipliers
populateEffectiveTokens(summary)
// Compute effective tokens using per-model multipliers (with optional custom overrides)
populateEffectiveTokensWithCustomWeights(summary, customWeights)

return summary, nil
}
Expand Down Expand Up @@ -210,7 +213,9 @@ func findTokenUsageFile(runDir string) string {
return ""
}

// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory
// analyzeTokenUsage finds and parses the token-usage.jsonl file from a run directory.
// It automatically reads custom token weights from aw_info.json when present and
// applies them to the effective token computation.
func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error) {
tokenUsageLog.Printf("Analyzing token usage in: %s", runDir)

Expand All @@ -226,7 +231,24 @@ func analyzeTokenUsage(runDir string, verbose bool) (*TokenUsageSummary, error)
}
}

return parseTokenUsageFile(filePath)
// Try to load custom token weights from aw_info.json for this run
customWeights := extractCustomTokenWeightsFromDir(runDir)

return parseTokenUsageFile(filePath, customWeights)
}

// extractCustomTokenWeightsFromDir reads aw_info.json from a run directory and returns
// any custom token weights embedded there at compile time. Returns nil when not found.
func extractCustomTokenWeightsFromDir(runDir string) *types.TokenWeights {
awInfoPath := findAwInfoPath(runDir)
if awInfoPath == "" {
return nil
}
awInfo, err := parseAwInfo(awInfoPath, false)
if err != nil || awInfo == nil {
return nil
}
return awInfo.TokenWeights
}

// TotalTokens returns the sum of all token types
Expand Down
Loading
Loading