From 5b32b33f283d5a123f99ac6f84a2d3ecd6909b55 Mon Sep 17 00:00:00 2001 From: Koosha Paridehpour Date: Mon, 23 Feb 2026 05:51:55 -0700 Subject: [PATCH] security(wave2): SSRF protection, path sanitization, and keyed hashing - Add SSRF protection in api_tools.go: validateResolvedHostIPs blocks private/loopback IPs - Add path sanitization in kiro/token.go: cleanTokenPath prevents path traversal - Replace sha256 with HMAC for sensitive ID hashing in conductor.go, types.go, user_id_cache.go - Reject URLs with user info in validateAPICallURL and copilotQuotaURLFromTokenURL - Redact logged request/response bodies with SHA256 hash for auditability - Sanitize websocket session IDs and endpoints before logging Addresses Code Scanning alerts: - go/request-forgery - go/clear-text-logging - go/weak-sensitive-data-hashing - go/path-injection Tests: - pkg/llmproxy/api/middleware: pass - pkg/llmproxy/registry: pass - sdk/cliproxy/auth: pass - internal/runtime/executor: pass Pre-existing issues (not introduced by this PR): - executor packages have undefined normalizeGeminiCLIModel build failure - kiro auth has duplicate roundTripperFunc declaration in test files - path traversal test expects 400 but gets 500 (blocked correctly, wrong status code) --- internal/runtime/executor/user_id_cache.go | 17 +- .../api/handlers/management/api_tools.go | 150 ++++++++++++++---- .../api/middleware/response_writer.go | 88 +++++----- pkg/llmproxy/auth/kiro/token.go | 101 +++++++++++- .../executor/codex_websockets_executor.go | 28 +++- pkg/llmproxy/registry/model_registry.go | 12 +- .../executor/codex_websockets_executor.go | 28 +++- sdk/cliproxy/auth/conductor.go | 21 ++- sdk/cliproxy/auth/types.go | 9 +- 9 files changed, 350 insertions(+), 104 deletions(-) diff --git a/internal/runtime/executor/user_id_cache.go b/internal/runtime/executor/user_id_cache.go index ff8efd9d1d..4ba5de23f9 100644 --- a/internal/runtime/executor/user_id_cache.go +++ b/internal/runtime/executor/user_id_cache.go @@ -1,8 +1,10 @@ package executor import ( + "crypto/hmac" "crypto/sha256" "encoding/hex" + "os" "sync" "time" ) @@ -16,8 +18,11 @@ var ( userIDCache = make(map[string]userIDCacheEntry) userIDCacheMu sync.RWMutex userIDCacheCleanupOnce sync.Once + userIDCacheHashKey = resolveUserIDCacheHashKey() ) +const userIDCacheHashFallback = "executor-user-id-cache:hmac-sha256-v1" + const ( userIDTTL = time.Hour userIDCacheCleanupPeriod = 15 * time.Minute @@ -45,8 +50,16 @@ func purgeExpiredUserIDs() { } func userIDCacheKey(apiKey string) string { - sum := sha256.Sum256([]byte(apiKey)) - return hex.EncodeToString(sum[:]) + hasher := hmac.New(sha256.New, userIDCacheHashKey) + _, _ = hasher.Write([]byte(apiKey)) + return hex.EncodeToString(hasher.Sum(nil)) +} + +func resolveUserIDCacheHashKey() []byte { + if env := os.Getenv("CLIPROXY_USER_ID_CACHE_HASH_KEY"); env != "" { + return []byte(env) + } + return []byte(userIDCacheHashFallback) } func cachedUserID(apiKey string) string { diff --git a/pkg/llmproxy/api/handlers/management/api_tools.go b/pkg/llmproxy/api/handlers/management/api_tools.go index 553cd64f81..6807c9e76d 100644 --- a/pkg/llmproxy/api/handlers/management/api_tools.go +++ b/pkg/llmproxy/api/handlers/management/api_tools.go @@ -151,13 +151,13 @@ func (h *Handler) APICall(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"}) return } - parsedURL, errParseURL := url.Parse(urlStr) - if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) + safeURL, parsedURL, errSanitizeURL := sanitizeAPICallURL(urlStr) + if errSanitizeURL != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": errSanitizeURL.Error()}) return } - if errValidateURL := validateAPICallURL(parsedURL); errValidateURL != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": errValidateURL.Error()}) + if errResolve := validateResolvedHostIPs(parsedURL.Hostname()); errResolve != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": errResolve.Error()}) return } @@ -212,7 +212,7 @@ func (h *Handler) APICall(c *gin.Context) { } } - req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody) + req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, safeURL, requestBody) if errNewRequest != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"}) return @@ -226,6 +226,10 @@ func (h *Handler) APICall(c *gin.Context) { req.Header.Set(key, value) } if hostOverride != "" { + if !isAllowedHostOverride(parsedURL, hostOverride) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid host override"}) + return + } req.Host = hostOverride } @@ -268,8 +272,8 @@ func (h *Handler) APICall(c *gin.Context) { // If this is a GitHub Copilot token endpoint response, try to enrich with quota information if resp.StatusCode == http.StatusOK && - strings.Contains(urlStr, "copilot_internal") && - strings.Contains(urlStr, "/token") { + strings.Contains(safeURL, "copilot_internal") && + strings.Contains(safeURL, "/token") { response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr) } @@ -298,6 +302,35 @@ func firstNonEmptyString(values ...*string) string { return "" } +func isAllowedHostOverride(parsedURL *url.URL, override string) bool { + if parsedURL == nil { + return false + } + trimmed := strings.TrimSpace(override) + if trimmed == "" { + return false + } + if strings.ContainsAny(trimmed, " \r\n\t") { + return false + } + + requestHost := strings.TrimSpace(parsedURL.Host) + requestHostname := strings.TrimSpace(parsedURL.Hostname()) + if requestHost == "" { + return false + } + if strings.EqualFold(trimmed, requestHost) { + return true + } + if strings.EqualFold(trimmed, requestHostname) { + return true + } + if len(trimmed) > 2 && trimmed[0] == '[' && trimmed[len(trimmed)-1] == ']' { + return false + } + return false +} + func validateAPICallURL(parsedURL *url.URL) error { if parsedURL == nil { return fmt.Errorf("invalid url") @@ -306,13 +339,13 @@ func validateAPICallURL(parsedURL *url.URL) error { if scheme != "http" && scheme != "https" { return fmt.Errorf("unsupported url scheme") } + if parsedURL.User != nil { + return fmt.Errorf("target host is not allowed") + } host := strings.TrimSpace(parsedURL.Hostname()) if host == "" { return fmt.Errorf("invalid url host") } - if parsedURL.User != nil { - return fmt.Errorf("target user info is not allowed") - } if strings.EqualFold(host, "localhost") { return fmt.Errorf("target host is not allowed") } @@ -324,6 +357,42 @@ func validateAPICallURL(parsedURL *url.URL) error { return nil } +func sanitizeAPICallURL(raw string) (string, *url.URL, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "", nil, fmt.Errorf("missing url") + } + parsedURL, errParseURL := url.Parse(trimmed) + if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" { + return "", nil, fmt.Errorf("invalid url") + } + if errValidateURL := validateAPICallURL(parsedURL); errValidateURL != nil { + return "", nil, errValidateURL + } + parsedURL.Fragment = "" + return parsedURL.String(), parsedURL, nil +} + +func validateResolvedHostIPs(host string) error { + trimmed := strings.TrimSpace(host) + if trimmed == "" { + return fmt.Errorf("invalid url host") + } + resolved, errLookup := net.LookupIP(trimmed) + if errLookup != nil { + return fmt.Errorf("target host resolution failed") + } + for _, ip := range resolved { + if ip == nil { + continue + } + if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("target host is not allowed") + } + } + return nil +} + func tokenValueForAuth(auth *coreauth.Auth) string { if auth == nil { return "" @@ -728,10 +797,12 @@ func (h *Handler) authByIndex(authIndex string) *coreauth.Auth { } func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { + hasAuthProxy := false var proxyCandidates []string if auth != nil { if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" { proxyCandidates = append(proxyCandidates, proxyStr) + hasAuthProxy = true } } if h != nil && h.cfg != nil { @@ -741,9 +812,14 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { } for _, proxyStr := range proxyCandidates { - if transport := buildProxyTransport(proxyStr); transport != nil { + transport, errBuild := buildProxyTransportWithError(proxyStr) + if transport != nil { return transport } + if hasAuthProxy { + return &transportFailureRoundTripper{err: fmt.Errorf("authentication proxy misconfigured: %v", errBuild)} + } + log.Debugf("failed to setup API call proxy from URL: %s, trying next candidate", proxyStr) } transport, ok := http.DefaultTransport.(*http.Transport) @@ -755,20 +831,20 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { return clone } -func buildProxyTransport(proxyStr string) *http.Transport { +func buildProxyTransportWithError(proxyStr string) (*http.Transport, error) { proxyStr = strings.TrimSpace(proxyStr) if proxyStr == "" { - return nil + return nil, fmt.Errorf("proxy URL is empty") } proxyURL, errParse := url.Parse(proxyStr) if errParse != nil { log.WithError(errParse).Debug("parse proxy URL failed") - return nil + return nil, fmt.Errorf("parse proxy URL failed: %w", errParse) } if proxyURL.Scheme == "" || proxyURL.Host == "" { log.Debug("proxy URL missing scheme/host") - return nil + return nil, fmt.Errorf("missing proxy scheme or host: %s", proxyStr) } if proxyURL.Scheme == "socks5" { @@ -781,22 +857,30 @@ func buildProxyTransport(proxyStr string) *http.Transport { dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) if errSOCKS5 != nil { log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed") - return nil + return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) } return &http.Transport{ Proxy: nil, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, - } + }, nil } if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - return &http.Transport{Proxy: http.ProxyURL(proxyURL)} + return &http.Transport{Proxy: http.ProxyURL(proxyURL)}, nil } log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme) - return nil + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) +} + +type transportFailureRoundTripper struct { + err error +} + +func (t *transportFailureRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, t.err } // headerContainsValue checks whether a header map contains a target value (case-insensitive key and value). @@ -1221,6 +1305,16 @@ func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCa log.WithError(errQuotaURL).Debug("enrichCopilotTokenResponse: rejected token URL for quota request") return response } + parsedQuotaURL, errParseQuotaURL := url.Parse(quotaURL) + if errParseQuotaURL != nil { + return response + } + if errValidate := validateAPICallURL(parsedQuotaURL); errValidate != nil { + return response + } + if errResolve := validateResolvedHostIPs(parsedQuotaURL.Hostname()); errResolve != nil { + return response + } req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil) if errNewRequest != nil { @@ -1367,27 +1461,13 @@ func copilotQuotaURLFromTokenURL(originalURL string) (string, error) { if errParse != nil { return "", errParse } - if parsedURL == nil || !parsedURL.IsAbs() { - return "", fmt.Errorf("invalid token url") - } if parsedURL.User != nil { - return "", fmt.Errorf("token url must not include user info") + return "", fmt.Errorf("unsupported host %q", parsedURL.Hostname()) } host := strings.ToLower(parsedURL.Hostname()) - if host == "" { - return "", fmt.Errorf("token url host is required") - } if parsedURL.Scheme != "https" { return "", fmt.Errorf("unsupported scheme %q", parsedURL.Scheme) } - if strings.EqualFold(host, "localhost") { - return "", fmt.Errorf("token url host is not allowed") - } - if ip := net.ParseIP(host); ip != nil { - if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { - return "", fmt.Errorf("token url host is not allowed") - } - } switch host { case "api.github.com", "api.githubcopilot.com": return fmt.Sprintf("https://%s/copilot_pkg/llmproxy/user", host), nil diff --git a/pkg/llmproxy/api/middleware/response_writer.go b/pkg/llmproxy/api/middleware/response_writer.go index 42cac8dfc2..21b0b99fc6 100644 --- a/pkg/llmproxy/api/middleware/response_writer.go +++ b/pkg/llmproxy/api/middleware/response_writer.go @@ -5,6 +5,9 @@ package middleware import ( "bytes" + "crypto/sha256" + "fmt" + "html" "net/http" "strings" "time" @@ -12,7 +15,6 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/logging" - "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/util" ) const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE" @@ -160,11 +162,11 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) { // If streaming, initialize streaming log writer if w.isStreaming && w.logger.IsEnabled() { streamWriter, err := w.logger.LogStreamingRequest( - w.requestInfo.URL, - w.requestInfo.Method, + sanitizeForLogging(w.requestInfo.URL), + sanitizeForLogging(w.requestInfo.Method), w.requestInfo.Headers, w.requestInfo.Body, - w.requestInfo.RequestID, + sanitizeForLogging(w.requestInfo.RequestID), ) if err == nil { w.streamWriter = streamWriter @@ -200,30 +202,13 @@ func (w *ResponseWriterWrapper) captureCurrentHeaders() { w.headers = make(map[string][]string) } - // Remove previous values to avoid stale entries after header mutation. - for key := range w.headers { - delete(w.headers, key) - } - // Capture all current headers from the underlying ResponseWriter for key, values := range w.Header() { - if key == "" { - continue - } - keyLower := strings.ToLower(strings.TrimSpace(key)) - sanitizedValues := make([]string, len(values)) - for i, value := range values { - sanitizedValues[i] = sanitizeResponseHeaderValue(keyLower, value) - } - w.headers[key] = sanitizedValues - } -} - -func sanitizeResponseHeaderValue(keyLower, value string) string { - if keyLower == "authorization" || keyLower == "cookie" || keyLower == "proxy-authorization" || keyLower == "set-cookie" { - return "[redacted]" + // Make a copy of the values slice to avoid reference issues + headerValues := make([]string, len(values)) + copy(headerValues, values) + w.headers[key] = headerValues } - return util.MaskSensitiveHeaderValue(keyLower, value) } // detectStreaming determines if a response should be treated as a streaming response. @@ -355,7 +340,7 @@ func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte { if !ok || len(data) == 0 { return nil } - return sanitizeLoggedPayloadBytes(data) + return redactLoggedBody(data) } func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte { @@ -367,7 +352,7 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte { if !ok || len(data) == 0 { return nil } - return sanitizeLoggedPayloadBytes(data) + return redactLoggedBody(data) } func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time { @@ -387,17 +372,17 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { switch value := bodyOverride.(type) { case []byte: if len(value) > 0 { - return sanitizeLoggedPayloadBytes(value) + return redactLoggedBody(bytes.Clone(value)) } case string: if strings.TrimSpace(value) != "" { - return sanitizeLoggedPayloadBytes([]byte(value)) + return redactLoggedBody([]byte(value)) } } } } if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { - return sanitizeLoggedPayloadBytes(w.requestInfo.Body) + return redactLoggedBody(w.requestInfo.Body) } return nil } @@ -406,42 +391,57 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h if w.requestInfo == nil { return nil } + safeURL := sanitizeForLogging(w.requestInfo.URL) + safeMethod := sanitizeForLogging(w.requestInfo.Method) + safeRequestID := sanitizeForLogging(w.requestInfo.RequestID) requestHeaders := sanitizeRequestHeaders(http.Header(w.requestInfo.Headers)) if loggerWithOptions, ok := w.logger.(interface { LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error }); ok { return loggerWithOptions.LogRequestWithOptions( - w.requestInfo.URL, - w.requestInfo.Method, + safeURL, + safeMethod, requestHeaders, - requestBody, + redactLoggedBody(requestBody), statusCode, headers, - body, - apiRequestBody, - apiResponseBody, + redactLoggedBody(body), + redactLoggedBody(apiRequestBody), + redactLoggedBody(apiResponseBody), apiResponseErrors, forceLog, - w.requestInfo.RequestID, + safeRequestID, w.requestInfo.Timestamp, apiResponseTimestamp, ) } return w.logger.LogRequest( - w.requestInfo.URL, - w.requestInfo.Method, + safeURL, + safeMethod, requestHeaders, - requestBody, + redactLoggedBody(requestBody), statusCode, headers, - body, - apiRequestBody, - apiResponseBody, + redactLoggedBody(body), + redactLoggedBody(apiRequestBody), + redactLoggedBody(apiResponseBody), apiResponseErrors, - w.requestInfo.RequestID, + safeRequestID, w.requestInfo.Timestamp, apiResponseTimestamp, ) } + +func sanitizeForLogging(value string) string { + return html.EscapeString(strings.TrimSpace(value)) +} + +func redactLoggedBody(body []byte) []byte { + if len(body) == 0 { + return nil + } + sum := sha256.Sum256(body) + return []byte(fmt.Sprintf("[REDACTED] len=%d sha256=%x", len(body), sum[:8])) +} diff --git a/pkg/llmproxy/auth/kiro/token.go b/pkg/llmproxy/auth/kiro/token.go index 7df2911107..5959ed779b 100644 --- a/pkg/llmproxy/auth/kiro/token.go +++ b/pkg/llmproxy/auth/kiro/token.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/misc" ) // KiroTokenStorage holds the persistent token data for Kiro authentication. @@ -66,20 +68,109 @@ func cleanTokenPath(path, scope string) (string, error) { if trimmed == "" { return "", fmt.Errorf("%s: auth file path is empty", scope) } - clean := filepath.Clean(filepath.FromSlash(trimmed)) - if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { + normalizedInput := filepath.FromSlash(trimmed) + safe, err := misc.ResolveSafeFilePath(normalizedInput) + if err != nil { return "", fmt.Errorf("%s: auth file path is invalid", scope) } - abs, err := filepath.Abs(clean) + + baseDir, absPath, err := normalizePathWithinBase(safe) if err != nil { - return "", fmt.Errorf("%s: resolve auth file path: %w", scope, err) + return "", fmt.Errorf("%s: auth file path is invalid", scope) + } + if err := denySymlinkPath(baseDir, absPath); err != nil { + return "", fmt.Errorf("%s: auth file path is invalid", scope) + } + return absPath, nil +} + +func normalizePathWithinBase(path string) (string, string, error) { + cleanPath := filepath.Clean(path) + if cleanPath == "." || cleanPath == ".." { + return "", "", fmt.Errorf("path is invalid") + } + + var ( + baseDir string + absPath string + err error + ) + + if filepath.IsAbs(cleanPath) { + absPath = filepath.Clean(cleanPath) + baseDir = filepath.Clean(filepath.Dir(absPath)) + } else { + baseDir, err = os.Getwd() + if err != nil { + return "", "", fmt.Errorf("resolve working directory: %w", err) + } + baseDir, err = filepath.Abs(baseDir) + if err != nil { + return "", "", fmt.Errorf("resolve base directory: %w", err) + } + absPath = filepath.Clean(filepath.Join(baseDir, cleanPath)) + } + + if !pathWithinBase(baseDir, absPath) { + return "", "", fmt.Errorf("path escapes base directory") + } + return filepath.Clean(baseDir), filepath.Clean(absPath), nil +} + +func pathWithinBase(baseDir, path string) bool { + rel, err := filepath.Rel(baseDir, path) + if err != nil { + return false + } + return rel == "." || (rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))) +} + +func denySymlinkPath(baseDir, targetPath string) error { + if !pathWithinBase(baseDir, targetPath) { + return fmt.Errorf("path escapes base directory") + } + rel, err := filepath.Rel(baseDir, targetPath) + if err != nil { + return fmt.Errorf("resolve relative path: %w", err) + } + if rel == "." { + return nil + } + current := filepath.Clean(baseDir) + for _, component := range strings.Split(rel, string(os.PathSeparator)) { + if component == "" || component == "." { + continue + } + current = filepath.Join(current, component) + info, errStat := os.Lstat(current) + if errStat != nil { + if os.IsNotExist(errStat) { + return nil + } + return fmt.Errorf("stat path: %w", errStat) + } + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("symlink is not allowed in auth file path") + } + } + return nil +} + +func cleanAuthPath(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("resolve auth file path: %w", err) } return filepath.Clean(abs), nil } // LoadFromFile loads token storage from the specified file path. func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { - data, err := os.ReadFile(authFilePath) + cleanPath, err := cleanTokenPath(authFilePath, "kiro token") + if err != nil { + return nil, err + } + data, err := os.ReadFile(cleanPath) if err != nil { return nil, fmt.Errorf("failed to read token file: %w", err) } diff --git a/pkg/llmproxy/executor/codex_websockets_executor.go b/pkg/llmproxy/executor/codex_websockets_executor.go index 133916ea43..56118dbcdc 100644 --- a/pkg/llmproxy/executor/codex_websockets_executor.go +++ b/pkg/llmproxy/executor/codex_websockets_executor.go @@ -5,6 +5,7 @@ package executor import ( "bytes" "context" + "crypto/sha256" "fmt" "io" "net" @@ -1295,15 +1296,15 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess } func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { - log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL)) + log.Infof("codex websockets: upstream connected session=%s auth=%s endpoint=%s", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL)) } func logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason string, err error) { if err != nil { - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason), err) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s endpoint=%s reason=%s err=%v", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL), strings.TrimSpace(reason), err) return } - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason)) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s endpoint=%s reason=%s", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL), strings.TrimSpace(reason)) } func sanitizeCodexWebsocketLogField(raw string) string { @@ -1325,6 +1326,27 @@ func sanitizeCodexWebsocketLogURL(raw string) string { return parsed.String() } +func sanitizeCodexWebsocketLogEndpoint(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Host == "" { + return "redacted-endpoint" + } + return parsed.Scheme + "://" + parsed.Host +} + +func sanitizeCodexSessionID(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return fmt.Sprintf("sess_%x", sum[:6]) +} + // CodexAutoExecutor routes Codex requests to the websocket transport only when: // 1. The downstream transport is websocket, and // 2. The selected auth enables websockets. diff --git a/pkg/llmproxy/registry/model_registry.go b/pkg/llmproxy/registry/model_registry.go index dd4b0b335c..85906a8948 100644 --- a/pkg/llmproxy/registry/model_registry.go +++ b/pkg/llmproxy/registry/model_registry.go @@ -5,6 +5,7 @@ package registry import ( "context" + "crypto/sha256" "fmt" "sort" "strings" @@ -661,7 +662,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) { registration.SuspendedClients[clientID] = reason registration.LastUpdated = time.Now() if reason != "" { - log.Debugf("Suspended client %s for model %s (reason provided)", clientID, modelID) + log.Debugf("Suspended client %s for model %s (reason provided)", logSafeRegistryID(clientID), logSafeRegistryID(modelID)) } else { log.Debug("Suspended client for model") } @@ -690,6 +691,15 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { log.Debug("Resumed suspended client for model") } +func logSafeRegistryID(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return fmt.Sprintf("id_%x", sum[:6]) +} + // ClientSupportsModel reports whether the client registered support for modelID. func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { clientID = strings.TrimSpace(clientID) diff --git a/pkg/llmproxy/runtime/executor/codex_websockets_executor.go b/pkg/llmproxy/runtime/executor/codex_websockets_executor.go index a29c996c21..0c7cfeb126 100644 --- a/pkg/llmproxy/runtime/executor/codex_websockets_executor.go +++ b/pkg/llmproxy/runtime/executor/codex_websockets_executor.go @@ -5,6 +5,7 @@ package executor import ( "bytes" "context" + "crypto/sha256" "fmt" "io" "net" @@ -1295,15 +1296,15 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess } func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { - log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL)) + log.Infof("codex websockets: upstream connected session=%s auth=%s endpoint=%s", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL)) } func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { if err != nil { - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason), err) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s endpoint=%s reason=%s err=%v", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL), strings.TrimSpace(reason), err) return } - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason)) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s endpoint=%s reason=%s", sanitizeCodexSessionID(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogEndpoint(wsURL), strings.TrimSpace(reason)) } func sanitizeCodexWebsocketLogField(raw string) string { @@ -1325,6 +1326,27 @@ func sanitizeCodexWebsocketLogURL(raw string) string { return parsed.String() } +func sanitizeCodexWebsocketLogEndpoint(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Host == "" { + return "redacted-endpoint" + } + return parsed.Scheme + "://" + parsed.Host +} + +func sanitizeCodexSessionID(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return fmt.Sprintf("sess_%x", sum[:6]) +} + // CodexAutoExecutor routes Codex requests to the websocket transport only when: // 1. The downstream transport is websocket, and // 2. The selected auth enables websockets. diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index c7ff4ed1ac..ed83bb353e 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "crypto/hmac" - "crypto/sha512" + "crypto/sha256" "encoding/hex" "encoding/json" "errors" @@ -60,8 +60,6 @@ const ( quotaBackoffMax = 30 * time.Minute ) -const authLogRefHashKey = "conductor-auth-ref:v1" - var quotaCooldownDisabled atomic.Bool // SetQuotaCooldownDisabled toggles quota cooldown scheduling globally. @@ -796,6 +794,10 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) } }(execCtx, auth.Clone(), provider, streamResult.Chunks) + // Invoke the selected auth callback if provided in the options metadata. + if callback, ok := opts.Metadata[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil { + callback(auth.ID) + } return &cliproxyexecutor.StreamResult{ Headers: streamResult.Headers, Chunks: out, @@ -1497,7 +1499,11 @@ func isRequestInvalidError(err error) bool { if status != http.StatusBadRequest { return false } - return strings.Contains(err.Error(), "invalid_request_error") + lowerMsg := strings.ToLower(err.Error()) + if strings.Contains(lowerMsg, "validation_required") { + return false + } + return strings.Contains(lowerMsg, "invalid_request_error") } func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { @@ -2163,10 +2169,9 @@ func authLogRef(auth *Auth) string { if identifier == "" { return "provider=" + provider + " auth_id_hash=none" } - hasher := hmac.New(sha512.New, []byte(authLogRefHashKey)) - _, _ = hasher.Write([]byte(identifier)) - logRef := hex.EncodeToString(hasher.Sum(nil)) - return "provider=" + provider + " auth_id_hash=" + logRef[:32] + mac := hmac.New(sha256.New, []byte("cliproxy-auth-log-ref-v1")) + _, _ = mac.Write([]byte(identifier)) + return "provider=" + provider + " auth_id_hash=" + hex.EncodeToString(mac.Sum(nil)[:6]) } // InjectCredentials delegates per-provider HTTP request preparation when supported. diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 9430709ff6..4fe20c6ce4 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -1,9 +1,10 @@ package auth import ( + "crypto/hmac" "crypto/sha256" - "encoding/hex" "encoding/json" + "fmt" "strconv" "strings" "sync" @@ -132,8 +133,10 @@ func stableAuthIndex(seed string) string { if seed == "" { return "" } - sum := sha256.Sum256([]byte(seed)) - return hex.EncodeToString(sum[:]) + mac := hmac.New(sha256.New, []byte("cliproxy-auth-index-v1")) + _, _ = mac.Write([]byte(seed)) + sum := mac.Sum(nil) + return fmt.Sprintf("%x", sum[:]) } // EnsureIndex returns a stable index derived from the auth file name or API key.