diff --git a/README.md b/README.md index 4eab8da..15ee842 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,29 @@ sudo docker compose down **代理配置(可选):** 如需使用代理访问 Google 服务,在 Docker 命令中添加 `-e HTTP_PROXY=http://your-proxy:port -e HTTPS_PROXY=http://your-proxy:port`,或在 `docker-compose.yml` 的 `environment` 中添加这两个环境变量。 +##### 🛠️ 方式 3:从源码构建 + +如果您希望自己构建 Docker 镜像,可以使用以下命令: + +1. 构建镜像: + +```bash +docker build -t aistudio-to-api . +``` + +2. 运行容器: + +```bash +docker run -d \ + --name aistudio-to-api \ + -p 7860:7860 \ + -v /path/to/auth:/app/configs/auth \ + -e API_KEYS=123456,AIzaSyA5dIf8f56a_6Qmn8VvtjERe4XlqQxFahA \ + -e TZ=Asia/Shanghai \ + --restart unless-stopped \ + aistudio-to-api +``` + #### 🔑 步骤 2:账号管理 部署后,您需要使用以下方式之一添加 Google 账号: diff --git a/scripts/client/build.js b/scripts/client/build.js index 05a9ee8..c5c7eee 100644 --- a/scripts/client/build.js +++ b/scripts/client/build.js @@ -7,6 +7,21 @@ /* eslint-env browser */ +const b64toBlob = (b64Data, contentType = "", sliceSize = 512) => { + const byteCharacters = atob(b64Data); + const byteArrays = []; + for (let offset = 0; offset < byteCharacters.length; offset += sliceSize) { + const slice = byteCharacters.slice(offset, offset + sliceSize); + const byteNumbers = new Array(slice.length); + for (let i = 0; i < slice.length; i++) { + byteNumbers[i] = slice.charCodeAt(i); + } + const byteArray = new Uint8Array(byteNumbers); + byteArrays.push(byteArray); + } + return new Blob(byteArrays, { type: contentType }); +}; + const Logger = { enabled: true, output(...messages) { @@ -84,7 +99,7 @@ class ConnectionManager extends EventTarget { this.reconnectAttempts++; setTimeout(() => { Logger.output(`Attempting reconnection ${this.reconnectAttempts} attempt...`); - this.establish().catch(() => {}); + this.establish().catch(() => { }); }, this.reconnectDelay); } } @@ -158,21 +173,89 @@ class RequestProcessor { } _constructUrl(requestSpec) { - let pathSegment = requestSpec.path.startsWith("/") ? requestSpec.path.substring(1) : requestSpec.path; - const queryParams = new URLSearchParams(requestSpec.query_params); - if (requestSpec.streaming_mode === "fake") { - Logger.output("Buffered mode activated (Non-Stream / Fake-Stream), checking request details..."); - if (pathSegment.includes(":streamGenerateContent")) { - pathSegment = pathSegment.replace(":streamGenerateContent", ":generateContent"); - Logger.output(`API path modified to: ${pathSegment}`); + let pathAndQuery = requestSpec.url; + + if (!pathAndQuery) { + const pathSegment = requestSpec.path || ""; + const queryParams = new URLSearchParams(requestSpec.query_params); + + // Handle fake streaming mode adjustments + if (requestSpec.streaming_mode === "fake") { + if (pathSegment.includes(":streamGenerateContent")) { + // This is a bit risky if pathSegment is modified, but BuildProxy does it on the joined string + // We'll follow BuildProxy structured approach but keep this feature + } + if (queryParams.has("alt") && queryParams.get("alt") === "sse") { + queryParams.delete("alt"); + } } - if (queryParams.has("alt") && queryParams.get("alt") === "sse") { - queryParams.delete("alt"); - Logger.output('Removed "alt=sse" query parameter.'); + + // Special handling for legacy path construction if url not provided + let finalPath = pathSegment; + if (requestSpec.streaming_mode === "fake" && finalPath.includes(":streamGenerateContent")) { + finalPath = finalPath.replace(":streamGenerateContent", ":generateContent"); } + + const queryString = queryParams.toString(); + pathAndQuery = `${finalPath}${queryString ? "?" + queryString : ""}`; } - const queryString = queryParams.toString(); - return `https://${this.targetDomain}/${pathSegment}${queryString ? "?" + queryString : ""}`; + + // Rewriting absolute URLs (if provided) + if (pathAndQuery.match(/^https?:\/\//)) { + try { + const urlObj = new URL(pathAndQuery); + const originalUrl = pathAndQuery; + pathAndQuery = urlObj.pathname + urlObj.search; + Logger.output(`Rewriting absolute URL: ${originalUrl} -> ${pathAndQuery}`); + } catch (e) { + Logger.output("URL parsing warning:", e.message); + } + } + + let targetHost = this.targetDomain; + if (pathAndQuery.includes("__proxy_host__=")) { + try { + const tempUrl = new URL(pathAndQuery, "http://dummy"); + const params = tempUrl.searchParams; + if (params.has("__proxy_host__")) { + targetHost = params.get("__proxy_host__"); + params.delete("__proxy_host__"); + pathAndQuery = tempUrl.pathname + tempUrl.search; + Logger.output(`Dynamically switching target host: ${targetHost}`); + } + } catch (e) { + Logger.output("Failed to parse proxy host:", e.message); + } + } + + let cleanPath = pathAndQuery.replace(/^\/+/, ""); + const method = requestSpec.method ? requestSpec.method.toUpperCase() : "GET"; + + if (this.targetDomain.includes("generativelanguage")) { + const versionRegex = /v1[a-z0-9]*\/files/; + const uploadMatch = cleanPath.match(new RegExp(`upload\/${versionRegex.source}`)); + + if (uploadMatch) { + // If path already contains upload/, just ensure it's correct + const index = cleanPath.indexOf("upload/"); + if (index > 0) { + const fixedPath = cleanPath.substring(index); + Logger.output(`Corrected path: ${cleanPath} -> ${fixedPath}`); + cleanPath = fixedPath; + } + } else if (method === "POST") { + // Detect if it starts with version and 'files', e.g. v1beta/files + const filesPathMatch = cleanPath.match(new RegExp(`^${versionRegex.source}`)); + if (filesPathMatch) { + cleanPath = "upload/" + cleanPath; + Logger.output("Auto-completing upload path:", cleanPath); + } + } + } + + const finalUrl = `https://${targetHost}/${cleanPath}`; + Logger.output(`Constructed URL: ${pathAndQuery} -> ${finalUrl}`); + return finalUrl; } _generateRandomString(length) { @@ -189,110 +272,116 @@ class RequestProcessor { signal, }; - if (["POST", "PUT", "PATCH"].includes(requestSpec.method) && requestSpec.body) { - try { - const bodyObj = JSON.parse(requestSpec.body); - - // --- Module 1: Image/Embedding/TTS Model Filtering --- - // These models do NOT support: tools, thinkingConfig, systemInstruction, response_mime_type - const isImageModel = requestSpec.path.includes("-image") || requestSpec.path.includes("imagen"); - const isEmbeddingModel = requestSpec.path.includes("embedding"); - const isTtsModel = requestSpec.path.includes("tts"); - if (isImageModel || isEmbeddingModel || isTtsModel) { - // Remove tools - const incompatibleKeys = ["toolConfig", "tool_config", "toolChoice", "tools"]; - incompatibleKeys.forEach(key => { - if (Object.prototype.hasOwnProperty.call(bodyObj, key)) delete bodyObj[key]; - }); - // Remove thinkingConfig - if (bodyObj.generationConfig?.thinkingConfig) { - delete bodyObj.generationConfig.thinkingConfig; - } - // Remove systemInstruction - if (bodyObj.systemInstruction) { - delete bodyObj.systemInstruction; + if (["POST", "PUT", "PATCH"].includes(requestSpec.method)) { + if (!requestSpec.is_generative && requestSpec.body_b64) { + const contentType = requestSpec.headers?.["content-type"] || ""; + config.body = b64toBlob(requestSpec.body_b64, contentType); + Logger.output("Using binary body (Base64 decoded) for non-generative request"); + } else if (requestSpec.body) { + try { + const bodyObj = JSON.parse(requestSpec.body); + + // --- Module 1: Image/Embedding/TTS Model Filtering --- + // These models do NOT support: tools, thinkingConfig, systemInstruction, response_mime_type + const isImageModel = requestSpec.path.includes("-image") || requestSpec.path.includes("imagen"); + const isEmbeddingModel = requestSpec.path.includes("embedding"); + const isTtsModel = requestSpec.path.includes("tts"); + if (isImageModel || isEmbeddingModel || isTtsModel) { + // Remove tools + const incompatibleKeys = ["toolConfig", "tool_config", "toolChoice", "tools"]; + incompatibleKeys.forEach(key => { + if (Object.prototype.hasOwnProperty.call(bodyObj, key)) delete bodyObj[key]; + }); + // Remove thinkingConfig + if (bodyObj.generationConfig?.thinkingConfig) { + delete bodyObj.generationConfig.thinkingConfig; + } + // Remove systemInstruction + if (bodyObj.systemInstruction) { + delete bodyObj.systemInstruction; + } + // Remove response_mime_type + if (bodyObj.generationConfig?.response_mime_type) { + delete bodyObj.generationConfig.response_mime_type; + } + if (bodyObj.generationConfig?.responseMimeType) { + delete bodyObj.generationConfig.responseMimeType; + } } - // Remove response_mime_type - if (bodyObj.generationConfig?.response_mime_type) { - delete bodyObj.generationConfig.response_mime_type; - } - if (bodyObj.generationConfig?.responseMimeType) { - delete bodyObj.generationConfig.responseMimeType; - } - } - // --- Module 1.5: responseModalities Handling --- - // Image: keep as-is (needed for image generation) - // Embedding: remove - // TTS: force to ["AUDIO"] - if (isTtsModel) { - if (!bodyObj.generationConfig) { - bodyObj.generationConfig = {}; + // --- Module 1.5: responseModalities Handling --- + // Image: keep as-is (needed for image generation) + // Embedding: remove + // TTS: force to ["AUDIO"] + if (isTtsModel) { + if (!bodyObj.generationConfig) { + bodyObj.generationConfig = {}; + } + bodyObj.generationConfig.responseModalities = ["AUDIO"]; + Logger.output("TTS model detected, setting responseModalities to AUDIO"); + } else if (isEmbeddingModel) { + if (bodyObj.generationConfig?.responseModalities) { + delete bodyObj.generationConfig.responseModalities; + } } - bodyObj.generationConfig.responseModalities = ["AUDIO"]; - Logger.output("TTS model detected, setting responseModalities to AUDIO"); - } else if (isEmbeddingModel) { - if (bodyObj.generationConfig?.responseModalities) { - delete bodyObj.generationConfig.responseModalities; - } - } - // --- Module 2: Computer-Use Model Filtering --- - // Remove tools, responseModalities - const isComputerUseModel = requestSpec.path.includes("computer-use"); - if (isComputerUseModel) { - const incompatibleKeys = ["tool_config", "toolChoice", "tools"]; - incompatibleKeys.forEach(key => { - if (Object.prototype.hasOwnProperty.call(bodyObj, key)) delete bodyObj[key]; - }); - if (bodyObj.generationConfig?.responseModalities) { - delete bodyObj.generationConfig.responseModalities; + // --- Module 2: Computer-Use Model Filtering --- + // Remove tools, responseModalities + const isComputerUseModel = requestSpec.path.includes("computer-use"); + if (isComputerUseModel) { + const incompatibleKeys = ["tool_config", "toolChoice", "tools"]; + incompatibleKeys.forEach(key => { + if (Object.prototype.hasOwnProperty.call(bodyObj, key)) delete bodyObj[key]; + }); + if (bodyObj.generationConfig?.responseModalities) { + delete bodyObj.generationConfig.responseModalities; + } } - } - // --- Module 3: Robotics Model Filtering --- - // Remove googleSearch, urlContext from tools; also remove responseModalities - const isRoboticsModel = requestSpec.path.includes("robotics"); - if (isRoboticsModel) { - if (Array.isArray(bodyObj.tools)) { - bodyObj.tools = bodyObj.tools.filter(t => !t.googleSearch && !t.urlContext); - if (bodyObj.tools.length === 0) delete bodyObj.tools; + // --- Module 3: Robotics Model Filtering --- + // Remove googleSearch, urlContext from tools; also remove responseModalities + const isRoboticsModel = requestSpec.path.includes("robotics"); + if (isRoboticsModel) { + if (Array.isArray(bodyObj.tools)) { + bodyObj.tools = bodyObj.tools.filter(t => !t.googleSearch && !t.urlContext); + if (bodyObj.tools.length === 0) delete bodyObj.tools; + } + if (bodyObj.generationConfig?.responseModalities) { + delete bodyObj.generationConfig.responseModalities; + } } - if (bodyObj.generationConfig?.responseModalities) { - delete bodyObj.generationConfig.responseModalities; + + // adapt gemini 3 pro preview + // if raise `400 INVALID_ARGUMENT`, try to delete `thinkingLevel` + // if (bodyObj.generationConfig?.thinkingConfig?.thinkingLevel) { + // delete bodyObj.generationConfig.thinkingConfig.thinkingLevel; + // } + + // upper case `thinkingLevel` + if (bodyObj.generationConfig?.thinkingConfig?.thinkingLevel) { + bodyObj.generationConfig.thinkingConfig.thinkingLevel = String( + bodyObj.generationConfig.thinkingConfig.thinkingLevel + ).toUpperCase(); } - } - // adapt gemini 3 pro preview - // if raise `400 INVALID_ARGUMENT`, try to delete `thinkingLevel` - // if (bodyObj.generationConfig?.thinkingConfig?.thinkingLevel) { - // delete bodyObj.generationConfig.thinkingConfig.thinkingLevel; - // } - - // upper case `thinkingLevel` - if (bodyObj.generationConfig?.thinkingConfig?.thinkingLevel) { - bodyObj.generationConfig.thinkingConfig.thinkingLevel = String( - bodyObj.generationConfig.thinkingConfig.thinkingLevel - ).toUpperCase(); + // if raise `400 INVALID_ARGUMENT`, try to delete `thoughtSignature` + // if (Array.isArray(bodyObj.contents)) { + // bodyObj.contents.forEach(msg => { + // if (Array.isArray(msg.parts)) { + // msg.parts.forEach(part => { + // if (part.thoughtSignature) { + // delete part.thoughtSignature; + // } + // }); + // } + // }); + // } + + config.body = JSON.stringify(bodyObj); + } catch (e) { + Logger.output("Error occurred while processing request body:", e.message); + config.body = requestSpec.body; } - - // if raise `400 INVALID_ARGUMENT`, try to delete `thoughtSignature` - // if (Array.isArray(bodyObj.contents)) { - // bodyObj.contents.forEach(msg => { - // if (Array.isArray(msg.parts)) { - // msg.parts.forEach(part => { - // if (part.thoughtSignature) { - // delete part.thoughtSignature; - // } - // }); - // } - // }); - // } - - config.body = JSON.stringify(bodyObj); - } catch (e) { - Logger.output("Error occurred while processing request body:", e.message); - config.body = requestSpec.body; } } @@ -301,17 +390,20 @@ class RequestProcessor { _sanitizeHeaders(headers) { const sanitized = { ...headers }; - [ + // Follow BuildProxy's forbidden list exactly + const forbiddenHeaders = [ "host", "connection", "content-length", - "origin", + /* 'origin', */ // BuildProxy comments these out "referer", "user-agent", "sec-fetch-mode", "sec-fetch-site", "sec-fetch-dest", - ].forEach(h => delete sanitized[h]); + ]; + + forbiddenHeaders.forEach(h => delete sanitized[h]); return sanitized; } @@ -323,7 +415,7 @@ class RequestProcessor { controller.abort(); } } -} // <--- Critical! Ensure this bracket exists +} class ProxySystem extends EventTarget { constructor(websocketEndpoint) { @@ -398,9 +490,12 @@ class ProxySystem extends EventTarget { throw new DOMException("The user aborted a request.", "AbortError"); } - this._transmitHeaders(response, operationId); + this._transmitHeaders(response, operationId, requestSpec.headers?.host); const reader = response.body.getReader(); const textDecoder = new TextDecoder(); + const contentType = response.headers.get("content-type") || ""; + const isText = contentType.includes("text/") || contentType.includes("application/json"); + let fullBody = ""; // --- Core modification: Correctly dispatch streaming and non-streaming data inside the loop --- @@ -414,21 +509,23 @@ class ProxySystem extends EventTarget { cancelTimeout(); - const chunk = textDecoder.decode(value, { stream: true }); - - if (mode === "real") { - // Streaming mode: immediately forward each data chunk - this._transmitChunk(chunk, operationId); + if (isText) { + const chunk = textDecoder.decode(value, { stream: true }); + if (mode === "real") { + this._transmitChunk(chunk, operationId); + } else { + fullBody += chunk; + } } else { - // fake mode - // Non-streaming mode: concatenate data chunks, wait to forward all at once at the end - fullBody += chunk; + // Binary data: use Base64 to ensure WebSocket safety + const base64Chunk = btoa(String.fromCharCode(...value)); + this._transmitChunk(base64Chunk, operationId, true); // true = isBinary } } Logger.output("Data stream read complete."); - if (mode === "fake") { + if (mode === "fake" && isText) { // In non-streaming mode, after loop ends, forward the concatenated complete response body this._transmitChunk(fullBody, operationId); } @@ -450,10 +547,25 @@ class ProxySystem extends EventTarget { } } - _transmitHeaders(response, operationId) { + _transmitHeaders(response, operationId, proxyHost) { const headerMap = {}; response.headers.forEach((v, k) => { - headerMap[k] = v; + const lowerKey = k.toLowerCase(); + if ((lowerKey === "location" || lowerKey === "x-goog-upload-url") && v.includes("googleapis.com")) { + try { + const urlObj = new URL(v); + const host = proxyHost || location.host; + const separator = urlObj.search ? "&" : "?"; + const newSearch = `${urlObj.search}${separator}__proxy_host__=${urlObj.host}`; + const newUrl = `${location.protocol}//${host}${urlObj.pathname}${newSearch}`; + headerMap[k] = newUrl; + Logger.output(`Rewriting header ${k}: ${v} -> ${headerMap[k]}`); + } catch (e) { + headerMap[k] = v; + } + } else { + headerMap[k] = v; + } }); this.connectionManager.transmit({ event_type: "response_headers", @@ -463,12 +575,13 @@ class ProxySystem extends EventTarget { }); } - _transmitChunk(chunk, operationId) { - if (!chunk) return; + _transmitChunk(data, operationId, isBinary = false) { + if (!data) return; this.connectionManager.transmit({ - data: chunk, + data: data, event_type: "chunk", request_id: operationId, + is_binary: isBinary, }); } diff --git a/src/core/ProxyServerSystem.js b/src/core/ProxyServerSystem.js index 1316e4c..dcc813e 100644 --- a/src/core/ProxyServerSystem.js +++ b/src/core/ProxyServerSystem.js @@ -10,6 +10,8 @@ const { EventEmitter } = require("events"); const express = require("express"); const WebSocket = require("ws"); const http = require("http"); +const https = require("https"); +const fs = require("fs"); const net = require("net"); const { URL } = require("url"); @@ -157,9 +159,7 @@ class ProxyServerSystem extends EventEmitter { // Allow access if session is authenticated (e.g. browser accessing /vnc or API from UI) if (req.session && req.session.isAuthenticated) { - if (req.path === "/vnc") { - return next(); - } + return next(); } const serverApiKeys = this.config.apiKeys; @@ -205,7 +205,29 @@ class ProxyServerSystem extends EventEmitter { async _startHttpServer() { const app = this._createExpressApp(); - this.httpServer = http.createServer(app); + + if (this.config.sslKeyPath && this.config.sslCertPath) { + try { + if (fs.existsSync(this.config.sslKeyPath) && fs.existsSync(this.config.sslCertPath)) { + const options = { + key: fs.readFileSync(this.config.sslKeyPath), + cert: fs.readFileSync(this.config.sslCertPath), + }; + this.httpServer = https.createServer(options, app); + this.logger.info("[System] Starting in HTTPS mode..."); + } else { + this.logger.warn( + "[System] SSL file paths provided but files not found. Falling back to HTTP." + ); + this.httpServer = http.createServer(app); + } + } catch (error) { + this.logger.error(`[System] Failed to load SSL files: ${error.message}. Falling back to HTTP.`); + this.httpServer = http.createServer(app); + } + } else { + this.httpServer = http.createServer(app); + } this.httpServer.on("upgrade", (req, socket) => { const pathname = new URL(req.url, `http://${req.headers.host}`).pathname; @@ -280,39 +302,91 @@ class ProxyServerSystem extends EventEmitter { _createExpressApp() { const app = express(); + // Request logging (Moved to top for better debugging) + app.use((req, res, next) => { + if ( + req.path !== "/api/status" && + req.path !== "/" && + req.path !== "/favicon.ico" && + req.path !== "/login" && + req.path !== "/health" + ) { + this.logger.info(`[Entrypoint] Received a request: ${req.method} ${req.path}`); + } + next(); + }); + // CORS middleware app.use((req, res, next) => { res.header("Access-Control-Allow-Origin", "*"); res.header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS"); + res.header("Access-Control-Allow-Private-Network", "true"); res.header( "Access-Control-Allow-Headers", "Content-Type, Authorization, x-requested-with, x-api-key, x-goog-api-key, x-goog-api-client, x-user-agent," + - " origin, accept, baggage, sentry-trace, openai-organization, openai-project, openai-beta, x-stainless-lang, " + - "x-stainless-package-version, x-stainless-os, x-stainless-arch, x-stainless-runtime, x-stainless-runtime-version, " + - "x-stainless-retry-count, x-stainless-timeout, sec-ch-ua, sec-ch-ua-mobile, sec-ch-ua-platform" + " origin, accept, baggage, sentry-trace, openai-organization, openai-project, openai-beta, x-stainless-lang, " + + "x-stainless-package-version, x-stainless-os, x-stainless-arch, x-stainless-runtime, x-stainless-runtime-version, " + + "x-stainless-retry-count, x-stainless-timeout, sec-ch-ua, sec-ch-ua-mobile, sec-ch-ua-platform, " + + "x-goog-upload-protocol, x-goog-upload-command, x-goog-upload-header-content-length, " + + "x-goog-upload-header-content-type, x-goog-upload-url, x-goog-upload-offset, x-goog-upload-status" ); + + // Expose all common Headers, including upload related ones (matched from BuildProxy) + res.header("Access-Control-Expose-Headers", "*"); + res.header( + "Access-Control-Expose-Headers", + "x-goog-upload-url, x-goog-upload-status, x-goog-upload-chunk-granularity, " + + "x-goog-upload-control-url, x-goog-upload-command, x-goog-upload-content-type, " + + "x-goog-upload-protocol, x-goog-upload-file-name, x-goog-upload-offset, " + + "date, content-type, content-length, location" + ); + if (req.method === "OPTIONS") { return res.sendStatus(204); } next(); }); - // Request logging + // Manual body collection middleware (BuildProxy style) + // Collects the entire raw body into req.rawBody as a Buffer + // Also attempts to parse JSON into req.body for compatibility app.use((req, res, next) => { - if ( - req.path !== "/api/status" && - req.path !== "/" && - req.path !== "/favicon.ico" && - req.path !== "/login" && - req.path !== "/health" - ) { - this.logger.info(`[Entrypoint] Received a request: ${req.method} ${req.path}`); + if (req.method === "GET" || req.method === "OPTIONS" || req.method === "HEAD") { + return next(); } - next(); - }); - app.use(express.json({ limit: "100mb" })); - app.use(express.urlencoded({ extended: true })); + const chunks = []; + req.on("data", chunk => chunks.push(chunk)); + req.on("end", () => { + req.rawBody = Buffer.concat(chunks); + + // Try to parse JSON for req.body compatibility + if (req.headers["content-type"]?.includes("application/json")) { + try { + req.body = JSON.parse(req.rawBody.toString()); + } catch (e) { + // Not valid JSON, keep req.body undefined or empty + req.body = {}; + } + } else if (req.headers["content-type"]?.includes("application/x-www-form-urlencoded")) { + try { + const qs = require("querystring"); + req.body = qs.parse(req.rawBody.toString()); + } catch (e) { + req.body = {}; + } + } else { + req.body = {}; + } + + next(); + }); + + req.on("error", (err) => { + this.logger.error(`[System] Request stream error: ${err.message}`); + next(err); + }); + }); // Serve static files from ui/dist (Vite build output) const path = require("path"); @@ -362,7 +436,7 @@ class ProxyServerSystem extends EventEmitter { app.get("/vnc", (req, res) => { res.status(400).send( "Error: WebSocket connection failed. " + - "If you are using a proxy (like Nginx), ensure it is configured to forward 'Upgrade' and 'Connection' headers." + "If you are using a proxy (like Nginx), ensure it is configured to forward 'Upgrade' and 'Connection' headers." ); }); diff --git a/src/core/RequestHandler.js b/src/core/RequestHandler.js index f9239ea..41f3b30 100644 --- a/src/core/RequestHandler.js +++ b/src/core/RequestHandler.js @@ -305,6 +305,7 @@ class RequestHandler { } }); + this.logger.info(`[Request] Incoming ${req.method} ${req.path} (ID: ${requestId})`); const proxyRequest = this._buildProxyRequest(req, requestId); proxyRequest.is_generative = isGenerativeRequest; const messageQueue = this.connectionRegistry.createMessageQueue(requestId); @@ -321,11 +322,11 @@ class RequestHandler { if (this.serverSystem.streamingMode === "fake") { await this._handlePseudoStreamResponse(proxyRequest, messageQueue, req, res); } else { - await this._handleRealStreamResponse(proxyRequest, messageQueue, res); + await this._handleRealStreamResponse(proxyRequest, messageQueue, req, res); } } else { proxyRequest.streaming_mode = "fake"; - await this._handleNonStreamResponse(proxyRequest, messageQueue, res); + await this._handleNonStreamResponse(proxyRequest, messageQueue, req, res); } } catch (error) { this._handleRequestError(error, res); @@ -689,7 +690,7 @@ class RequestHandler { } } - async _handleRealStreamResponse(proxyRequest, messageQueue, res) { + async _handleRealStreamResponse(proxyRequest, messageQueue, req, res) { this.logger.info(`[Request] Request dispatched to browser for processing...`); this._forwardRequest(proxyRequest); const headerMessage = await messageQueue.dequeue(); @@ -715,7 +716,7 @@ class RequestHandler { this.authSwitcher.failureCount = 0; } - this._setResponseHeaders(res, headerMessage); + this._setResponseHeaders(res, headerMessage, req); this.logger.info("[Request] Starting streaming transmission..."); try { let lastChunk = ""; @@ -728,8 +729,11 @@ class RequestHandler { break; } if (dataMessage.data) { - res.write(dataMessage.data); - lastChunk = dataMessage.data; + const writeData = dataMessage.is_binary + ? Buffer.from(dataMessage.data, "base64") + : dataMessage.data; + res.write(writeData); + if (!dataMessage.is_binary) lastChunk = dataMessage.data; } } try { @@ -757,7 +761,7 @@ class RequestHandler { } } - async _handleNonStreamResponse(proxyRequest, messageQueue, res) { + async _handleNonStreamResponse(proxyRequest, messageQueue, req, res) { this.logger.info(`[Request] Entering non-stream processing mode...`); try { @@ -783,7 +787,7 @@ class RequestHandler { } const headerMessage = result.message; - let fullBody = ""; + const chunks = []; let receiving = true; while (receiving) { const message = await messageQueue.dequeue(300000); @@ -793,12 +797,17 @@ class RequestHandler { break; } if (message.event_type === "chunk" && message.data) { - fullBody += message.data; + const chunkBuffer = message.is_binary + ? Buffer.from(message.data, "base64") + : Buffer.from(message.data); + chunks.push(chunkBuffer); } } + const fullBodyBuffer = Buffer.concat(chunks); + try { - const fullResponse = JSON.parse(fullBody); + const fullResponse = JSON.parse(fullBodyBuffer.toString()); const finishReason = fullResponse.candidates?.[0]?.finishReason || "UNKNOWN"; this.logger.info( `✅ [Request] Response ended, reason: ${finishReason}, request ID: ${proxyRequest.request_id}` @@ -807,9 +816,8 @@ class RequestHandler { // Ignore JSON parsing errors for finish reason } - res.status(headerMessage.status || 200) - .type("application/json") - .send(fullBody || "{}"); + this._setResponseHeaders(res, headerMessage, req); + res.send(fullBodyBuffer); this.logger.info(`[Request] Complete non-stream response sent to client.`); } catch (error) { @@ -963,11 +971,57 @@ class RequestHandler { } } - _setResponseHeaders(res, headerMessage) { + _setResponseHeaders(res, headerMessage, req) { res.status(headerMessage.status || 200); const headers = headerMessage.headers || {}; + + // Filter headers that might cause CORS conflicts + const forbiddenHeaders = [ + "access-control-allow-origin", + "access-control-allow-methods", + "access-control-allow-headers", + ]; + Object.entries(headers).forEach(([name, value]) => { - if (name.toLowerCase() !== "content-length") res.set(name, value); + const lowerName = name.toLowerCase(); + if (forbiddenHeaders.includes(lowerName)) return; + if (lowerName === "content-length") return; + + // Special handling for upload URL and redirects: point them back to this proxy + if ((lowerName === "x-goog-upload-url" || lowerName === "location") && value.includes("googleapis.com")) { + try { + const urlObj = new URL(value); + // Construct local proxy URL using configured host/port + // Note: The client (build.js) might have already embedded the original host in __proxy_host__ + // But wait, headerMessage comes from the BROWSER. + // If the Browser sends back the header as received from Google, then it's the GOOGLE URL. + // If the Browser rewrote it, it's the LOCALHOST URL. + // build.js `_transmitHeaders` rewrites it! + + // So `value` is `http://localhost:xxxx/...&__proxy_host__=google.com` (from Browser) + // We just need to ensure it points to *our* current listener address. + + // Use the Host header from the request to support remote clients (e.g. Docker IPs) + // If req.headers.host exists (standard), use it. Otherwise fallback to config. + let newAuthority; + if (req && req.headers && req.headers.host) { + newAuthority = req.headers.host; + } else { + const host = this.serverSystem.config.host === "0.0.0.0" ? "127.0.0.1" : this.serverSystem.config.host; + newAuthority = `${host}:${this.serverSystem.config.httpPort}`; + } + + const protocol = req.secure || (req.get && req.get("X-Forwarded-Proto") === "https") ? "https" : "http"; + const newUrl = `${protocol}://${newAuthority}${urlObj.pathname}${urlObj.search}`; + + this.logger.info(`[Response] Rewriting header ${name}: ${value} -> ${newUrl}`); + res.set(name, newUrl); + } catch (e) { + res.set(name, value); + } + } else { + res.set(name, value); + } }); } @@ -1159,6 +1213,8 @@ class RequestHandler { query_params: req.query || {}, request_id: requestId, streaming_mode: this.serverSystem.streamingMode, + body_b64: req.rawBody ? req.rawBody.toString("base64") : undefined, + is_generative: req.method === "POST" && (req.path.includes("generateContent") || req.path.includes("streamGenerateContent")), }; }