Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions crates/forge_app/src/dto/anthropic/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,12 +657,14 @@ mod tests {
assert_eq!(delta_domain.completion_tokens, TokenCount::Actual(75));
assert_eq!(delta_domain.cached_tokens, TokenCount::Actual(0));

// Accumulate usage (simulating how we'd combine them in practice)
let accumulated = initial_domain.accumulate(&delta_domain);
assert_eq!(accumulated.prompt_tokens, TokenCount::Actual(150));
assert_eq!(accumulated.completion_tokens, TokenCount::Actual(75));
assert_eq!(accumulated.cached_tokens, TokenCount::Actual(50));
assert_eq!(accumulated.total_tokens, TokenCount::Actual(225));
// Merge usage (simulating how we'd combine them in practice)
// Using merge (max) instead of accumulate (sum) since Anthropic
// usage values are cumulative, not incremental deltas.
let merged = initial_domain.merge(&delta_domain);
assert_eq!(merged.prompt_tokens, TokenCount::Actual(150));
assert_eq!(merged.completion_tokens, TokenCount::Actual(75));
assert_eq!(merged.cached_tokens, TokenCount::Actual(50));
assert_eq!(merged.total_tokens, TokenCount::Actual(150)); // max(150, 75)
}

#[test]
Expand Down
15 changes: 15 additions & 0 deletions crates/forge_domain/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,21 @@ impl Default for TokenCount {
}
}

impl TokenCount {
/// Returns the larger of two TokenCount values by their inner count.
/// If both are `Actual`, the result is `Actual`. If either is `Approx`,
/// the result is `Approx`.
pub fn max(self, other: TokenCount) -> TokenCount {
use TokenCount::*;
match (self, other) {
(Actual(a), Actual(b)) => Actual(a.max(b)),
(Actual(a), Approx(b)) => Approx(a.max(b)),
(Approx(a), Actual(b)) => Approx(a.max(b)),
(Approx(a), Approx(b)) => Approx(a.max(b)),
}
}
Comment thread
amitksingh1490 marked this conversation as resolved.
}

