From da778a18fb83ee138f51e3dcd6322a643be46381 Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 7 Apr 2026 01:20:01 +0800 Subject: [PATCH 1/7] refactor: replace WASM-based PoW with a high-performance native Go implementation and add context support for cancellation. --- README.MD | 8 ++++---- README.en.md | 8 ++++---- docs/CONTRIBUTING.en.md | 2 +- docs/CONTRIBUTING.md | 2 +- docs/DEPLOY.en.md | 3 --- docs/DEPLOY.md | 3 --- internal/deepseek/client_auth.go | 2 +- internal/deepseek/pow.go | 5 +++-- internal/deepseek/pow_test.go | 2 +- pow/deepseek_pow.go | 16 ++++++++++++---- pow/deepseek_pow_test.go | 7 ++++--- 11 files changed, 31 insertions(+), 27 deletions(-) diff --git a/README.MD b/README.MD index d0c003a..6e5411d 100644 --- a/README.MD +++ b/README.MD @@ -48,7 +48,7 @@ flowchart LR Auth["Auth Resolver\n(API key / bearer / x-goog-api-key)"] Pool["Account Pool + Queue\n(并发槽位 + 等待队列)"] DSClient["DeepSeek Client\n(Session / Auth / HTTP)"] - Pow["PoW WASM\n(wazero 预加载)"] + Pow["PoW 实现\n(纯 Go 毫秒级)"] Tool["Tool Sieve\n(Go/Node 语义对齐)"] end end @@ -95,7 +95,7 @@ flowchart LR | Gemini 兼容 | `POST /v1beta/models/{model}:generateContent`、`POST /v1beta/models/{model}:streamGenerateContent`(及 `/v1/models/{model}:*` 路径) | | 多账号轮询 | 自动 token 刷新、邮箱/手机号双登录方式 | | 并发队列控制 | 每账号 in-flight 上限 + 等待队列,动态计算建议并发值 | -| DeepSeek PoW | WASM 计算(`wazero`),无需外部 Node.js 依赖 | +| DeepSeek PoW | 纯 Go 高性能实现(DeepSeekHashV1),毫秒级响应 | | Tool Calling | 防泄漏处理:非代码块高置信特征识别、`delta.tool_calls` 早发、结构化增量输出 | | Admin API | 配置管理、运行时设置热更新、账号测试 / 批量测试、会话清理、导入导出、Vercel 同步、版本检查 | | WebUI 管理台 | `/admin` 单页应用(中英文双语、深色模式) | @@ -344,7 +344,7 @@ cp opencode.json.example opencode.json | `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` | | `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — | | `DS2API_ENV_WRITEBACK` | 环境变量模式下自动写回配置文件并切换文件模式(`1/true/yes/on`) | 关闭 | -| `DS2API_WASM_PATH` | PoW WASM 文件路径 | 自动查找 | +| `DS2API_POW_CONCURRENCY` | PoW 并行计算协程数(可选) | 默认 CPU 核心数 | | `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` | | `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 | | `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 | @@ -455,7 +455,7 @@ ds2api/ │ ├── claudeconv/ # Claude 消息格式转换 │ ├── compat/ # Go 版本兼容与回归测试辅助 │ ├── config/ # 配置加载、校验与热更新 -│ ├── deepseek/ # DeepSeek API 客户端、PoW WASM +│ ├── deepseek/ # DeepSeek API 客户端、PoW 逻辑 │ ├── js/ # Node 运行时流式处理与兼容逻辑 │ ├── devcapture/ # 开发抓包模块 │ ├── rawsample/ # 原始流样本可见文本提取与回放辅助 diff --git a/README.en.md b/README.en.md index 5959041..df764ab 100644 --- a/README.en.md +++ b/README.en.md @@ -48,7 +48,7 @@ flowchart LR Auth["Auth Resolver\n(API key / bearer / x-goog-api-key)"] Pool["Account Pool + Queue\n(in-flight slots + wait queue)"] DSClient["DeepSeek Client\n(session / auth / HTTP)"] - Pow["PoW WASM\n(wazero preload)"] + Pow["PoW Solver\n(Pure Go ms-level)"] Tool["Tool Sieve\n(Go/Node semantic parity)"] end end @@ -95,7 +95,7 @@ flowchart LR | Gemini compatible | `POST /v1beta/models/{model}:generateContent`, `POST /v1beta/models/{model}:streamGenerateContent` (plus `/v1/models/{model}:*` paths) | | Multi-account rotation | Auto token refresh, email/mobile dual login | | Concurrency control | Per-account in-flight limit + waiting queue, dynamic recommended concurrency | -| DeepSeek PoW | WASM solving via `wazero`, no external Node.js dependency | +| DeepSeek PoW | Pure Go high-performance solver (DeepSeekHashV1), ms-level response | | Tool Calling | Anti-leak handling: non-code-block feature match, early `delta.tool_calls`, structured incremental output | | Admin API | Config management, runtime settings hot-reload, account testing/batch test, session cleanup, import/export, Vercel sync, version check | | WebUI Admin Panel | SPA at `/admin` (bilingual Chinese/English, dark mode) | @@ -344,7 +344,7 @@ cp opencode.json.example opencode.json | `DS2API_CONFIG_PATH` | Config file path | `config.json` | | `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — | | `DS2API_ENV_WRITEBACK` | Auto-write env-backed config to file and transition to file mode (`1/true/yes/on`) | Disabled | -| `DS2API_WASM_PATH` | PoW WASM file path | Auto-detect | +| `DS2API_POW_CONCURRENCY` | PoW parallel solver goroutine count (optional) | Default CPU core count | | `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` | | `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel | | `DS2API_ACCOUNT_MAX_INFLIGHT` | Max in-flight requests per account | `2` | @@ -453,7 +453,7 @@ ds2api/ │ ├── claudeconv/ # Claude message format conversion │ ├── compat/ # Go-version compatibility and regression helpers │ ├── config/ # Config loading, validation, and hot-reload -│ ├── deepseek/ # DeepSeek API client, PoW WASM +│ ├── deepseek/ # DeepSeek API client, PoW logic │ ├── js/ # Node runtime stream/compat logic │ ├── devcapture/ # Dev packet capture module │ ├── rawsample/ # Visible-text extraction and replay helpers for raw stream samples diff --git a/docs/CONTRIBUTING.en.md b/docs/CONTRIBUTING.en.md index 0752b98..3370141 100644 --- a/docs/CONTRIBUTING.en.md +++ b/docs/CONTRIBUTING.en.md @@ -115,7 +115,7 @@ ds2api/ │ ├── claudeconv/ # Claude message conversion │ ├── compat/ # Go-version compatibility and regression helpers │ ├── config/ # Config loading, validation, and hot-reload -│ ├── deepseek/ # DeepSeek client, PoW WASM +│ ├── deepseek/ # DeepSeek client, PoW logic │ ├── js/ # Node runtime stream/compat logic │ ├── devcapture/ # Dev packet capture │ ├── format/ # Output formatting diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index ad16e97..a4408df 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -115,7 +115,7 @@ ds2api/ │ ├── claudeconv/ # Claude 消息格式转换 │ ├── compat/ # Go 版本兼容与回归测试辅助 │ ├── config/ # 配置加载、校验与热更新 -│ ├── deepseek/ # DeepSeek 客户端、PoW WASM +│ ├── deepseek/ # DeepSeek 客户端、PoW 逻辑 │ ├── js/ # Node 运行时流式/兼容逻辑 │ ├── devcapture/ # 开发抓包 │ ├── format/ # 输出格式化 diff --git a/docs/DEPLOY.en.md b/docs/DEPLOY.en.md index 145b186..8273ff7 100644 --- a/docs/DEPLOY.en.md +++ b/docs/DEPLOY.en.md @@ -366,7 +366,6 @@ Each archive includes: - `ds2api` executable (`ds2api.exe` on Windows) - `static/admin/` (built WebUI assets) -- `sha3_wasm_bg.7b9ca65ddd.wasm` (optional; binary has embedded fallback) - `config.example.json`, `.env.example` - `README.MD`, `README.en.md`, `LICENSE` @@ -456,8 +455,6 @@ server { # Copy compiled binary and related files to target directory sudo mkdir -p /opt/ds2api sudo cp ds2api config.json /opt/ds2api/ -# Optional: if you want to use an external WASM file (override the embedded one, from a release package or build output) -# sudo cp /path/to/sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/ sudo cp -r static/admin /opt/ds2api/static/admin ``` diff --git a/docs/DEPLOY.md b/docs/DEPLOY.md index 598c210..e4969d4 100644 --- a/docs/DEPLOY.md +++ b/docs/DEPLOY.md @@ -366,7 +366,6 @@ No Output Directory named "public" found after the Build completed. - `ds2api` 可执行文件(Windows 为 `ds2api.exe`) - `static/admin/`(WebUI 构建产物) -- `sha3_wasm_bg.7b9ca65ddd.wasm`(可选;程序内置 embed fallback) - `config.example.json`、`.env.example` - `README.MD`、`README.en.md`、`LICENSE` @@ -456,8 +455,6 @@ server { # 将编译好的二进制文件和相关文件复制到目标目录 sudo mkdir -p /opt/ds2api sudo cp ds2api config.json /opt/ds2api/ -# 可选:若你希望使用外置 WASM 文件(覆盖内置版本,来自 release 包或构建产物) -# sudo cp /path/to/sha3_wasm_bg.7b9ca65ddd.wasm /opt/ds2api/ sudo cp -r static/admin /opt/ds2api/static/admin ``` diff --git a/internal/deepseek/client_auth.go b/internal/deepseek/client_auth.go index 3cbf323..e953327 100644 --- a/internal/deepseek/client_auth.go +++ b/internal/deepseek/client_auth.go @@ -109,7 +109,7 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in data, _ := resp["data"].(map[string]any) bizData, _ := data["biz_data"].(map[string]any) challenge, _ := bizData["challenge"].(map[string]any) - answer, err := ComputePow(challenge) + answer, err := ComputePow(ctx, challenge) if err != nil { attempts++ continue diff --git a/internal/deepseek/pow.go b/internal/deepseek/pow.go index 54f678d..9d839de 100644 --- a/internal/deepseek/pow.go +++ b/internal/deepseek/pow.go @@ -1,6 +1,7 @@ package deepseek import ( + "context" "encoding/base64" "encoding/json" "errors" @@ -9,7 +10,7 @@ import ( ) // ComputePow 使用纯 Go 实现求解 PoW challenge (DeepSeekHashV1)。 -func ComputePow(challenge map[string]any) (int64, error) { +func ComputePow(ctx context.Context, challenge map[string]any) (int64, error) { algo, _ := challenge["algorithm"].(string) if algo != "DeepSeekHashV1" { return 0, errors.New("unsupported algorithm") @@ -19,7 +20,7 @@ func ComputePow(challenge map[string]any) (int64, error) { expireAt := toInt64(challenge["expire_at"], 1680000000) difficulty := toInt64FromFloat(challenge["difficulty"], 144000) - return pow.SolvePow(challengeStr, salt, expireAt, difficulty) + return pow.SolvePow(ctx, challengeStr, salt, expireAt, difficulty) } // BuildPowHeader 序列化 {algorithm,challenge,salt,answer,signature,target_path} 为 base64(JSON)。 diff --git a/internal/deepseek/pow_test.go b/internal/deepseek/pow_test.go index 3f1104c..0161f62 100644 --- a/internal/deepseek/pow_test.go +++ b/internal/deepseek/pow_test.go @@ -13,7 +13,7 @@ func TestPreloadPowNoOp(t *testing.T) { } func TestComputePowUnsupportedAlgorithm(t *testing.T) { - _, err := ComputePow(map[string]any{"algorithm": "unknown"}) + _, err := ComputePow(context.Background(), map[string]any{"algorithm": "unknown"}) if err == nil { t.Fatal("expected error for unsupported algorithm") } diff --git a/pow/deepseek_pow.go b/pow/deepseek_pow.go index f6ea7de..bb9b2b4 100644 --- a/pow/deepseek_pow.go +++ b/pow/deepseek_pow.go @@ -1,6 +1,7 @@ package pow import ( + "context" "encoding/base64" "encoding/binary" "encoding/hex" @@ -9,7 +10,7 @@ import ( "strconv" ) -// Challenge 对应 /api/v0/chat/create_pow_challenge 返回的 data.biz_data.challenge。 +// Challenge 对应 /api/v0/chat/create_pow_challenge 返回 dem data.biz_data.challenge。 type Challenge struct { Algorithm string `json:"algorithm"` Challenge string `json:"challenge"` @@ -27,7 +28,7 @@ func BuildPrefix(salt string, expireAt int64) string { // SolvePow 搜索 nonce ∈ [0, difficulty) 使得 DeepSeekHashV1(prefix+str(nonce)) == challenge。 // prefix 预吸收进 state,循环内零分配。 -func SolvePow(challengeHex, salt string, expireAt, difficulty int64) (int64, error) { +func SolvePow(ctx context.Context, challengeHex, salt string, expireAt, difficulty int64) (int64, error) { if len(challengeHex) != 64 { return 0, errors.New("pow: challenge must be 64 hex chars") } @@ -59,6 +60,13 @@ func SolvePow(challengeHex, salt string, expireAt, difficulty int64) (int64, err var numBuf [20]byte for n := int64(0); n < difficulty; n++ { + // Periodically check if context is canceled to avoid wasting CPU + if n&0x3FF == 0 { + if err := ctx.Err(); err != nil { + return 0, err + } + } + v := uint64(n) pos := 20 if v == 0 { @@ -123,7 +131,7 @@ func BuildPowHeader(c *Challenge, answer int64) (string, error) { } // SolveAndBuildHeader 端到端: Challenge → x-ds-pow-response header string。 -func SolveAndBuildHeader(c *Challenge) (string, error) { +func SolveAndBuildHeader(ctx context.Context, c *Challenge) (string, error) { if c.Algorithm != "DeepSeekHashV1" { return "", errors.New("pow: unsupported algorithm: " + c.Algorithm) } @@ -131,7 +139,7 @@ func SolveAndBuildHeader(c *Challenge) (string, error) { if d == 0 { d = 144000 } - answer, err := SolvePow(c.Challenge, c.Salt, c.ExpireAt, d) + answer, err := SolvePow(ctx, c.Challenge, c.Salt, c.ExpireAt, d) if err != nil { return "", err } diff --git a/pow/deepseek_pow_test.go b/pow/deepseek_pow_test.go index 73c6f66..d2ed773 100644 --- a/pow/deepseek_pow_test.go +++ b/pow/deepseek_pow_test.go @@ -1,6 +1,7 @@ package pow import ( + "context" "encoding/base64" "encoding/hex" "encoding/json" @@ -36,7 +37,7 @@ func TestSolvePow(t *testing.T) { {"abc123salt", 1700000000, 12345, 20000}, } { h := DeepSeekHashV1([]byte(BuildPrefix(tc.salt, tc.expire) + strconv.FormatInt(tc.answer, 10))) - got, err := SolvePow(hex.EncodeToString(h[:]), tc.salt, tc.expire, tc.diff) + got, err := SolvePow(context.Background(), hex.EncodeToString(h[:]), tc.salt, tc.expire, tc.diff) if err != nil || got != tc.answer { t.Errorf("salt=%q answer=%d: got=%d err=%v", tc.salt, tc.answer, got, err) } @@ -45,7 +46,7 @@ func TestSolvePow(t *testing.T) { func TestSolveAndBuildHeader(t *testing.T) { t0 := DeepSeekHashV1([]byte("salt_1712345678_777")) - header, err := SolveAndBuildHeader(&Challenge{ + header, err := SolveAndBuildHeader(context.Background(), &Challenge{ Algorithm: "DeepSeekHashV1", Challenge: hex.EncodeToString(t0[:]), Salt: "salt", ExpireAt: 1712345678, Difficulty: 2000, Signature: "sig", TargetPath: "/api/v0/chat/completion", @@ -74,6 +75,6 @@ func BenchmarkSolve(b *testing.B) { h := DeepSeekHashV1([]byte("realisticsalt_1712345678_72000")) ch := hex.EncodeToString(h[:]) for i := 0; i < b.N; i++ { - _, _ = SolvePow(ch, "realisticsalt", 1712345678, 144000) + _, _ = SolvePow(context.Background(), ch, "realisticsalt", 1712345678, 144000) } } From b79a13efd56f43d91cb1f4611a650d9568b70071 Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 7 Apr 2026 01:39:27 +0800 Subject: [PATCH 2/7] feat: support explicit prompt token tracking in SSE parsing and stream handlers --- .../adapter/openai/chat_stream_runtime.go | 17 ++++-- internal/adapter/openai/handler_chat.go | 13 ++-- .../adapter/openai/handler_toolcall_format.go | 2 +- .../openai/responses_stream_runtime_events.go | 2 +- internal/deepseek/constants_shared.json | 1 - internal/js/chat-stream/sse_parse_impl.js | 60 ++++++++++++++----- internal/js/chat-stream/token_usage.js | 10 ++-- internal/js/chat-stream/vercel_stream_impl.js | 6 +- internal/sse/consumer.go | 15 +++-- internal/sse/line.go | 9 ++- internal/sse/parser.go | 34 ++++++++--- internal/sse/parser_edge_test.go | 16 +---- internal/sse/parser_test.go | 14 +++++ 13 files changed, 136 insertions(+), 63 deletions(-) diff --git a/internal/adapter/openai/chat_stream_runtime.go b/internal/adapter/openai/chat_stream_runtime.go index a199882..47f483a 100644 --- a/internal/adapter/openai/chat_stream_runtime.go +++ b/internal/adapter/openai/chat_stream_runtime.go @@ -37,6 +37,7 @@ type chatStreamRuntime struct { streamToolNames map[int]string thinking strings.Builder text strings.Builder + promptTokens int outputTokens int } @@ -170,11 +171,16 @@ func (s *chatStreamRuntime) finalize(finishReason string) { finishReason = "tool_calls" } usage := openaifmt.BuildChatUsage(s.finalPrompt, finalThinking, finalText) + if s.promptTokens > 0 { + usage["prompt_tokens"] = s.promptTokens + } if s.outputTokens > 0 { usage["completion_tokens"] = s.outputTokens - if prompt, ok := usage["prompt_tokens"].(int); ok { - usage["total_tokens"] = prompt + s.outputTokens - } + } + if s.promptTokens > 0 || s.outputTokens > 0 { + p := usage["prompt_tokens"].(int) + c := usage["completion_tokens"].(int) + usage["total_tokens"] = p + c } s.sendChunk(openaifmt.BuildChatStreamChunk( s.completionID, @@ -190,6 +196,9 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.PromptTokens > 0 { + s.promptTokens = parsed.PromptTokens + } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } @@ -243,7 +252,7 @@ func (s *chatStreamRuntime) onParsed(parsed sse.LineResult) streamengine.ParsedD if !s.emitEarlyToolDeltas { continue } - filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.streamToolNames) + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.streamToolNames) if len(filtered) == 0 { continue } diff --git a/internal/adapter/openai/handler_chat.go b/internal/adapter/openai/handler_chat.go index 95337b6..e28886d 100644 --- a/internal/adapter/openai/handler_chat.go +++ b/internal/adapter/openai/handler_chat.go @@ -131,12 +131,17 @@ func (h *Handler) handleNonStream(w http.ResponseWriter, ctx context.Context, re return } respBody := openaifmt.BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText, toolNames) - if result.OutputTokens > 0 { + if result.PromptTokens > 0 || result.OutputTokens > 0 { if usage, ok := respBody["usage"].(map[string]any); ok { - usage["completion_tokens"] = result.OutputTokens - if prompt, ok := usage["prompt_tokens"].(int); ok { - usage["total_tokens"] = prompt + result.OutputTokens + if result.PromptTokens > 0 { + usage["prompt_tokens"] = result.PromptTokens } + if result.OutputTokens > 0 { + usage["completion_tokens"] = result.OutputTokens + } + p, _ := usage["prompt_tokens"].(int) + c, _ := usage["completion_tokens"].(int) + usage["total_tokens"] = p + c } } writeJSON(w, http.StatusOK, respBody) diff --git a/internal/adapter/openai/handler_toolcall_format.go b/internal/adapter/openai/handler_toolcall_format.go index 44eb4d1..3937610 100644 --- a/internal/adapter/openai/handler_toolcall_format.go +++ b/internal/adapter/openai/handler_toolcall_format.go @@ -113,7 +113,7 @@ func formatIncrementalStreamToolCallDeltas(deltas []toolCallDelta, ids map[int]s return out } -func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNames []string, seenNames map[int]string) []toolCallDelta { +func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, seenNames map[int]string) []toolCallDelta { if len(deltas) == 0 { return nil } diff --git a/internal/adapter/openai/responses_stream_runtime_events.go b/internal/adapter/openai/responses_stream_runtime_events.go index 792d0ce..21e15d1 100644 --- a/internal/adapter/openai/responses_stream_runtime_events.go +++ b/internal/adapter/openai/responses_stream_runtime_events.go @@ -48,7 +48,7 @@ func (s *responsesStreamRuntime) processToolStreamEvents(events []toolStreamEven if !s.emitEarlyToolDeltas { continue } - filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.toolNames, s.functionNames) + filtered := filterIncrementalToolCallDeltasByAllowed(evt.ToolCallDeltas, s.functionNames) if len(filtered) == 0 { continue } diff --git a/internal/deepseek/constants_shared.json b/internal/deepseek/constants_shared.json index a71ca02..56950ca 100644 --- a/internal/deepseek/constants_shared.json +++ b/internal/deepseek/constants_shared.json @@ -12,7 +12,6 @@ "skip_contains_patterns": [ "quasi_status", "elapsed_secs", - "token_usage", "pending_fragment", "conversation_mode", "fragments/-1/status", diff --git a/internal/js/chat-stream/sse_parse_impl.js b/internal/js/chat-stream/sse_parse_impl.js index f24ee6d..10b85f0 100644 --- a/internal/js/chat-stream/sse_parse_impl.js +++ b/internal/js/chat-stream/sse_parse_impl.js @@ -20,7 +20,9 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc }; } - const outputTokens = extractAccumulatedTokenUsage(chunk); + const usage = extractAccumulatedTokenUsage(chunk); + const promptTokens = usage.prompt; + const outputTokens = usage.output; if (Object.prototype.hasOwnProperty.call(chunk, 'error')) { return { @@ -29,7 +31,8 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: formatErrorMessage(chunk.error), - outputTokens: 0, + promptTokens, + outputTokens, newType: currentType, }; } @@ -43,6 +46,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: true, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -55,6 +59,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -67,6 +72,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -77,6 +83,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -89,6 +96,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType: currentType, }; @@ -157,6 +165,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -168,6 +177,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -182,6 +192,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -196,6 +207,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: true, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -207,6 +219,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -242,6 +255,7 @@ function parseChunkForContent(chunk, thinkingEnabled, currentType, stripReferenc finished: false, contentFilter: false, errorMessage: '', + promptTokens, outputTokens, newType, }; @@ -429,40 +443,54 @@ function hasContentFilterStatusValue(v) { } function extractAccumulatedTokenUsage(chunk) { - return findAccumulatedTokenUsage(chunk); + const usage = findAccumulatedTokenUsage(chunk); + return usage || { prompt: 0, output: 0 }; } function findAccumulatedTokenUsage(v) { if (Array.isArray(v)) { for (const item of v) { - const n = findAccumulatedTokenUsage(item); - if (n > 0) { - return n; - } + const u = findAccumulatedTokenUsage(item); + if (u) return u; } - return 0; + return null; } if (!v || typeof v !== 'object') { - return 0; + return null; } const pathValue = asString(v.p); if (pathValue && pathValue.toLowerCase().includes('accumulated_token_usage')) { const n = toInt(v.v); if (n > 0) { - return n; + return { prompt: 0, output: n }; + } + } + if (pathValue && pathValue.toLowerCase().includes('token_usage')) { + const u = v.v; + if (u && typeof u === 'object') { + const p = toInt(u.prompt_tokens); + const c = toInt(u.completion_tokens); + if (p > 0 || c > 0) { + return { prompt: p, output: c }; + } } } const direct = toInt(v.accumulated_token_usage); if (direct > 0) { - return direct; + return { prompt: 0, output: direct }; } - for (const value of Object.values(v)) { - const n = findAccumulatedTokenUsage(value); - if (n > 0) { - return n; + if (v.token_usage && typeof v.token_usage === 'object') { + const p = toInt(v.token_usage.prompt_tokens); + const c = toInt(v.token_usage.completion_tokens); + if (p > 0 || c > 0) { + return { prompt: p, output: c }; } } - return 0; + for (const value of Object.values(v)) { + const u = findAccumulatedTokenUsage(value); + if (u) return u; + } + return null; } function toInt(v) { diff --git a/internal/js/chat-stream/token_usage.js b/internal/js/chat-stream/token_usage.js index 0f71c5f..82e12e8 100644 --- a/internal/js/chat-stream/token_usage.js +++ b/internal/js/chat-stream/token_usage.js @@ -1,15 +1,17 @@ 'use strict'; -function buildUsage(prompt, thinking, output, outputTokens = 0) { - const promptTokens = estimateTokens(prompt); +function buildUsage(prompt, thinking, output, outputTokens = 0, providedPromptTokens = 0) { const reasoningTokens = estimateTokens(thinking); const completionTokens = estimateTokens(output); + + const finalPromptTokens = Number.isFinite(providedPromptTokens) && providedPromptTokens > 0 ? Math.trunc(providedPromptTokens) : estimateTokens(prompt); + const overriddenCompletionTokens = Number.isFinite(outputTokens) && outputTokens > 0 ? Math.trunc(outputTokens) : 0; const finalCompletionTokens = overriddenCompletionTokens > 0 ? overriddenCompletionTokens : reasoningTokens + completionTokens; return { - prompt_tokens: promptTokens, + prompt_tokens: finalPromptTokens, completion_tokens: finalCompletionTokens, - total_tokens: promptTokens + finalCompletionTokens, + total_tokens: finalPromptTokens + finalCompletionTokens, completion_tokens_details: { reasoning_tokens: reasoningTokens, }, diff --git a/internal/js/chat-stream/vercel_stream_impl.js b/internal/js/chat-stream/vercel_stream_impl.js index e46b530..7c39313 100644 --- a/internal/js/chat-stream/vercel_stream_impl.js +++ b/internal/js/chat-stream/vercel_stream_impl.js @@ -125,6 +125,7 @@ async function handleVercelStream(req, res, rawBody, payload) { let currentType = thinkingEnabled ? 'thinking' : 'text'; let thinkingText = ''; let outputText = ''; + let promptTokens = 0; let outputTokens = 0; const toolSieveEnabled = toolPolicy.toolSieveEnabled; const toolSieveState = createToolSieveState(); @@ -178,7 +179,7 @@ async function handleVercelStream(req, res, rawBody, payload) { created, model, choices: [{ delta: {}, index: 0, finish_reason: reason }], - usage: buildUsage(finalPrompt, thinkingText, outputText, outputTokens), + usage: buildUsage(finalPrompt, thinkingText, outputText, outputTokens, promptTokens), }); if (!res.writableEnded && !res.destroyed) { res.write('data: [DONE]\n\n'); @@ -227,6 +228,9 @@ async function handleVercelStream(req, res, rawBody, payload) { if (!parsed.parsed) { continue; } + if (parsed.promptTokens > 0) { + promptTokens = parsed.promptTokens; + } if (parsed.outputTokens > 0) { outputTokens = parsed.outputTokens; } diff --git a/internal/sse/consumer.go b/internal/sse/consumer.go index 141bd93..f11e942 100644 --- a/internal/sse/consumer.go +++ b/internal/sse/consumer.go @@ -12,6 +12,7 @@ import ( type CollectResult struct { Text string Thinking string + PromptTokens int OutputTokens int ContentFilter bool } @@ -28,6 +29,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co } text := strings.Builder{} thinking := strings.Builder{} + promptTokens := 0 outputTokens := 0 contentFilter := false currentType := "text" @@ -40,18 +42,18 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co if !result.Parsed { return true } + if result.PromptTokens > 0 { + promptTokens = result.PromptTokens + } + if result.OutputTokens > 0 { + outputTokens = result.OutputTokens + } if result.Stop { if result.ContentFilter { contentFilter = true } - if result.OutputTokens > 0 { - outputTokens = result.OutputTokens - } return false } - if result.OutputTokens > 0 { - outputTokens = result.OutputTokens - } for _, p := range result.Parts { if p.Type == "thinking" { trimmed := TrimContinuationOverlap(thinking.String(), p.Text) @@ -66,6 +68,7 @@ func CollectStream(resp *http.Response, thinkingEnabled bool, closeBody bool) Co return CollectResult{ Text: text.String(), Thinking: thinking.String(), + PromptTokens: promptTokens, OutputTokens: outputTokens, ContentFilter: contentFilter, } diff --git a/internal/sse/line.go b/internal/sse/line.go index d55f9e5..a63563b 100644 --- a/internal/sse/line.go +++ b/internal/sse/line.go @@ -10,6 +10,7 @@ type LineResult struct { ErrorMessage string Parts []ContentPart NextType string + PromptTokens int OutputTokens int } @@ -20,9 +21,9 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri if !parsed { return LineResult{NextType: currentType} } - outputTokens := extractAccumulatedTokenUsage(chunk) + promptTokens, outputTokens := extractAccumulatedTokenUsage(chunk) if done { - return LineResult{Parsed: true, Stop: true, NextType: currentType, OutputTokens: outputTokens} + return LineResult{Parsed: true, Stop: true, NextType: currentType, PromptTokens: promptTokens, OutputTokens: outputTokens} } if errObj, hasErr := chunk["error"]; hasErr { return LineResult{ @@ -30,6 +31,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ErrorMessage: fmt.Sprintf("%v", errObj), NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -39,6 +41,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ContentFilter: true, NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -48,6 +51,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: true, ContentFilter: true, NextType: currentType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } @@ -58,6 +62,7 @@ func ParseDeepSeekContentLine(raw []byte, thinkingEnabled bool, currentType stri Stop: finished, Parts: parts, NextType: nextType, + PromptTokens: promptTokens, OutputTokens: outputTokens, } } diff --git a/internal/sse/parser.go b/internal/sse/parser.go index eee46f9..051619e 100644 --- a/internal/sse/parser.go +++ b/internal/sse/parser.go @@ -364,34 +364,50 @@ func hasContentFilterStatusValue(v any) bool { return false } -func extractAccumulatedTokenUsage(chunk map[string]any) int { +func extractAccumulatedTokenUsage(chunk map[string]any) (int, int) { return findAccumulatedTokenUsage(chunk) } -func findAccumulatedTokenUsage(v any) int { +func findAccumulatedTokenUsage(v any) (int, int) { switch x := v.(type) { case map[string]any: if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "accumulated_token_usage") { if n, ok := toInt(x["v"]); ok && n > 0 { - return n + return 0, n + } + } + if p, _ := x["p"].(string); strings.Contains(strings.ToLower(p), "token_usage") { + if m, ok := x["v"].(map[string]any); ok { + p, _ := toInt(m["prompt_tokens"]) + c, _ := toInt(m["completion_tokens"]) + if p > 0 || c > 0 { + return p, c + } } } if n, ok := toInt(x["accumulated_token_usage"]); ok && n > 0 { - return n + return 0, n + } + if usage, ok := x["token_usage"].(map[string]any); ok { + p, _ := toInt(usage["prompt_tokens"]) + c, _ := toInt(usage["completion_tokens"]) + if p > 0 || c > 0 { + return p, c + } } for _, vv := range x { - if n := findAccumulatedTokenUsage(vv); n > 0 { - return n + if p, c := findAccumulatedTokenUsage(vv); p > 0 || c > 0 { + return p, c } } case []any: for _, item := range x { - if n := findAccumulatedTokenUsage(item); n > 0 { - return n + if p, c := findAccumulatedTokenUsage(item); p > 0 || c > 0 { + return p, c } } } - return 0 + return 0, 0 } func toInt(v any) (int, bool) { diff --git a/internal/sse/parser_edge_test.go b/internal/sse/parser_edge_test.go index ba1c723..f0e7f9a 100644 --- a/internal/sse/parser_edge_test.go +++ b/internal/sse/parser_edge_test.go @@ -50,18 +50,6 @@ func TestShouldSkipPathQuasiStatus(t *testing.T) { } } -func TestShouldSkipPathElapsedSecs(t *testing.T) { - if !shouldSkipPath("response/elapsed_secs") { - t.Fatal("expected skip for elapsed_secs path") - } -} - -func TestShouldSkipPathTokenUsage(t *testing.T) { - if !shouldSkipPath("response/token_usage") { - t.Fatal("expected skip for token_usage path") - } -} - func TestShouldSkipPathPendingFragment(t *testing.T) { if !shouldSkipPath("response/pending_fragment") { t.Fatal("expected skip for pending_fragment path") @@ -127,7 +115,7 @@ func TestParseSSEChunkForContentNoVField(t *testing.T) { func TestParseSSEChunkForContentSkippedPath(t *testing.T) { parts, finished, nextType := ParseSSEChunkForContent(map[string]any{ - "p": "response/token_usage", + "p": "response/quasi_status", "v": "some data", }, false, "text") if finished || len(parts) > 0 { @@ -498,7 +486,7 @@ func TestExtractContentRecursiveFinishedStatus(t *testing.T) { func TestExtractContentRecursiveSkipsPath(t *testing.T) { items := []any{ - map[string]any{"p": "token_usage", "v": "data"}, + map[string]any{"p": "quasi_status", "v": "data"}, } parts, finished := extractContentRecursive(items, "text") if finished { diff --git a/internal/sse/parser_test.go b/internal/sse/parser_test.go index b036f57..89c5356 100644 --- a/internal/sse/parser_test.go +++ b/internal/sse/parser_test.go @@ -19,6 +19,20 @@ func TestParseDeepSeekSSELineDone(t *testing.T) { } } +func TestExtractTokenUsage(t *testing.T) { + chunk := map[string]any{ + "p": "response/token_usage", + "v": map[string]any{ + "prompt_tokens": 123, + "completion_tokens": 456, + }, + } + p, c := extractAccumulatedTokenUsage(chunk) + if p != 123 || c != 456 { + t.Fatalf("expected 123/456, got %d/%d", p, c) + } +} + func TestParseSSEChunkForContentSimple(t *testing.T) { parts, finished, _ := ParseSSEChunkForContent(map[string]any{"v": "hello"}, false, "text") if finished { From 99682216331de41a0addec743209da119aaf18c2 Mon Sep 17 00:00:00 2001 From: CJACK Date: Tue, 7 Apr 2026 02:10:45 +0800 Subject: [PATCH 3/7] refactor: improve XML tool parsing robustness, update system prompt constraints, and simplify tool filtering logic --- internal/toolcall/tool_prompt.go | 37 +++++++++++++++++----------- internal/toolcall/toolcalls_parse.go | 10 ++++---- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/internal/toolcall/tool_prompt.go b/internal/toolcall/tool_prompt.go index 8d6649a..bcf1046 100644 --- a/internal/toolcall/tool_prompt.go +++ b/internal/toolcall/tool_prompt.go @@ -46,24 +46,33 @@ When calling tools, emit ONLY raw XML at the very end of your response. No text RULES: -1) Output ONLY the XML above when calling tools. Do NOT mix tool XML with regular text. -2) MUST contain a strict JSON object. All JSON keys and strings use double quotes. -3) Multiple tools → multiple blocks inside ONE root. -4) Do NOT wrap the XML in markdown code fences (no triple backticks). -5) After receiving a tool result, use it directly. Only call another tool if the result is insufficient. -6) Parameters MUST use the exact field names from the selected tool schema. -7) CRITICAL: Do NOT invent or add any extra fields (such as "_raw", "_xml"). Use ONLY the fields strictly defined in the schema. Extra fields will cause execution failure. +1) When calling tools, you MUST use the XML format. +2) No text is allowed AFTER the XML block. +3) MUST be a single-line strict JSON object. Use double quotes. +4) Multiple tools must be inside the same root. +5) Do NOT wrap XML in markdown fences (` + "```" + `). +6) Do NOT invent parameters. Use only the provided schema. +7) CRITICAL: Do NOT use native tool markers like "<|Tool|>" or "<|tool|>". +8) CRITICAL: Do NOT output role markers like "<|System|>", "<|User|>", or "<|Assistant|>". +9) CRITICAL: Do NOT output internal monologues (e.g. "I will list files now..."). Just output your answer or the XML. ❌ WRONG — Do NOT do these: -Wrong 1 — mixed text and XML: - I'll read the file for you. ... -Wrong 2 — describing tool calls in text: - [调用 Bash] {"command": "ls"} +Wrong 1 — mixed text after XML: + ... I hope this helps. +Wrong 2 — function-call syntax: + Grep({"pattern": "token"}) Wrong 3 — missing wrapper: ` + ex1 + `{} -Wrong 4 — extra/invented fields: - {"_raw": "...", "command": "ls"} - +Wrong 4 — Markdown code fences: + ` + "```xml" + ` + ... + ` + "```" + ` +Wrong 5 — native tool tokens: + <|Tool|>call_some_tool{"param":1}<|Tool|> +Wrong 6 — role markers in response: + <|Assistant|> Here is the result... + +Remember: The ONLY valid way to use tools is the XML block at the end of your response. ✅ CORRECT EXAMPLES: diff --git a/internal/toolcall/toolcalls_parse.go b/internal/toolcall/toolcalls_parse.go index 8f0a289..400fd86 100644 --- a/internal/toolcall/toolcalls_parse.go +++ b/internal/toolcall/toolcalls_parse.go @@ -41,7 +41,7 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa continue } parsed := tc - calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + calls, rejectedNames := filterToolCallsDetailed(parsed) result.Calls = calls result.RejectedToolNames = rejectedNames result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 @@ -77,7 +77,7 @@ func ParseToolCallsDetailed(text string, availableToolNames []string) ToolCallPa result.SawToolCallSyntax = true } - calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + calls, rejectedNames := filterToolCallsDetailed(parsed) result.Calls = calls result.RejectedToolNames = rejectedNames result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 @@ -108,7 +108,7 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) continue } result.SawToolCallSyntax = true - calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + calls, rejectedNames := filterToolCallsDetailed(parsed) result.Calls = calls result.RejectedToolNames = rejectedNames result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 @@ -143,14 +143,14 @@ func ParseStandaloneToolCallsDetailed(text string, availableToolNames []string) } } result.SawToolCallSyntax = true - calls, rejectedNames := filterToolCallsDetailed(parsed, availableToolNames) + calls, rejectedNames := filterToolCallsDetailed(parsed) result.Calls = calls result.RejectedToolNames = rejectedNames result.RejectedByPolicy = len(rejectedNames) > 0 && len(calls) == 0 return result } -func filterToolCallsDetailed(parsed []ParsedToolCall, availableToolNames []string) ([]ParsedToolCall, []string) { +func filterToolCallsDetailed(parsed []ParsedToolCall) ([]ParsedToolCall, []string) { out := make([]ParsedToolCall, 0, len(parsed)) for _, tc := range parsed { if tc.Name == "" { From 96b8587c5bdee686ea25b3b377a271ec784fa181 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 7 Apr 2026 08:27:03 +0800 Subject: [PATCH 4/7] Fix token usage propagation and remove stale env docs --- README.MD | 1 - README.en.md | 1 - internal/adapter/openai/prompt_build_test.go | 9 +- internal/adapter/openai/responses_handler.go | 13 ++- .../openai/responses_stream_runtime_core.go | 18 +++- internal/adapter/openai/stream_status_test.go | 94 +++++++++++++++++++ tests/node/chat-stream.test.js | 22 ++++- 7 files changed, 143 insertions(+), 15 deletions(-) diff --git a/README.MD b/README.MD index 6e5411d..5d2ee4e 100644 --- a/README.MD +++ b/README.MD @@ -344,7 +344,6 @@ cp opencode.json.example opencode.json | `DS2API_CONFIG_PATH` | 配置文件路径 | `config.json` | | `DS2API_CONFIG_JSON` | 直接注入配置(JSON 或 Base64) | — | | `DS2API_ENV_WRITEBACK` | 环境变量模式下自动写回配置文件并切换文件模式(`1/true/yes/on`) | 关闭 | -| `DS2API_POW_CONCURRENCY` | PoW 并行计算协程数(可选) | 默认 CPU 核心数 | | `DS2API_STATIC_ADMIN_DIR` | 管理台静态文件目录 | `static/admin` | | `DS2API_AUTO_BUILD_WEBUI` | 启动时自动构建 WebUI | 本地开启,Vercel 关闭 | | `DS2API_DEV_PACKET_CAPTURE` | 本地开发抓包开关(记录最近会话请求/响应体) | 本地非 Vercel 默认开启 | diff --git a/README.en.md b/README.en.md index df764ab..6753bc0 100644 --- a/README.en.md +++ b/README.en.md @@ -344,7 +344,6 @@ cp opencode.json.example opencode.json | `DS2API_CONFIG_PATH` | Config file path | `config.json` | | `DS2API_CONFIG_JSON` | Inline config (JSON or Base64) | — | | `DS2API_ENV_WRITEBACK` | Auto-write env-backed config to file and transition to file mode (`1/true/yes/on`) | Disabled | -| `DS2API_POW_CONCURRENCY` | PoW parallel solver goroutine count (optional) | Default CPU core count | | `DS2API_STATIC_ADMIN_DIR` | Admin static assets dir | `static/admin` | | `DS2API_AUTO_BUILD_WEBUI` | Auto-build WebUI on startup | Enabled locally, disabled on Vercel | | `DS2API_ACCOUNT_MAX_INFLIGHT` | Max in-flight requests per account | `2` | diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 223689b..390cbd4 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -74,16 +74,13 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * } finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "") - if !strings.Contains(finalPrompt, "After receiving a tool result, use it directly.") { - t.Fatalf("vercel prepare finalPrompt missing final-answer instruction: %q", finalPrompt) - } - if !strings.Contains(finalPrompt, "Only call another tool if the result is insufficient.") { - t.Fatalf("vercel prepare finalPrompt missing retry guard instruction: %q", finalPrompt) + if !strings.Contains(finalPrompt, "Remember: Output ONLY the ... XML block when calling tools.") { + t.Fatalf("vercel prepare finalPrompt missing final tool-call anchor instruction: %q", finalPrompt) } if !strings.Contains(finalPrompt, "TOOL CALL FORMAT") { t.Fatalf("vercel prepare finalPrompt missing xml format instruction: %q", finalPrompt) } - if !strings.Contains(finalPrompt, "Do NOT wrap the XML in markdown code fences") { + if !strings.Contains(finalPrompt, "Do NOT wrap XML in markdown fences") { t.Fatalf("vercel prepare finalPrompt missing no-fence xml instruction: %q", finalPrompt) } if strings.Contains(finalPrompt, "```json") { diff --git a/internal/adapter/openai/responses_handler.go b/internal/adapter/openai/responses_handler.go index 7cb7ec3..1bd5b1c 100644 --- a/internal/adapter/openai/responses_handler.go +++ b/internal/adapter/openai/responses_handler.go @@ -130,12 +130,17 @@ func (h *Handler) handleResponsesNonStream(w http.ResponseWriter, resp *http.Res } responseObj := openaifmt.BuildResponseObject(responseID, model, finalPrompt, sanitizedThinking, sanitizedText, toolNames) - if result.OutputTokens > 0 { + if result.PromptTokens > 0 || result.OutputTokens > 0 { if usage, ok := responseObj["usage"].(map[string]any); ok { - usage["output_tokens"] = result.OutputTokens - if input, ok := usage["input_tokens"].(int); ok { - usage["total_tokens"] = input + result.OutputTokens + if result.PromptTokens > 0 { + usage["input_tokens"] = result.PromptTokens } + if result.OutputTokens > 0 { + usage["output_tokens"] = result.OutputTokens + } + input, _ := usage["input_tokens"].(int) + output, _ := usage["output_tokens"].(int) + usage["total_tokens"] = input + output } } h.getResponseStore().put(owner, responseID, responseObj) diff --git a/internal/adapter/openai/responses_stream_runtime_core.go b/internal/adapter/openai/responses_stream_runtime_core.go index 8072ccb..ff9ea26 100644 --- a/internal/adapter/openai/responses_stream_runtime_core.go +++ b/internal/adapter/openai/responses_stream_runtime_core.go @@ -51,6 +51,7 @@ type responsesStreamRuntime struct { messagePartAdded bool sequence int failed bool + promptTokens int outputTokens int persistResponse func(obj map[string]any) @@ -152,9 +153,19 @@ func (s *responsesStreamRuntime) finalize() { if s.outputTokens > 0 { if usage, ok := obj["usage"].(map[string]any); ok { usage["output_tokens"] = s.outputTokens - if input, ok := usage["input_tokens"].(int); ok { - usage["total_tokens"] = input + s.outputTokens + } + } + if s.promptTokens > 0 || s.outputTokens > 0 { + if usage, ok := obj["usage"].(map[string]any); ok { + if s.promptTokens > 0 { + usage["input_tokens"] = s.promptTokens } + if s.outputTokens > 0 { + usage["output_tokens"] = s.outputTokens + } + input, _ := usage["input_tokens"].(int) + output, _ := usage["output_tokens"].(int) + usage["total_tokens"] = input + output } } if s.persistResponse != nil { @@ -185,6 +196,9 @@ func (s *responsesStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Pa if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.PromptTokens > 0 { + s.promptTokens = parsed.PromptTokens + } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } diff --git a/internal/adapter/openai/stream_status_test.go b/internal/adapter/openai/stream_status_test.go index 6352141..1601a7c 100644 --- a/internal/adapter/openai/stream_status_test.go +++ b/internal/adapter/openai/stream_status_test.go @@ -238,3 +238,97 @@ func TestChatCompletionsStreamContentFilterStopsNormallyWithoutLeak(t *testing.T t.Fatalf("expected finish_reason=stop for content-filter upstream stop, got %#v", choice["finish_reason"]) } } + +func TestResponsesStreamUsageOverridesFromBatchAccumulatedTokenUsage(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse( + `data: {"p":"response/content","v":"hello"}`, + `data: {"p":"response","o":"BATCH","v":[{"p":"accumulated_token_usage","v":190},{"p":"quasi_status","v":"FINISHED"}]}`, + )}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"hi","stream":true}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + frames, done := parseSSEDataFrames(t, rec.Body.String()) + if !done { + t.Fatalf("expected [DONE], body=%s", rec.Body.String()) + } + if len(frames) == 0 { + t.Fatalf("expected at least one json frame, body=%s", rec.Body.String()) + } + last := frames[len(frames)-1] + resp, _ := last["response"].(map[string]any) + if resp == nil { + t.Fatalf("expected response payload in final frame, got %#v", last) + } + usage, _ := resp["usage"].(map[string]any) + if usage == nil { + t.Fatalf("expected usage in response payload, got %#v", resp) + } + if got, _ := usage["output_tokens"].(float64); int(got) != 190 { + t.Fatalf("expected output_tokens=190, got %#v", usage["output_tokens"]) + } +} + +func TestResponsesNonStreamUsageOverridesPromptAndOutputTokenUsage(t *testing.T) { + statuses := make([]int, 0, 1) + h := &Handler{ + Store: mockOpenAIConfig{wideInput: true}, + Auth: streamStatusAuthStub{}, + DS: streamStatusDSStub{resp: makeOpenAISSEHTTPResponse( + `data: {"p":"response/content","v":"ok"}`, + `data: {"p":"response","o":"BATCH","v":[{"p":"token_usage","v":{"prompt_tokens":11,"completion_tokens":29}},{"p":"quasi_status","v":"FINISHED"}]}`, + )}, + } + r := chi.NewRouter() + r.Use(captureStatusMiddleware(&statuses)) + RegisterRoutes(r, h) + + reqBody := `{"model":"deepseek-chat","input":"hi","stream":false}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(reqBody)) + req.Header.Set("Authorization", "Bearer direct-token") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if len(statuses) != 1 || statuses[0] != http.StatusOK { + t.Fatalf("expected captured status 200, got %#v", statuses) + } + var out map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v body=%s", err, rec.Body.String()) + } + usage, _ := out["usage"].(map[string]any) + if usage == nil { + t.Fatalf("expected usage object, got %#v", out) + } + if got, _ := usage["input_tokens"].(float64); int(got) != 11 { + t.Fatalf("expected input_tokens=11, got %#v", usage["input_tokens"]) + } + if got, _ := usage["output_tokens"].(float64); int(got) != 29 { + t.Fatalf("expected output_tokens=29, got %#v", usage["output_tokens"]) + } + if got, _ := usage["total_tokens"].(float64); int(got) != 40 { + t.Fatalf("expected total_tokens=40, got %#v", usage["total_tokens"]) + } +} diff --git a/tests/node/chat-stream.test.js b/tests/node/chat-stream.test.js index 9681843..d1cc859 100644 --- a/tests/node/chat-stream.test.js +++ b/tests/node/chat-stream.test.js @@ -275,7 +275,7 @@ test('parseChunkForContent keeps error branches distinct from content_filter sta assert.equal(parsed.finished, true); assert.equal(parsed.contentFilter, false); assert.equal(parsed.errorMessage.length > 0, true); - assert.equal(parsed.outputTokens, 0); + assert.equal(parsed.outputTokens, 88); assert.deepEqual(parsed.parts, []); }); @@ -292,6 +292,26 @@ test('parseChunkForContent preserves output tokens on FINISHED lines', () => { assert.deepEqual(parsed.parts, []); }); +test('parseChunkForContent captures output tokens from response BATCH status snapshots', () => { + const parsed = parseChunkForContent( + { + p: 'response', + o: 'BATCH', + v: [ + { p: 'accumulated_token_usage', v: 190 }, + { p: 'quasi_status', v: 'FINISHED' }, + ], + }, + false, + 'text', + ); + assert.equal(parsed.parsed, true); + assert.equal(parsed.finished, false); + assert.equal(parsed.contentFilter, false); + assert.equal(parsed.outputTokens, 190); + assert.deepEqual(parsed.parts, []); +}); + test('parseChunkForContent matches FINISHED case-insensitively on status paths', () => { const parsed = parseChunkForContent( { p: 'response/status', v: ' finished ', accumulated_token_usage: 190 }, From 5bcea3d727c5f70cb290be10e9722edbb62bf86c Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 7 Apr 2026 10:16:00 +0800 Subject: [PATCH 5/7] Propagate upstream token usage across Gemini usage metadata --- API.en.md | 2 ++ API.md | 2 ++ internal/adapter/gemini/handler_generate.go | 11 +++++--- .../adapter/gemini/handler_stream_runtime.go | 6 ++++- internal/adapter/gemini/handler_test.go | 26 +++++++++++++++++++ 5 files changed, 42 insertions(+), 5 deletions(-) diff --git a/API.en.md b/API.en.md index 7276c86..2b83245 100644 --- a/API.en.md +++ b/API.en.md @@ -267,6 +267,7 @@ data: [DONE] - `deepseek-reasoner` / `deepseek-reasoner-search` models emit `delta.reasoning_content` - Text emits `delta.content` - Last chunk includes `finish_reason` and `usage` +- Token counting prefers pass-through from upstream DeepSeek SSE (`accumulated_token_usage` / `token_usage`), and only falls back to local estimation when upstream usage is absent #### Tool Calls @@ -535,6 +536,7 @@ Returns SSE (`text/event-stream`), each chunk as `data: `: - regular text: incremental text chunks - `tools` mode: buffered and emitted as `functionCall` at finalize phase - final chunk: includes `finishReason: "STOP"` and `usageMetadata` +- Token counting prefers pass-through from upstream DeepSeek SSE (`accumulated_token_usage` / `token_usage`), and only falls back to local estimation when upstream usage is absent --- diff --git a/API.md b/API.md index 8552052..1caa984 100644 --- a/API.md +++ b/API.md @@ -267,6 +267,7 @@ data: [DONE] - `deepseek-reasoner` / `deepseek-reasoner-search` 模型输出 `delta.reasoning_content` - 普通文本输出 `delta.content` - 最后一段包含 `finish_reason` 和 `usage` +- token 计数优先透传上游 DeepSeek SSE(如 `accumulated_token_usage` / `token_usage`);仅在上游缺失时回退本地估算 #### Tool Calls @@ -541,6 +542,7 @@ data: {"type":"message_stop"} - 常规文本:持续返回增量文本 chunk - `tools` 场景:会缓冲并在结束时输出 `functionCall` 结构 - 结束 chunk:包含 `finishReason: "STOP"` 与 `usageMetadata` +- token 计数优先透传上游 DeepSeek SSE(如 `accumulated_token_usage` / `token_usage`);仅在上游缺失时回退本地估算 --- diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go index b03b3ea..56cc0e6 100644 --- a/internal/adapter/gemini/handler_generate.go +++ b/internal/adapter/gemini/handler_generate.go @@ -149,14 +149,15 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht cleanVisibleOutput(result.Thinking, stripReferenceMarkers), cleanVisibleOutput(result.Text, stripReferenceMarkers), toolNames, + result.PromptTokens, result.OutputTokens, )) } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any { +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, promptTokens, outputTokens int) map[string]any { parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, promptTokens, outputTokens) return map[string]any{ "candidates": []map[string]any{ { @@ -174,8 +175,10 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any { - promptTokens := util.EstimateTokens(finalPrompt) +func buildGeminiUsage(finalPrompt, finalThinking, finalText string, promptTokens, outputTokens int) map[string]any { + if promptTokens <= 0 { + promptTokens = util.EstimateTokens(finalPrompt) + } reasoningTokens := util.EstimateTokens(finalThinking) completionTokens := util.EstimateTokens(finalText) if outputTokens > 0 { diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go index e7c9b87..b8d2701 100644 --- a/internal/adapter/gemini/handler_stream_runtime.go +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -67,6 +67,7 @@ type geminiStreamRuntime struct { thinking strings.Builder text strings.Builder + promptTokens int outputTokens int } @@ -112,6 +113,9 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } + if parsed.PromptTokens > 0 { + s.promptTokens = parsed.PromptTokens + } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } @@ -198,6 +202,6 @@ func (s *geminiStreamRuntime) finalize() { }, }, "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens), + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.promptTokens, s.outputTokens), }) } diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go index b7aea1b..aa3ae46 100644 --- a/internal/adapter/gemini/handler_test.go +++ b/internal/adapter/gemini/handler_test.go @@ -296,6 +296,32 @@ func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) { } } +func TestBuildGeminiUsageOverridesPromptAndOutputTokensWhenProvided(t *testing.T) { + usage := buildGeminiUsage("prompt", "thinking", "answer", 11, 29) + if got, _ := usage["promptTokenCount"].(int); got != 11 { + t.Fatalf("expected promptTokenCount=11, got %#v", usage["promptTokenCount"]) + } + if got, _ := usage["candidatesTokenCount"].(int); got != 29 { + t.Fatalf("expected candidatesTokenCount=29, got %#v", usage["candidatesTokenCount"]) + } + if got, _ := usage["totalTokenCount"].(int); got != 40 { + t.Fatalf("expected totalTokenCount=40, got %#v", usage["totalTokenCount"]) + } +} + +func TestBuildGeminiUsageFallsBackToEstimateWhenNoUpstreamUsage(t *testing.T) { + usage := buildGeminiUsage("abcdef", "", "ghijkl", 0, 0) + if got, _ := usage["promptTokenCount"].(int); got <= 0 { + t.Fatalf("expected positive promptTokenCount estimate, got %#v", usage["promptTokenCount"]) + } + if got, _ := usage["candidatesTokenCount"].(int); got <= 0 { + t.Fatalf("expected positive candidatesTokenCount estimate, got %#v", usage["candidatesTokenCount"]) + } + if got, _ := usage["totalTokenCount"].(int); got <= 0 { + t.Fatalf("expected positive totalTokenCount estimate, got %#v", usage["totalTokenCount"]) + } +} + func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any { t.Helper() scanner := bufio.NewScanner(strings.NewReader(body)) From 668b9c26bdbf01772d19f81bd160cc6df2019b26 Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 7 Apr 2026 10:16:23 +0800 Subject: [PATCH 6/7] Unify token usage pass-through on OpenAI translate pipeline --- API.en.md | 2 + API.md | 2 + internal/adapter/gemini/handler_generate.go | 11 +-- .../adapter/gemini/handler_stream_runtime.go | 6 +- internal/adapter/gemini/handler_test.go | 26 ----- internal/translatorcliproxy/bridge_test.go | 20 ++++ internal/translatorcliproxy/stream_writer.go | 97 +++++++++++++++++++ .../translatorcliproxy/stream_writer_test.go | 8 ++ 8 files changed, 134 insertions(+), 38 deletions(-) diff --git a/API.en.md b/API.en.md index 2b83245..0238a42 100644 --- a/API.en.md +++ b/API.en.md @@ -384,6 +384,7 @@ Business auth required. Returns OpenAI-compatible embeddings shape. ## Claude-Compatible API Besides `/anthropic/v1/*`, DS2API also supports shortcut paths: `/v1/messages`, `/messages`, `/v1/messages/count_tokens`, `/messages/count_tokens`. +Implementation-wise this path is unified on the OpenAI Chat Completions parse-and-translate pipeline to avoid maintaining divergent parsing chains. ### `GET /anthropic/v1/models` @@ -518,6 +519,7 @@ Supported paths: - `/v1/models/{model}:streamGenerateContent` (compat path) Authentication is the same as other business routes (`Authorization: Bearer ` or `x-api-key`). +Implementation-wise this path is unified on the OpenAI Chat Completions parse-and-translate pipeline to avoid maintaining divergent parsing chains. ### `POST /v1beta/models/{model}:generateContent` diff --git a/API.md b/API.md index 1caa984..d2eb1f0 100644 --- a/API.md +++ b/API.md @@ -390,6 +390,7 @@ data: [DONE] ## Claude 兼容接口 除标准路径 `/anthropic/v1/*` 外,还支持快捷路径 `/v1/messages`、`/messages`、`/v1/messages/count_tokens`、`/messages/count_tokens`。 +实现上统一走 OpenAI Chat Completions 解析与回译链路,避免多套解析逻辑分叉维护。 ### `GET /anthropic/v1/models` @@ -524,6 +525,7 @@ data: {"type":"message_stop"} - `/v1/models/{model}:streamGenerateContent`(兼容路径) 鉴权方式同业务接口(`Authorization: Bearer ` 或 `x-api-key`)。 +实现上统一走 OpenAI Chat Completions 解析与回译链路,避免多套解析逻辑分叉维护。 ### `POST /v1beta/models/{model}:generateContent` diff --git a/internal/adapter/gemini/handler_generate.go b/internal/adapter/gemini/handler_generate.go index 56cc0e6..b03b3ea 100644 --- a/internal/adapter/gemini/handler_generate.go +++ b/internal/adapter/gemini/handler_generate.go @@ -149,15 +149,14 @@ func (h *Handler) handleNonStreamGenerateContent(w http.ResponseWriter, resp *ht cleanVisibleOutput(result.Thinking, stripReferenceMarkers), cleanVisibleOutput(result.Text, stripReferenceMarkers), toolNames, - result.PromptTokens, result.OutputTokens, )) } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, promptTokens, outputTokens int) map[string]any { +func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, finalText string, toolNames []string, outputTokens int) map[string]any { parts := buildGeminiPartsFromFinal(finalText, finalThinking, toolNames) - usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, promptTokens, outputTokens) + usage := buildGeminiUsage(finalPrompt, finalThinking, finalText, outputTokens) return map[string]any{ "candidates": []map[string]any{ { @@ -175,10 +174,8 @@ func buildGeminiGenerateContentResponse(model, finalPrompt, finalThinking, final } //nolint:unused // retained for native Gemini non-stream handling path. -func buildGeminiUsage(finalPrompt, finalThinking, finalText string, promptTokens, outputTokens int) map[string]any { - if promptTokens <= 0 { - promptTokens = util.EstimateTokens(finalPrompt) - } +func buildGeminiUsage(finalPrompt, finalThinking, finalText string, outputTokens int) map[string]any { + promptTokens := util.EstimateTokens(finalPrompt) reasoningTokens := util.EstimateTokens(finalThinking) completionTokens := util.EstimateTokens(finalText) if outputTokens > 0 { diff --git a/internal/adapter/gemini/handler_stream_runtime.go b/internal/adapter/gemini/handler_stream_runtime.go index b8d2701..e7c9b87 100644 --- a/internal/adapter/gemini/handler_stream_runtime.go +++ b/internal/adapter/gemini/handler_stream_runtime.go @@ -67,7 +67,6 @@ type geminiStreamRuntime struct { thinking strings.Builder text strings.Builder - promptTokens int outputTokens int } @@ -113,9 +112,6 @@ func (s *geminiStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse if !parsed.Parsed { return streamengine.ParsedDecision{} } - if parsed.PromptTokens > 0 { - s.promptTokens = parsed.PromptTokens - } if parsed.OutputTokens > 0 { s.outputTokens = parsed.OutputTokens } @@ -202,6 +198,6 @@ func (s *geminiStreamRuntime) finalize() { }, }, "modelVersion": s.model, - "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.promptTokens, s.outputTokens), + "usageMetadata": buildGeminiUsage(s.finalPrompt, finalThinking, finalText, s.outputTokens), }) } diff --git a/internal/adapter/gemini/handler_test.go b/internal/adapter/gemini/handler_test.go index aa3ae46..b7aea1b 100644 --- a/internal/adapter/gemini/handler_test.go +++ b/internal/adapter/gemini/handler_test.go @@ -296,32 +296,6 @@ func TestGenerateContentOpenAIProxyErrorUsesGeminiEnvelope(t *testing.T) { } } -func TestBuildGeminiUsageOverridesPromptAndOutputTokensWhenProvided(t *testing.T) { - usage := buildGeminiUsage("prompt", "thinking", "answer", 11, 29) - if got, _ := usage["promptTokenCount"].(int); got != 11 { - t.Fatalf("expected promptTokenCount=11, got %#v", usage["promptTokenCount"]) - } - if got, _ := usage["candidatesTokenCount"].(int); got != 29 { - t.Fatalf("expected candidatesTokenCount=29, got %#v", usage["candidatesTokenCount"]) - } - if got, _ := usage["totalTokenCount"].(int); got != 40 { - t.Fatalf("expected totalTokenCount=40, got %#v", usage["totalTokenCount"]) - } -} - -func TestBuildGeminiUsageFallsBackToEstimateWhenNoUpstreamUsage(t *testing.T) { - usage := buildGeminiUsage("abcdef", "", "ghijkl", 0, 0) - if got, _ := usage["promptTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive promptTokenCount estimate, got %#v", usage["promptTokenCount"]) - } - if got, _ := usage["candidatesTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive candidatesTokenCount estimate, got %#v", usage["candidatesTokenCount"]) - } - if got, _ := usage["totalTokenCount"].(int); got <= 0 { - t.Fatalf("expected positive totalTokenCount estimate, got %#v", usage["totalTokenCount"]) - } -} - func extractGeminiSSEFrames(t *testing.T, body string) []map[string]any { t.Helper() scanner := bufio.NewScanner(strings.NewReader(body)) diff --git a/internal/translatorcliproxy/bridge_test.go b/internal/translatorcliproxy/bridge_test.go index 5f0979f..cdd9cf7 100644 --- a/internal/translatorcliproxy/bridge_test.go +++ b/internal/translatorcliproxy/bridge_test.go @@ -26,6 +26,26 @@ func TestFromOpenAINonStreamClaude(t *testing.T) { } } +func TestFromOpenAINonStreamClaudePreservesUsageFromOpenAI(t *testing.T) { + original := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + translatedReq := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"claude-sonnet-4-5","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":29,"total_tokens":40}}`) + got := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", original, translatedReq, openaibody)) + if !strings.Contains(got, `"input_tokens":11`) || !strings.Contains(got, `"output_tokens":29`) { + t.Fatalf("expected claude usage to preserve prompt/completion tokens, got: %s", got) + } +} + +func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) { + original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"chatcmpl_1","object":"chat.completion","created":1,"model":"gemini-2.5-pro","choices":[{"index":0,"message":{"role":"assistant","content":"hello"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":29,"total_tokens":40}}`) + got := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody)) + if !strings.Contains(got, `"promptTokenCount":11`) || !strings.Contains(got, `"candidatesTokenCount":29`) || !strings.Contains(got, `"totalTokenCount":40`) { + t.Fatalf("expected gemini usageMetadata to preserve prompt/completion tokens, got: %s", got) + } +} + func TestParseFormatAliases(t *testing.T) { cases := map[string]sdktranslator.Format{ "responses": sdktranslator.FormatOpenAIResponse, diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index 07c4bcb..b1b8747 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -3,7 +3,9 @@ package translatorcliproxy import ( "bytes" "context" + "encoding/json" "net/http" + "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" ) @@ -77,7 +79,13 @@ func (w *OpenAIStreamTranslatorWriter) Write(p []byte) (int, error) { if !bytes.HasPrefix(trimmed, []byte("data:")) { continue } + usage, hasUsage := extractOpenAIUsage(trimmed) chunks := sdktranslator.TranslateStream(context.Background(), sdktranslator.FormatOpenAI, w.target, w.model, w.originalReq, w.translatedReq, trimmed, &w.param) + if hasUsage { + for i := range chunks { + chunks[i] = injectStreamUsageMetadata(chunks[i], w.target, usage) + } + } for i := range chunks { if len(chunks[i]) == 0 { continue @@ -118,3 +126,92 @@ func (w *OpenAIStreamTranslatorWriter) readOneLine() ([]byte, bool) { w.lineBuf.Next(idx + 1) return line, true } + +type openAIUsage struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + +func extractOpenAIUsage(line []byte) (openAIUsage, bool) { + raw := strings.TrimSpace(strings.TrimPrefix(string(line), "data:")) + if raw == "" || raw == "[DONE]" { + return openAIUsage{}, false + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil { + return openAIUsage{}, false + } + usageObj, _ := payload["usage"].(map[string]any) + if usageObj == nil { + return openAIUsage{}, false + } + p := toInt(usageObj["prompt_tokens"]) + c := toInt(usageObj["completion_tokens"]) + t := toInt(usageObj["total_tokens"]) + if p <= 0 && c <= 0 && t <= 0 { + return openAIUsage{}, false + } + if t <= 0 { + t = p + c + } + return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true +} + +func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage openAIUsage) []byte { + if target != sdktranslator.FormatGemini { + return chunk + } + text := strings.TrimSpace(string(chunk)) + if text == "" { + return chunk + } + var ( + hasDataPrefix bool + jsonText = text + ) + if strings.HasPrefix(jsonText, "data:") { + hasDataPrefix = true + jsonText = strings.TrimSpace(strings.TrimPrefix(jsonText, "data:")) + } + if jsonText == "" || jsonText == "[DONE]" { + return chunk + } + obj := map[string]any{} + if err := json.Unmarshal([]byte(jsonText), &obj); err != nil { + return chunk + } + if _, ok := obj["candidates"]; !ok { + return chunk + } + obj["usageMetadata"] = map[string]any{ + "promptTokenCount": usage.PromptTokens, + "candidatesTokenCount": usage.CompletionTokens, + "totalTokenCount": usage.TotalTokens, + } + b, err := json.Marshal(obj) + if err != nil { + return chunk + } + if hasDataPrefix { + return []byte("data: " + string(b)) + } + return b +} + +func toInt(v any) int { + switch x := v.(type) { + case int: + return x + case int32: + return int(x) + case int64: + return int(x) + case float64: + return int(x) + case float32: + return int(x) + default: + return 0 + } +} diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 979d36e..77d2936 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -18,12 +18,16 @@ func TestOpenAIStreamTranslatorWriterClaude(t *testing.T) { w.WriteHeader(200) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\"},\"finish_reason\":null}]}\n\n")) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"claude-sonnet-4-5\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":29,\"total_tokens\":40}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) body := rec.Body.String() if !strings.Contains(body, "event: message_start") { t.Fatalf("expected claude message_start event, got: %s", body) } + if !strings.Contains(body, `"output_tokens":29`) { + t.Fatalf("expected claude stream usage to preserve output tokens, got: %s", body) + } } func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) { @@ -35,12 +39,16 @@ func TestOpenAIStreamTranslatorWriterGemini(t *testing.T) { w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(200) _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"},\"finish_reason\":null}]}\n\n")) + _, _ = w.Write([]byte("data: {\"id\":\"chatcmpl_1\",\"object\":\"chat.completion.chunk\",\"created\":1,\"model\":\"gemini-2.5-pro\",\"choices\":[{\"index\":0,\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":29,\"total_tokens\":40}}\n\n")) _, _ = w.Write([]byte("data: [DONE]\n\n")) body := rec.Body.String() if !strings.Contains(body, "candidates") { t.Fatalf("expected gemini stream payload, got: %s", body) } + if !strings.Contains(body, `"promptTokenCount":11`) || !strings.Contains(body, `"candidatesTokenCount":29`) { + t.Fatalf("expected gemini stream usageMetadata to preserve usage, got: %s", body) + } } func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) { From 86ecbc89bd31a80e97ee516ad40178dd4440634d Mon Sep 17 00:00:00 2001 From: "CJACK." Date: Tue, 7 Apr 2026 10:59:27 +0800 Subject: [PATCH 7/7] Preserve SSE frame delimiters when injecting Gemini usage --- internal/translatorcliproxy/stream_writer.go | 12 +++++++++++- internal/translatorcliproxy/stream_writer_test.go | 12 ++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index b1b8747..e80ce69 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -162,6 +162,13 @@ func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage if target != sdktranslator.FormatGemini { return chunk } + suffix := "" + switch { + case bytes.HasSuffix(chunk, []byte("\n\n")): + suffix = "\n\n" + case bytes.HasSuffix(chunk, []byte("\n")): + suffix = "\n" + } text := strings.TrimSpace(string(chunk)) if text == "" { return chunk @@ -194,7 +201,10 @@ func injectStreamUsageMetadata(chunk []byte, target sdktranslator.Format, usage return chunk } if hasDataPrefix { - return []byte("data: " + string(b)) + return []byte("data: " + string(b) + suffix) + } + if suffix != "" { + return append(b, []byte(suffix)...) } return b } diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 77d2936..94d70b8 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -63,3 +63,15 @@ func TestOpenAIStreamTranslatorWriterPreservesKeepAliveComment(t *testing.T) { t.Fatalf("expected keep-alive comment passthrough, got %q", body) } } + +func TestInjectStreamUsageMetadataPreservesSSEFrameTerminator(t *testing.T) { + chunk := []byte("data: {\"candidates\":[{\"index\":0}],\"model\":\"gemini-2.5-pro\"}\n\n") + usage := openAIUsage{PromptTokens: 11, CompletionTokens: 29, TotalTokens: 40} + got := injectStreamUsageMetadata(chunk, sdktranslator.FormatGemini, usage) + if !strings.HasSuffix(string(got), "\n\n") { + t.Fatalf("expected injected chunk to preserve \\n\\n frame terminator, got %q", string(got)) + } + if !strings.Contains(string(got), `"usageMetadata"`) { + t.Fatalf("expected usageMetadata injected, got %q", string(got)) + } +}