diff --git a/src/copilot_usage/models.py b/src/copilot_usage/models.py index 2a74f80..2e14ee8 100644 --- a/src/copilot_usage/models.py +++ b/src/copilot_usage/models.py @@ -70,6 +70,26 @@ class ModelMetrics(BaseModel): usage: TokenUsage = Field(default_factory=TokenUsage) +def merge_model_metrics( + base: dict[str, ModelMetrics], + additional: dict[str, ModelMetrics], +) -> dict[str, ModelMetrics]: + """Return a new dict merging *additional* into *base* without mutation.""" + result = {name: mm.model_copy(deep=True) for name, mm in base.items()} + for name, mm in additional.items(): + if name in result: + existing = result[name] + existing.requests.count += mm.requests.count + existing.requests.cost += mm.requests.cost + existing.usage.inputTokens += mm.usage.inputTokens + existing.usage.outputTokens += mm.usage.outputTokens + existing.usage.cacheReadTokens += mm.usage.cacheReadTokens + existing.usage.cacheWriteTokens += mm.usage.cacheWriteTokens + else: + result[name] = mm.model_copy(deep=True) + return result + + class CodeChanges(BaseModel): """Code‐change stats from a session.shutdown event.""" diff --git a/src/copilot_usage/parser.py b/src/copilot_usage/parser.py index cf855a5..7cd7072 100644 --- a/src/copilot_usage/parser.py +++ b/src/copilot_usage/parser.py @@ -17,13 +17,13 @@ CodeChanges, EventType, ModelMetrics, - RequestMetrics, SessionEvent, SessionShutdownData, SessionStartData, SessionSummary, TokenUsage, ToolExecutionData, + merge_model_metrics, ) _DEFAULT_BASE: Path = Path.home() / ".copilot" / "session-state" @@ -276,27 +276,7 @@ def build_session_summary( total_api_duration += sd.totalApiDurationMs if sd.codeChanges is not None: last_code_changes = sd.codeChanges - for model_name, metrics in sd.modelMetrics.items(): - if model_name in merged_metrics: - existing = merged_metrics[model_name] - merged_metrics[model_name] = ModelMetrics( - requests=RequestMetrics( - count=existing.requests.count + metrics.requests.count, - cost=existing.requests.cost + metrics.requests.cost, - ), - usage=TokenUsage( - inputTokens=existing.usage.inputTokens - + metrics.usage.inputTokens, - outputTokens=existing.usage.outputTokens - + metrics.usage.outputTokens, - cacheReadTokens=existing.usage.cacheReadTokens - + metrics.usage.cacheReadTokens, - cacheWriteTokens=existing.usage.cacheWriteTokens - + metrics.usage.cacheWriteTokens, - ), - ) - else: - merged_metrics[model_name] = metrics + merged_metrics = merge_model_metrics(merged_metrics, sd.modelMetrics) return SessionSummary( session_id=session_id, diff --git a/src/copilot_usage/report.py b/src/copilot_usage/report.py index 419686c..0d55787 100644 --- a/src/copilot_usage/report.py +++ b/src/copilot_usage/report.py @@ -23,6 +23,7 @@ SessionSummary, ToolExecutionData, UserMessageData, + merge_model_metrics, ) __all__ = [ @@ -547,20 +548,7 @@ def _aggregate_model_metrics( """Merge model metrics across all sessions into a single dict.""" merged: dict[str, ModelMetrics] = {} for s in sessions: - for model_name, mm in s.model_metrics.items(): - if model_name not in merged: - merged[model_name] = ModelMetrics( - requests=mm.requests.model_copy(), - usage=mm.usage.model_copy(), - ) - else: - existing = merged[model_name] - existing.requests.count += mm.requests.count - existing.requests.cost += mm.requests.cost - existing.usage.inputTokens += mm.usage.inputTokens - existing.usage.outputTokens += mm.usage.outputTokens - existing.usage.cacheReadTokens += mm.usage.cacheReadTokens - existing.usage.cacheWriteTokens += mm.usage.cacheWriteTokens + merged = merge_model_metrics(merged, s.model_metrics) return merged diff --git a/tests/copilot_usage/test_models.py b/tests/copilot_usage/test_models.py index 5036cf8..856b555 100644 --- a/tests/copilot_usage/test_models.py +++ b/tests/copilot_usage/test_models.py @@ -16,6 +16,7 @@ TokenUsage, ToolExecutionData, UserMessageData, + merge_model_metrics, ) # --------------------------------------------------------------------------- @@ -227,3 +228,102 @@ def test_session_summary_full() -> None: ) assert s.total_premium_requests == 24 assert s.model_metrics["claude-opus-4.6-1m"].usage.inputTokens == 1627935 + + +# --------------------------------------------------------------------------- +# merge_model_metrics +# --------------------------------------------------------------------------- + + +class TestMergeModelMetrics: + """Unit tests for the merge_model_metrics helper.""" + + def test_both_empty(self) -> None: + assert merge_model_metrics({}, {}) == {} + + def test_empty_base(self) -> None: + additional = { + "model-a": ModelMetrics( + requests=RequestMetrics(count=3, cost=2), + usage=TokenUsage(inputTokens=100, outputTokens=50), + ) + } + result = merge_model_metrics({}, additional) + assert "model-a" in result + assert result["model-a"].requests.count == 3 + assert result["model-a"].usage.inputTokens == 100 + + def test_empty_additional(self) -> None: + base = { + "model-a": ModelMetrics( + requests=RequestMetrics(count=5, cost=3), + usage=TokenUsage(outputTokens=200), + ) + } + result = merge_model_metrics(base, {}) + assert result["model-a"].requests.count == 5 + assert result["model-a"].usage.outputTokens == 200 + + def test_overlapping_keys_accumulate(self) -> None: + base = { + "claude-sonnet-4": ModelMetrics( + requests=RequestMetrics(count=3, cost=2), + usage=TokenUsage( + inputTokens=100, + outputTokens=50, + cacheReadTokens=10, + cacheWriteTokens=5, + ), + ) + } + additional = { + "claude-sonnet-4": ModelMetrics( + requests=RequestMetrics(count=7, cost=4), + usage=TokenUsage( + inputTokens=200, + outputTokens=80, + cacheReadTokens=20, + cacheWriteTokens=15, + ), + ) + } + result = merge_model_metrics(base, additional) + m = result["claude-sonnet-4"] + assert m.requests.count == 10 + assert m.requests.cost == 6 + assert m.usage.inputTokens == 300 + assert m.usage.outputTokens == 130 + assert m.usage.cacheReadTokens == 30 + assert m.usage.cacheWriteTokens == 20 + + def test_disjoint_keys_kept_separate(self) -> None: + base = {"model-a": ModelMetrics(usage=TokenUsage(outputTokens=100))} + additional = {"model-b": ModelMetrics(usage=TokenUsage(outputTokens=200))} + result = merge_model_metrics(base, additional) + assert "model-a" in result and "model-b" in result + assert result["model-a"].usage.outputTokens == 100 + assert result["model-b"].usage.outputTokens == 200 + + def test_does_not_mutate_base(self) -> None: + base = { + "m1": ModelMetrics( + requests=RequestMetrics(count=1, cost=1), + usage=TokenUsage(inputTokens=10), + ) + } + additional = { + "m1": ModelMetrics( + requests=RequestMetrics(count=2, cost=2), + usage=TokenUsage(inputTokens=20), + ) + } + merge_model_metrics(base, additional) + # base must be unchanged + assert base["m1"].requests.count == 1 + assert base["m1"].usage.inputTokens == 10 + + def test_does_not_mutate_additional(self) -> None: + base = {"m1": ModelMetrics(requests=RequestMetrics(count=1))} + additional = {"m1": ModelMetrics(requests=RequestMetrics(count=5))} + merge_model_metrics(base, additional) + assert additional["m1"].requests.count == 5