impl Deref for TokenCount {
type Target = usize;

Expand Down
98 changes: 96 additions & 2 deletions crates/forge_domain/src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ pub struct Usage {
}

impl Usage {
/// Accumulates usage from another Usage instance
/// Cost is summed, tokens are added using TokenCount's Add implementation
/// Accumulates usage from another Usage instance by summing all fields.
///
/// Use this for aggregating usage across **independent** requests (e.g.,
/// session-level totals where each message has its own final usage).
pub fn accumulate(mut self, other: &Usage) -> Self {
self.prompt_tokens = self.prompt_tokens + other.prompt_tokens;
self.completion_tokens = self.completion_tokens + other.completion_tokens;
Expand All @@ -46,6 +48,34 @@ impl Usage {
};
self
}

/// Merges usage from another Usage instance using a "last non-zero wins"
/// strategy.
///
/// Use this when combining **partial** usage events within a single
/// streaming response where values are **cumulative** (not incremental):
/// - `message_start`: `input_tokens=1000, output_tokens=1`
/// - `message_delta`: `input_tokens=0, output_tokens=75` (cumulative
/// total)
///
/// For each field, the larger of the two values is kept. This prevents
/// double-counting when providers report cumulative token counts across
/// multiple events.
///
/// Cost is summed since cost events are always additive.
pub fn merge(mut self, other: &Usage) -> Self {
self.prompt_tokens = self.prompt_tokens.max(other.prompt_tokens);
self.completion_tokens = self.completion_tokens.max(other.completion_tokens);
self.total_tokens = self.total_tokens.max(other.total_tokens);
self.cached_tokens = self.cached_tokens.max(other.cached_tokens);
self.cost = match (self.cost, other.cost) {
(Some(a), Some(b)) => Some(a + b),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
self
}
}

/// Represents a message that was received from the LLM provider
Expand Down Expand Up @@ -374,4 +404,68 @@ mod tests {
FinishReason::Stop
);
}

#[test]
fn test_usage_merge_anthropic_cumulative() {
// Fixture: Simulates Anthropic's message_start + message_delta pattern
// where output_tokens in message_delta is CUMULATIVE (total), not a delta.
let fixture_message_start = Usage {
prompt_tokens: TokenCount::Actual(1000),
completion_tokens: TokenCount::Actual(1), // Initial output token
total_tokens: TokenCount::Actual(1001),
cached_tokens: TokenCount::Actual(300),
cost: None,
};

let fixture_message_delta = Usage {
prompt_tokens: TokenCount::Actual(0),
completion_tokens: TokenCount::Actual(75), // Cumulative total, NOT delta
total_tokens: TokenCount::Actual(75),
cached_tokens: TokenCount::Actual(0),
cost: None,
};

let actual = fixture_message_start.merge(&fixture_message_delta);

let expected = Usage {
prompt_tokens: TokenCount::Actual(1000), // max(1000, 0)
completion_tokens: TokenCount::Actual(75), // max(1, 75) = 75, NOT 1+75=76
total_tokens: TokenCount::Actual(1001), // max(1001, 75)
cached_tokens: TokenCount::Actual(300), // max(300, 0)
cost: None,
};

assert_eq!(actual, expected);
}

#[test]
fn test_usage_merge_preserves_costs() {
let fixture_usage_1 = Usage {
prompt_tokens: TokenCount::Actual(100),
completion_tokens: TokenCount::Actual(0),
total_tokens: TokenCount::Actual(100),
cached_tokens: TokenCount::Actual(0),
cost: Some(0.01),
};

let fixture_usage_2 = Usage {
prompt_tokens: TokenCount::Actual(0),
completion_tokens: TokenCount::Actual(50),
total_tokens: TokenCount::Actual(50),
cached_tokens: TokenCount::Actual(0),
cost: Some(0.02),
};

let actual = fixture_usage_1.merge(&fixture_usage_2);

let expected = Usage {
prompt_tokens: TokenCount::Actual(100),
completion_tokens: TokenCount::Actual(50),
total_tokens: TokenCount::Actual(100),
cached_tokens: TokenCount::Actual(0),
cost: Some(0.03), // Costs are summed, not maxed
};

assert_eq!(actual, expected);
}
}
94 changes: 85 additions & 9 deletions crates/forge_domain/src/result_stream_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,17 @@ impl ResultStreamExt<anyhow::Error> for crate::BoxStream<ChatCompletionMessage,
anyhow::Ok(message?).with_context(|| "Failed to process message stream")?;
// Process usage information
// - For Anthropic-style streaming: input tokens in MessageStart, output tokens
// in MessageDelta
// in MessageDelta (values are CUMULATIVE, not incremental)
// ref: https://platform.claude.com/docs/en/build-with-claude/streaming#event-types
// - For OpenAI-style streaming: all tokens in the final chunk
// - For GLM-style: may send complete usage in every chunk (need to replace, not
// accumulate)
// - For Google-style: cumulative usage in every chunk
// - Cost-only events: have 0 tokens but a cost value
if let Some(current_usage) = message.usage.as_ref() {
// If current usage has both prompt and completion tokens, it's a "complete"
// usage In this case, replace instead of accumulate (handles
// GLM-style streaming)
// usage. In this case, replace instead of merge (handles GLM-style streaming
// where every chunk has full usage).
let is_complete_usage =
*current_usage.prompt_tokens > 0 && *current_usage.completion_tokens > 0;

Expand All @@ -95,10 +97,19 @@ impl ResultStreamExt<anyhow::Error> for crate::BoxStream<ChatCompletionMessage,
}
} else if is_cost_only {
// Accumulate only the cost to the existing usage
usage.cost = current_usage.cost;
usage.cost = match (usage.cost, current_usage.cost) {
(Some(a), Some(b)) => Some(a + b),
(Some(a), None) => Some(a),
(None, Some(b)) => Some(b),
(None, None) => None,
};
} else {
// Accumulate partial usage (for Anthropic-style streaming)
usage = usage.accumulate(current_usage);
// Merge partial usage using "max" strategy. This correctly handles
// providers like Anthropic where usage values are CUMULATIVE across
// events (message_start has input tokens, message_delta has the
// total output tokens). Using max instead of sum prevents
// double-counting when message_start includes output_tokens=1.
usage = usage.merge(current_usage);
}
}

Expand Down Expand Up @@ -485,8 +496,73 @@ mod tests {
}

#[tokio::test]
async fn test_into_full_anthropic_streaming_usage_accumulation() {
async fn test_into_full_anthropic_streaming_usage_merge() {
// Fixture: Simulate Anthropic streaming pattern where message_start has
// output_tokens=1 (the common case) and message_delta has the cumulative total.
// This tests that merge (max) is used instead of accumulate (sum) to prevent
// double-counting.
let messages = vec![
// MessageStart with input token usage AND output_tokens=1
Ok(ChatCompletionMessage::default().usage(Usage {
prompt_tokens: TokenCount::Actual(1000),
completion_tokens: TokenCount::Actual(1),
total_tokens: TokenCount::Actual(1001),
cached_tokens: TokenCount::Actual(300),
cost: None,
})),
// Content deltas
Ok(ChatCompletionMessage::default().content(Content::part("Hello "))),
Ok(ChatCompletionMessage::default().content(Content::part("world!"))),
// MessageDelta with cumulative output token usage
Ok(ChatCompletionMessage::default()
.usage(Usage {
prompt_tokens: TokenCount::Actual(0),
completion_tokens: TokenCount::Actual(50),
total_tokens: TokenCount::Actual(50),
cached_tokens: TokenCount::Actual(0),
cost: None,
})
.finish_reason(FinishReason::Stop)),
];

let result_stream: BoxStream<ChatCompletionMessage, anyhow::Error> =
Box::pin(tokio_stream::iter(messages));

// Actual: Convert stream to full message
let actual = result_stream.into_full(false).await.unwrap();

// Expected: Usage should use max (merge) not sum (accumulate).
// message_start has completion_tokens=1 and prompt_tokens=1000, so
// is_complete_usage=true -> replace: usage = {1000, 1, 1001, 300}
// message_delta has prompt=0, completion=50 -> is_complete_usage=false ->
// merge: prompt = max(1000, 0) = 1000
// completion = max(1, 50) = 50 (NOT 1+50=51)
// total = max(1001, 50) = 1001
// cached = max(300, 0) = 300
let expected = ChatCompletionMessageFull {
content: "Hello world!".to_string(),
tool_calls: vec![],
thought_signature: None,
usage: Usage {
prompt_tokens: TokenCount::Actual(1000),
completion_tokens: TokenCount::Actual(50), // max(1, 50) = 50, NOT 1+50=51
total_tokens: TokenCount::Actual(1001),
cached_tokens: TokenCount::Actual(300),
cost: None,
},
reasoning: None,
reasoning_details: None,
finish_reason: Some(FinishReason::Stop),
phase: None,
};

assert_eq!(actual, expected);
}

#[tokio::test]
async fn test_into_full_anthropic_streaming_usage_merge_zero_output() {
// Fixture: Simulate Anthropic/Vertex AI Anthropic streaming pattern
// where message_start has output_tokens=0 (Vertex AI pattern).
// MessageStart event has input tokens, MessageDelta has output tokens
let messages = vec![
// MessageStart with input token usage
Expand Down Expand Up @@ -518,15 +594,15 @@ mod tests {
// Actual: Convert stream to full message
let actual = result_stream.into_full(false).await.unwrap();

// Expected: Usage should be accumulated from both MessageStart and MessageDelta
// Expected: Usage should be merged from both MessageStart and MessageDelta
let expected = ChatCompletionMessageFull {
content: "Hello world!".to_string(),
tool_calls: vec![],
thought_signature: None,
usage: Usage {
prompt_tokens: TokenCount::Actual(1000), // From MessageStart
completion_tokens: TokenCount::Actual(50), // From MessageDelta
total_tokens: TokenCount::Actual(1050), // Sum of both
total_tokens: TokenCount::Actual(1000), // max(1000, 50) = 1000
cached_tokens: TokenCount::Actual(300), // From MessageStart
cost: None,
},
Expand Down
4 changes: 2 additions & 2 deletions crates/forge_repo/src/provider/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ impl IntoDomain for aws_sdk_bedrockruntime::types::ConverseStreamOutput {
.saturating_add(u.cache_write_input_tokens.unwrap_or(0));

forge_domain::Usage {
prompt_tokens: forge_domain::TokenCount::Actual(u.total_tokens as usize),
prompt_tokens: forge_domain::TokenCount::Actual(u.input_tokens as usize),
completion_tokens: forge_domain::TokenCount::Actual(
u.output_tokens as usize,
),
Expand Down Expand Up @@ -1418,7 +1418,7 @@ mod tests {
let actual = fixture.into_domain();
let expected =
ChatCompletionMessage::assistant(Content::part("")).usage(forge_domain::Usage {
prompt_tokens: TokenCount::Actual(1000),
prompt_tokens: TokenCount::Actual(800),
completion_tokens: TokenCount::Actual(200),
total_tokens: TokenCount::Actual(1000),
cached_tokens: TokenCount::Actual(80), // 50 + 30
Expand Down
4 changes: 4 additions & 0 deletions crates/forge_repo/src/provider/openai_responses/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ pub(super) enum StreamItem {
Message(Box<ChatCompletionMessage>),
}

/// Converts OpenAI Responses API usage into the domain Usage type.
/// Usage is sent once in the `response.completed` event (not split across
/// events).
/// ref: https://developers.openai.com/api/reference/resources/responses#(resource)%20responses%20%3E%20(model)%20response_usage%20%3E%20(schema)
impl IntoDomain for oai::ResponseUsage {
type Domain = Usage;

Expand Down
Loading