diff --git a/pkg/llmproxy/api/handlers/management/api_tools.go b/pkg/llmproxy/api/handlers/management/api_tools.go index b3419bd013..24268ba303 100644 --- a/pkg/llmproxy/api/handlers/management/api_tools.go +++ b/pkg/llmproxy/api/handlers/management/api_tools.go @@ -156,6 +156,10 @@ func (h *Handler) APICall(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"}) return } + if errValidateURL := validateAPICallURL(parsedURL); errValidateURL != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": errValidateURL.Error()}) + return + } authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal) auth := h.authByIndex(authIndex) @@ -294,6 +298,29 @@ func firstNonEmptyString(values ...*string) string { return "" } +func validateAPICallURL(parsedURL *url.URL) error { + if parsedURL == nil { + return fmt.Errorf("invalid url") + } + scheme := strings.ToLower(strings.TrimSpace(parsedURL.Scheme)) + if scheme != "http" && scheme != "https" { + return fmt.Errorf("unsupported url scheme") + } + host := strings.TrimSpace(parsedURL.Hostname()) + if host == "" { + return fmt.Errorf("invalid url host") + } + if strings.EqualFold(host, "localhost") { + return fmt.Errorf("target host is not allowed") + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return fmt.Errorf("target host is not allowed") + } + } + return nil +} + func tokenValueForAuth(auth *coreauth.Auth) string { if auth == nil { return "" @@ -1179,12 +1206,11 @@ func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCa // Fetch quota information from /copilot_pkg/llmproxy/user // Derive the base URL from the original token request to support proxies and test servers - parsedURL, errParse := url.Parse(originalURL) - if errParse != nil { - log.WithError(errParse).Debug("enrichCopilotTokenResponse: failed to parse URL") + quotaURL, errQuotaURL := copilotQuotaURLFromTokenURL(originalURL) + if errQuotaURL != nil { + log.WithError(errQuotaURL).Debug("enrichCopilotTokenResponse: rejected token URL for quota request") return response } - quotaURL := fmt.Sprintf("%s://%s/copilot_pkg/llmproxy/user", parsedURL.Scheme, parsedURL.Host) req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil) if errNewRequest != nil { @@ -1325,3 +1351,20 @@ func (h *Handler) enrichCopilotTokenResponse(ctx context.Context, response apiCa return response } + +func copilotQuotaURLFromTokenURL(originalURL string) (string, error) { + parsedURL, errParse := url.Parse(strings.TrimSpace(originalURL)) + if errParse != nil { + return "", errParse + } + host := strings.ToLower(parsedURL.Hostname()) + if parsedURL.Scheme != "https" { + return "", fmt.Errorf("unsupported scheme %q", parsedURL.Scheme) + } + switch host { + case "api.github.com", "api.githubcopilot.com": + return fmt.Sprintf("https://%s/copilot_pkg/llmproxy/user", host), nil + default: + return "", fmt.Errorf("unsupported host %q", parsedURL.Hostname()) + } +} diff --git a/pkg/llmproxy/api/handlers/management/api_tools_test.go b/pkg/llmproxy/api/handlers/management/api_tools_test.go index 0096ad0017..e2bf46657d 100644 --- a/pkg/llmproxy/api/handlers/management/api_tools_test.go +++ b/pkg/llmproxy/api/handlers/management/api_tools_test.go @@ -1,6 +1,7 @@ package management import ( + "bytes" "context" "encoding/json" "io" @@ -17,6 +18,26 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +func TestAPICall_RejectsUnsafeHost(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + body := []byte(`{"method":"GET","url":"http://127.0.0.1:8080/ping"}`) + req := httptest.NewRequest(http.MethodPost, "/v0/management/api-call", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + h := &Handler{} + h.APICall(c) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d, body=%s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} + type memoryAuthStore struct { mu sync.Mutex items map[string]*coreauth.Auth @@ -303,3 +324,56 @@ func TestGetKiroQuotaWithChecker_MissingProfileARN(t *testing.T) { t.Fatalf("unexpected response body: %s", rec.Body.String()) } } + +func TestCopilotQuotaURLFromTokenURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tokenURL string + wantURL string + expectErr bool + }{ + { + name: "github_api", + tokenURL: "https://api.github.com/copilot_internal/v2/token", + wantURL: "https://api.github.com/copilot_pkg/llmproxy/user", + expectErr: false, + }, + { + name: "copilot_api", + tokenURL: "https://api.githubcopilot.com/copilot_internal/v2/token", + wantURL: "https://api.githubcopilot.com/copilot_pkg/llmproxy/user", + expectErr: false, + }, + { + name: "reject_http", + tokenURL: "http://api.github.com/copilot_internal/v2/token", + expectErr: true, + }, + { + name: "reject_untrusted_host", + tokenURL: "https://127.0.0.1/copilot_internal/v2/token", + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := copilotQuotaURLFromTokenURL(tt.tokenURL) + if tt.expectErr { + if err == nil { + t.Fatalf("expected error, got url=%q", got) + } + return + } + if err != nil { + t.Fatalf("copilotQuotaURLFromTokenURL returned error: %v", err) + } + if got != tt.wantURL { + t.Fatalf("copilotQuotaURLFromTokenURL = %q, want %q", got, tt.wantURL) + } + }) + } +} diff --git a/pkg/llmproxy/api/handlers/management/auth_files.go b/pkg/llmproxy/api/handlers/management/auth_files.go index 4e62818901..d9be804d14 100644 --- a/pkg/llmproxy/api/handlers/management/auth_files.go +++ b/pkg/llmproxy/api/handlers/management/auth_files.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "os" + "path" "path/filepath" "sort" "strconv" @@ -136,6 +137,11 @@ func isWebUIRequest(c *gin.Context) bool { } func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) { + targetURL, errTarget := validateCallbackForwarderTarget(targetBase) + if errTarget != nil { + return nil, fmt.Errorf("invalid callback target: %w", errTarget) + } + callbackForwardersMu.Lock() prev := callbackForwarders[port] if prev != nil { @@ -154,16 +160,16 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - target := targetBase + target := *targetURL if raw := r.URL.RawQuery; raw != "" { - if strings.Contains(target, "?") { - target = target + "&" + raw + if target.RawQuery != "" { + target.RawQuery = target.RawQuery + "&" + raw } else { - target = target + "?" + raw + target.RawQuery = raw } } w.Header().Set("Cache-Control", "no-store") - http.Redirect(w, r, target, http.StatusFound) + http.Redirect(w, r, target.String(), http.StatusFound) }) srv := &http.Server{ @@ -195,6 +201,38 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor return forwarder, nil } +func validateCallbackForwarderTarget(targetBase string) (*url.URL, error) { + trimmed := strings.TrimSpace(targetBase) + if trimmed == "" { + return nil, fmt.Errorf("target cannot be empty") + } + parsed, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("parse target: %w", err) + } + if !parsed.IsAbs() { + return nil, fmt.Errorf("target must be absolute") + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "http" && scheme != "https" { + return nil, fmt.Errorf("target scheme %q is not allowed", parsed.Scheme) + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return nil, fmt.Errorf("target host is required") + } + if ip := net.ParseIP(host); ip != nil { + if !ip.IsLoopback() { + return nil, fmt.Errorf("target host must be loopback") + } + return parsed, nil + } + if host != "localhost" { + return nil, fmt.Errorf("target host must be localhost or loopback") + } + return parsed, nil +} + func stopCallbackForwarder(port int) { callbackForwardersMu.Lock() forwarder := callbackForwarders[port] @@ -243,9 +281,7 @@ func (h *Handler) managementCallbackURL(path string) (string, error) { if h == nil || h.cfg == nil || h.cfg.Port <= 0 { return "", fmt.Errorf("server port is not configured") } - if !strings.HasPrefix(path, "/") { - path = "/" + path - } + path = normalizeManagementCallbackPath(path) scheme := "http" if h.cfg.TLS.Enable { scheme = "https" @@ -253,6 +289,28 @@ func (h *Handler) managementCallbackURL(path string) (string, error) { return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil } +func normalizeManagementCallbackPath(rawPath string) string { + normalized := strings.TrimSpace(rawPath) + normalized = strings.ReplaceAll(normalized, "\\", "/") + if idx := strings.IndexAny(normalized, "?#"); idx >= 0 { + normalized = normalized[:idx] + } + if normalized == "" { + return "/" + } + if !strings.HasPrefix(normalized, "/") { + normalized = "/" + normalized + } + normalized = path.Clean(normalized) + if normalized == "." { + return "/" + } + if !strings.HasPrefix(normalized, "/") { + return "/" + normalized + } + return normalized +} + func (h *Handler) ListAuthFiles(c *gin.Context) { if h == nil { c.JSON(500, gin.H{"error": "handler not initialized"}) @@ -510,8 +568,8 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool { // Download single auth file by name func (h *Handler) DownloadAuthFile(c *gin.Context) { - name := c.Query("name") - if name == "" || strings.Contains(name, string(os.PathSeparator)) { + name := strings.TrimSpace(c.Query("name")) + if name == "" { c.JSON(400, gin.H{"error": "invalid name"}) return } @@ -519,7 +577,11 @@ func (h *Handler) DownloadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "name must end with .json"}) return } - full := filepath.Join(h.cfg.AuthDir, name) + full, err := misc.ResolveSafeFilePathInDir(h.cfg.AuthDir, name) + if err != nil { + c.JSON(400, gin.H{"error": "invalid name"}) + return + } data, err := os.ReadFile(full) if err != nil { if os.IsNotExist(err) { @@ -569,7 +631,8 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { return } name := c.Query("name") - if name == "" || strings.Contains(name, string(os.PathSeparator)) { + name = strings.TrimSpace(name) + if name == "" { c.JSON(400, gin.H{"error": "invalid name"}) return } @@ -582,11 +645,10 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "failed to read body"}) return } - dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) - if !filepath.IsAbs(dst) { - if abs, errAbs := filepath.Abs(dst); errAbs == nil { - dst = abs - } + dst, err := misc.ResolveSafeFilePathInDir(h.cfg.AuthDir, name) + if err != nil { + c.JSON(400, gin.H{"error": "invalid name"}) + return } if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil { c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)}) @@ -639,16 +701,15 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "deleted": deleted}) return } - name := c.Query("name") - if name == "" || strings.Contains(name, string(os.PathSeparator)) { + name := strings.TrimSpace(c.Query("name")) + if name == "" { c.JSON(400, gin.H{"error": "invalid name"}) return } - full := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) - if !filepath.IsAbs(full) { - if abs, errAbs := filepath.Abs(full); errAbs == nil { - full = abs - } + full, err := misc.ResolveSafeFilePathInDir(h.cfg.AuthDir, name) + if err != nil { + c.JSON(400, gin.H{"error": "invalid name"}) + return } if err := os.Remove(full); err != nil { if os.IsNotExist(err) { @@ -684,16 +745,51 @@ func (h *Handler) authIDForPath(path string) string { return path } +func (h *Handler) resolveAuthPath(path string) (string, error) { + path = strings.TrimSpace(path) + if path == "" { + return "", fmt.Errorf("auth path is empty") + } + if h == nil || h.cfg == nil { + return "", fmt.Errorf("handler configuration unavailable") + } + authDir := strings.TrimSpace(h.cfg.AuthDir) + if authDir == "" { + return "", fmt.Errorf("auth directory not configured") + } + cleanAuthDir, err := filepath.Abs(filepath.Clean(authDir)) + if err != nil { + return "", fmt.Errorf("resolve auth dir: %w", err) + } + cleanPath := filepath.Clean(path) + absPath := cleanPath + if !filepath.IsAbs(absPath) { + absPath = filepath.Join(cleanAuthDir, cleanPath) + } + absPath, err = filepath.Abs(absPath) + if err != nil { + return "", fmt.Errorf("resolve auth path: %w", err) + } + relPath, err := filepath.Rel(cleanAuthDir, absPath) + if err != nil { + return "", fmt.Errorf("resolve relative auth path: %w", err) + } + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("auth path escapes auth directory") + } + return absPath, nil +} + func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error { if h.authManager == nil { return nil } - if path == "" { - return fmt.Errorf("auth path is empty") + safePath, err := h.resolveAuthPath(path) + if err != nil { + return err } if data == nil { - var err error - data, err = os.ReadFile(path) + data, err = os.ReadFile(safePath) if err != nil { return fmt.Errorf("failed to read auth file: %w", err) } @@ -712,18 +808,18 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data [] } lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata) - authID := h.authIDForPath(path) + authID := h.authIDForPath(safePath) if authID == "" { - authID = path + authID = safePath } attr := map[string]string{ - "path": path, - "source": path, + "path": safePath, + "source": safePath, } auth := &coreauth.Auth{ ID: authID, Provider: provider, - FileName: filepath.Base(path), + FileName: filepath.Base(safePath), Label: label, Status: coreauth.StatusActive, Attributes: attr, diff --git a/pkg/llmproxy/api/handlers/management/auth_files_callback_forwarder_test.go b/pkg/llmproxy/api/handlers/management/auth_files_callback_forwarder_test.go new file mode 100644 index 0000000000..9ef810b3c9 --- /dev/null +++ b/pkg/llmproxy/api/handlers/management/auth_files_callback_forwarder_test.go @@ -0,0 +1,31 @@ +package management + +import "testing" + +func TestValidateCallbackForwarderTargetAllowsLoopbackAndLocalhost(t *testing.T) { + cases := []string{ + "http://127.0.0.1:8080/callback", + "https://localhost:9999/callback?state=abc", + "http://[::1]:1455/callback", + } + for _, target := range cases { + if _, err := validateCallbackForwarderTarget(target); err != nil { + t.Fatalf("expected target %q to be allowed: %v", target, err) + } + } +} + +func TestValidateCallbackForwarderTargetRejectsNonLocalTargets(t *testing.T) { + cases := []string{ + "", + "/relative/callback", + "ftp://127.0.0.1/callback", + "http://example.com/callback", + "https://8.8.8.8/callback", + } + for _, target := range cases { + if _, err := validateCallbackForwarderTarget(target); err == nil { + t.Fatalf("expected target %q to be rejected", target) + } + } +} diff --git a/pkg/llmproxy/api/handlers/management/management_extra_test.go b/pkg/llmproxy/api/handlers/management/management_extra_test.go index 5f3ac4cb08..9dc4861240 100644 --- a/pkg/llmproxy/api/handlers/management/management_extra_test.go +++ b/pkg/llmproxy/api/handlers/management/management_extra_test.go @@ -338,6 +338,60 @@ func TestDeleteAuthFile(t *testing.T) { } } +func TestDownloadAuthFileRejectsTraversalName(t *testing.T) { + gin.SetMode(gin.TestMode) + tmpDir := t.TempDir() + h := &Handler{cfg: &config.Config{AuthDir: tmpDir}} + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("GET", "/?name=..\\evil.json", nil) + + h.DownloadAuthFile(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d, body: %s", w.Code, w.Body.String()) + } +} + +func TestUploadAuthFileRejectsTraversalName(t *testing.T) { + gin.SetMode(gin.TestMode) + tmpDir := t.TempDir() + h := &Handler{ + cfg: &config.Config{AuthDir: tmpDir}, + authManager: coreauth.NewManager(nil, nil, nil), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("POST", "/?name=..\\evil.json", strings.NewReader("{}")) + + h.UploadAuthFile(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d, body: %s", w.Code, w.Body.String()) + } +} + +func TestDeleteAuthFileRejectsTraversalName(t *testing.T) { + gin.SetMode(gin.TestMode) + tmpDir := t.TempDir() + h := &Handler{ + cfg: &config.Config{AuthDir: tmpDir}, + authManager: coreauth.NewManager(nil, nil, nil), + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest("DELETE", "/?name=..\\evil.json", nil) + + h.DeleteAuthFile(c) + + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d, body: %s", w.Code, w.Body.String()) + } +} + func TestIsReadOnlyConfigWriteError(t *testing.T) { if !isReadOnlyConfigWriteError(&os.PathError{Op: "open", Path: "/tmp/config.yaml", Err: syscall.EROFS}) { t.Fatal("expected EROFS path error to be treated as read-only config write error") diff --git a/pkg/llmproxy/api/handlers/management/management_modelstates_test.go b/pkg/llmproxy/api/handlers/management/management_modelstates_test.go index 44d0b203f5..af3074b05f 100644 --- a/pkg/llmproxy/api/handlers/management/management_modelstates_test.go +++ b/pkg/llmproxy/api/handlers/management/management_modelstates_test.go @@ -2,6 +2,8 @@ package management import ( "context" + "os" + "path/filepath" "testing" "time" @@ -56,3 +58,21 @@ func TestRegisterAuthFromFilePreservesModelStates(t *testing.T) { t.Fatalf("expected specific model state to be preserved") } } + +func TestRegisterAuthFromFileRejectsPathOutsideAuthDir(t *testing.T) { + authDir := t.TempDir() + outsidePath := filepath.Join(t.TempDir(), "outside.json") + if err := os.WriteFile(outsidePath, []byte(`{"type":"iflow"}`), 0o600); err != nil { + t.Fatalf("write outside auth file: %v", err) + } + + h := &Handler{ + cfg: &config.Config{AuthDir: authDir}, + authManager: coreauth.NewManager(nil, nil, nil), + } + + err := h.registerAuthFromFile(context.Background(), outsidePath, nil) + if err == nil { + t.Fatal("expected error for auth path outside auth directory") + } +} diff --git a/pkg/llmproxy/api/handlers/management/oauth_sessions.go b/pkg/llmproxy/api/handlers/management/oauth_sessions.go index bc882e990e..6b82a71679 100644 --- a/pkg/llmproxy/api/handlers/management/oauth_sessions.go +++ b/pkg/llmproxy/api/handlers/management/oauth_sessions.go @@ -251,10 +251,30 @@ type oauthCallbackFilePayload struct { Error string `json:"error"` } -func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { - if strings.TrimSpace(authDir) == "" { +func sanitizeOAuthCallbackPath(authDir, fileName string) (string, error) { + trimmedAuthDir := strings.TrimSpace(authDir) + if trimmedAuthDir == "" { return "", fmt.Errorf("auth dir is empty") } + if fileName != filepath.Base(fileName) || strings.Contains(fileName, string(os.PathSeparator)) { + return "", fmt.Errorf("invalid oauth callback file name") + } + cleanAuthDir, err := filepath.Abs(filepath.Clean(trimmedAuthDir)) + if err != nil { + return "", fmt.Errorf("resolve auth dir: %w", err) + } + filePath := filepath.Join(cleanAuthDir, fileName) + relPath, err := filepath.Rel(cleanAuthDir, filePath) + if err != nil { + return "", fmt.Errorf("resolve oauth callback file path: %w", err) + } + if relPath == ".." || strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("invalid oauth callback file path") + } + return filePath, nil +} + +func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { canonicalProvider, err := NormalizeOAuthProvider(provider) if err != nil { return "", err @@ -264,7 +284,13 @@ func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) } fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state) - filePath := filepath.Join(authDir, fileName) + filePath, err := sanitizeOAuthCallbackPath(authDir, fileName) + if err != nil { + return "", err + } + if err := os.MkdirAll(filepath.Dir(filePath), 0o700); err != nil { + return "", fmt.Errorf("create oauth callback dir: %w", err) + } payload := oauthCallbackFilePayload{ Code: strings.TrimSpace(code), State: strings.TrimSpace(state), diff --git a/pkg/llmproxy/api/handlers/management/oauth_sessions_test.go b/pkg/llmproxy/api/handlers/management/oauth_sessions_test.go new file mode 100644 index 0000000000..88ca0ab9a8 --- /dev/null +++ b/pkg/llmproxy/api/handlers/management/oauth_sessions_test.go @@ -0,0 +1,51 @@ +package management + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestWriteOAuthCallbackFile_WritesInsideAuthDir(t *testing.T) { + authDir := t.TempDir() + state := "safe-state-123" + + filePath, err := WriteOAuthCallbackFile(authDir, "claude", state, "code-1", "") + if err != nil { + t.Fatalf("WriteOAuthCallbackFile failed: %v", err) + } + + authDirAbs, err := filepath.Abs(authDir) + if err != nil { + t.Fatalf("resolve auth dir: %v", err) + } + filePathAbs, err := filepath.Abs(filePath) + if err != nil { + t.Fatalf("resolve callback path: %v", err) + } + prefix := authDirAbs + string(os.PathSeparator) + if filePathAbs != authDirAbs && !strings.HasPrefix(filePathAbs, prefix) { + t.Fatalf("callback path escaped auth dir: %q", filePathAbs) + } + + content, err := os.ReadFile(filePathAbs) + if err != nil { + t.Fatalf("read callback file: %v", err) + } + var payload oauthCallbackFilePayload + if err := json.Unmarshal(content, &payload); err != nil { + t.Fatalf("unmarshal callback file: %v", err) + } + if payload.State != state { + t.Fatalf("unexpected state: got %q want %q", payload.State, state) + } +} + +func TestSanitizeOAuthCallbackPath_RejectsInjectedFileName(t *testing.T) { + _, err := sanitizeOAuthCallbackPath(t.TempDir(), "../escape.oauth") + if err == nil { + t.Fatal("expected error for injected callback file name") + } +} diff --git a/pkg/llmproxy/api/middleware/response_writer.go b/pkg/llmproxy/api/middleware/response_writer.go index d57d9f8bad..b1fc7d5fef 100644 --- a/pkg/llmproxy/api/middleware/response_writer.go +++ b/pkg/llmproxy/api/middleware/response_writer.go @@ -388,6 +388,7 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h if w.requestInfo == nil { return nil } + requestHeaders := sanitizeRequestHeaders(http.Header(w.requestInfo.Headers)) if loggerWithOptions, ok := w.logger.(interface { LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error @@ -395,7 +396,7 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h return loggerWithOptions.LogRequestWithOptions( w.requestInfo.URL, w.requestInfo.Method, - w.requestInfo.Headers, + requestHeaders, requestBody, statusCode, headers, @@ -413,7 +414,7 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h return w.logger.LogRequest( w.requestInfo.URL, w.requestInfo.Method, - w.requestInfo.Headers, + requestHeaders, requestBody, statusCode, headers, diff --git a/pkg/llmproxy/api/modules/amp/response_rewriter.go b/pkg/llmproxy/api/modules/amp/response_rewriter.go index 8a9cad704d..b789aeacfb 100644 --- a/pkg/llmproxy/api/modules/amp/response_rewriter.go +++ b/pkg/llmproxy/api/modules/amp/response_rewriter.go @@ -25,12 +25,23 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe return &ResponseRewriter{ ResponseWriter: w, body: &bytes.Buffer{}, - originalModel: originalModel, + originalModel: sanitizeModelIDForResponse(originalModel), } } const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap +func sanitizeModelIDForResponse(modelID string) string { + modelID = strings.TrimSpace(modelID) + if modelID == "" { + return "" + } + if strings.ContainsAny(modelID, "<>\r\n\x00") { + return "" + } + return modelID +} + func looksLikeSSEChunk(data []byte) bool { // Fallback detection: some upstreams may omit/lie about Content-Type, causing SSE to be buffered. // Heuristics are intentionally simple and cheap. diff --git a/pkg/llmproxy/api/modules/amp/response_rewriter_test.go b/pkg/llmproxy/api/modules/amp/response_rewriter_test.go index 114a9516fc..bf4c99483b 100644 --- a/pkg/llmproxy/api/modules/amp/response_rewriter_test.go +++ b/pkg/llmproxy/api/modules/amp/response_rewriter_test.go @@ -100,6 +100,15 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) { } } +func TestSanitizeModelIDForResponse(t *testing.T) { + if got := sanitizeModelIDForResponse(" gpt-5.2-codex "); got != "gpt-5.2-codex" { + t.Fatalf("expected trimmed model id, got %q", got) + } + if got := sanitizeModelIDForResponse("gpt-5", want: false}, + {name: "empty", url: " ", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isValidURL(tt.url); got != tt.want { + t.Fatalf("isValidURL(%q) = %v, want %v", tt.url, got, tt.want) + } + }) + } +} + +func TestGenerateSuccessHTMLEscapesPlatformURL(t *testing.T) { + server := NewOAuthServer(9999) + malicious := `https://console.anthropic.com/" onclick="alert('xss')` + + rendered := server.generateSuccessHTML(true, malicious) + + if strings.Contains(rendered, malicious) { + t.Fatalf("rendered html contains unescaped platform URL") + } + if strings.Contains(rendered, `onclick="alert('xss')`) { + t.Fatalf("rendered html contains unescaped injected attribute") + } + if !strings.Contains(rendered, `https://console.anthropic.com/" onclick="alert('xss')`) { + t.Fatalf("rendered html does not contain expected escaped URL") + } +} diff --git a/pkg/llmproxy/auth/claude/token.go b/pkg/llmproxy/auth/claude/token.go index f04503e6da..757b03235f 100644 --- a/pkg/llmproxy/auth/claude/token.go +++ b/pkg/llmproxy/auth/claude/token.go @@ -8,10 +8,32 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/misc" ) +func sanitizeTokenFilePath(authFilePath string) (string, error) { + trimmed := strings.TrimSpace(authFilePath) + if trimmed == "" { + return "", fmt.Errorf("token file path is empty") + } + cleaned := filepath.Clean(trimmed) + parts := strings.FieldsFunc(cleaned, func(r rune) bool { + return r == '/' || r == '\\' + }) + for _, part := range parts { + if part == ".." { + return "", fmt.Errorf("invalid token file path") + } + } + absPath, err := filepath.Abs(cleaned) + if err != nil { + return "", fmt.Errorf("failed to resolve token file path: %w", err) + } + return absPath, nil +} + // ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication. // It maintains compatibility with the existing auth system while adding Claude-specific fields // for managing access tokens, refresh tokens, and user account information. @@ -48,16 +70,24 @@ type ClaudeTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "claude" + safePath, err := sanitizeTokenFilePath(authFilePath) + if err != nil { + return err + } // Create directory structure if it doesn't exist - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } // Create the token file - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/claude/token_test.go b/pkg/llmproxy/auth/claude/token_test.go new file mode 100644 index 0000000000..c7ae86845e --- /dev/null +++ b/pkg/llmproxy/auth/claude/token_test.go @@ -0,0 +1,10 @@ +package claude + +import "testing" + +func TestClaudeTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &ClaudeTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../claude-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/codex/oauth_server.go b/pkg/llmproxy/auth/codex/oauth_server.go index c3674dc7bc..75bf193e11 100644 --- a/pkg/llmproxy/auth/codex/oauth_server.go +++ b/pkg/llmproxy/auth/codex/oauth_server.go @@ -4,8 +4,10 @@ import ( "context" "errors" "fmt" + "html" "net" "net/http" + "net/url" "strings" "sync" "time" @@ -256,7 +258,18 @@ func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) { // isValidURL checks if the URL is a valid http/https URL to prevent XSS func isValidURL(urlStr string) bool { urlStr = strings.TrimSpace(urlStr) - return strings.HasPrefix(urlStr, "https://") || strings.HasPrefix(urlStr, "http://") + if urlStr == "" || strings.ContainsAny(urlStr, "\"'<>") { + return false + } + parsed, err := url.Parse(urlStr) + if err != nil || !parsed.IsAbs() { + return false + } + scheme := strings.ToLower(parsed.Scheme) + if scheme != "https" && scheme != "http" { + return false + } + return strings.TrimSpace(parsed.Host) != "" } // generateSuccessHTML creates the HTML content for the success page. @@ -270,20 +283,21 @@ func isValidURL(urlStr string) bool { // Returns: // - string: The HTML content for the success page func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string { - html := LoginSuccessHtml + pageHTML := LoginSuccessHtml + escapedURL := html.EscapeString(platformURL) // Replace platform URL placeholder - html = strings.ReplaceAll(html, "{{PLATFORM_URL}}", platformURL) + pageHTML = strings.ReplaceAll(pageHTML, "{{PLATFORM_URL}}", escapedURL) // Add setup notice if required if setupRequired { - setupNotice := strings.ReplaceAll(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL) - html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1) + setupNotice := strings.ReplaceAll(SetupNoticeHtml, "{{PLATFORM_URL}}", escapedURL) + pageHTML = strings.Replace(pageHTML, "{{SETUP_NOTICE}}", setupNotice, 1) } else { - html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1) + pageHTML = strings.Replace(pageHTML, "{{SETUP_NOTICE}}", "", 1) } - return html + return pageHTML } // sendResult sends the OAuth result to the waiting channel. diff --git a/pkg/llmproxy/auth/codex/oauth_server_test.go b/pkg/llmproxy/auth/codex/oauth_server_test.go index 799255557e..47740feb2b 100644 --- a/pkg/llmproxy/auth/codex/oauth_server_test.go +++ b/pkg/llmproxy/auth/codex/oauth_server_test.go @@ -113,8 +113,11 @@ func TestIsValidURL(t *testing.T) { }{ {"https://example.com", true}, {"http://example.com", true}, + {" https://example.com/path?q=1 ", true}, {"javascript:alert(1)", false}, {"ftp://example.com", false}, + {"https://example.com\" onclick=\"alert(1)", false}, + {"https://", false}, } for _, tc := range cases { if isValidURL(tc.url) != tc.want { @@ -122,3 +125,15 @@ func TestIsValidURL(t *testing.T) { } } } + +func TestGenerateSuccessHTML_EscapesPlatformURL(t *testing.T) { + server := NewOAuthServer(1459) + malicious := `https://example.com" onclick="alert(1)` + got := server.generateSuccessHTML(true, malicious) + if strings.Contains(got, malicious) { + t.Fatalf("expected malicious URL to be escaped in HTML output") + } + if !strings.Contains(got, "https://example.com" onclick="alert(1)") { + t.Fatalf("expected escaped URL in HTML output, got: %s", got) + } +} diff --git a/pkg/llmproxy/auth/codex/token.go b/pkg/llmproxy/auth/codex/token.go index 3f7945c405..9e21f7bd16 100644 --- a/pkg/llmproxy/auth/codex/token.go +++ b/pkg/llmproxy/auth/codex/token.go @@ -8,10 +8,32 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/misc" ) +func sanitizeTokenFilePath(authFilePath string) (string, error) { + trimmed := strings.TrimSpace(authFilePath) + if trimmed == "" { + return "", fmt.Errorf("token file path is empty") + } + cleaned := filepath.Clean(trimmed) + parts := strings.FieldsFunc(cleaned, func(r rune) bool { + return r == '/' || r == '\\' + }) + for _, part := range parts { + if part == ".." { + return "", fmt.Errorf("invalid token file path") + } + } + absPath, err := filepath.Abs(cleaned) + if err != nil { + return "", fmt.Errorf("failed to resolve token file path: %w", err) + } + return absPath, nil +} + // CodexTokenStorage stores OAuth2 token information for OpenAI Codex API authentication. // It maintains compatibility with the existing auth system while adding Codex-specific fields // for managing access tokens, refresh tokens, and user account information. @@ -44,13 +66,17 @@ type CodexTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "codex" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/codex/token_test.go b/pkg/llmproxy/auth/codex/token_test.go index c55d2af966..7188dc2986 100644 --- a/pkg/llmproxy/auth/codex/token_test.go +++ b/pkg/llmproxy/auth/codex/token_test.go @@ -59,3 +59,10 @@ func TestSaveTokenToFile_MkdirFail(t *testing.T) { t.Error("expected error for invalid directory path") } } + +func TestCodexTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &CodexTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../codex-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/copilot/token.go b/pkg/llmproxy/auth/copilot/token.go index e1117c0690..fc013c5387 100644 --- a/pkg/llmproxy/auth/copilot/token.go +++ b/pkg/llmproxy/auth/copilot/token.go @@ -72,13 +72,17 @@ type DeviceCodeResponse struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *CopilotTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "github-copilot" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/copilot/token_test.go b/pkg/llmproxy/auth/copilot/token_test.go index 7ed2ac3297..05e87960e2 100644 --- a/pkg/llmproxy/auth/copilot/token_test.go +++ b/pkg/llmproxy/auth/copilot/token_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "testing" ) @@ -40,3 +41,10 @@ func TestCopilotTokenStorage_SaveTokenToFile(t *testing.T) { t.Errorf("expected type github-copilot, got %s", tsLoaded.Type) } } + +func TestCopilotTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &CopilotTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../copilot-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/diff/models_summary.go b/pkg/llmproxy/auth/diff/models_summary.go index faa82a7640..52e35e4968 100644 --- a/pkg/llmproxy/auth/diff/models_summary.go +++ b/pkg/llmproxy/auth/diff/models_summary.go @@ -1,8 +1,6 @@ package diff import ( - "crypto/sha256" - "encoding/hex" "sort" "strings" @@ -113,9 +111,8 @@ func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummar return VertexModelsSummary{} } sort.Strings(names) - sum := sha256.Sum256([]byte(strings.Join(names, "|"))) return VertexModelsSummary{ - hash: hex.EncodeToString(sum[:]), + hash: strings.Join(names, "|"), count: len(names), } } diff --git a/pkg/llmproxy/auth/diff/oauth_excluded_test.go b/pkg/llmproxy/auth/diff/oauth_excluded_test.go index 0fbcc27114..3577c3701e 100644 --- a/pkg/llmproxy/auth/diff/oauth_excluded_test.go +++ b/pkg/llmproxy/auth/diff/oauth_excluded_test.go @@ -98,6 +98,16 @@ func TestSummarizeVertexModels(t *testing.T) { } } +func TestSummarizeVertexModels_UsesCanonicalJoinedSignature(t *testing.T) { + summary := SummarizeVertexModels([]config.VertexCompatModel{ + {Name: "m1"}, + {Alias: "alias"}, + }) + if summary.hash != "alias|m1" { + t.Fatalf("expected canonical joined signature, got %q", summary.hash) + } +} + func expectContains(t *testing.T, list []string, target string) { t.Helper() for _, entry := range list { diff --git a/pkg/llmproxy/auth/diff/openai_compat.go b/pkg/llmproxy/auth/diff/openai_compat.go index 41726db3c3..99d136bdcb 100644 --- a/pkg/llmproxy/auth/diff/openai_compat.go +++ b/pkg/llmproxy/auth/diff/openai_compat.go @@ -1,8 +1,6 @@ package diff import ( - "crypto/sha256" - "encoding/hex" "fmt" "sort" "strings" @@ -178,6 +176,5 @@ func openAICompatSignature(entry config.OpenAICompatibility) string { if len(parts) == 0 { return "" } - sum := sha256.Sum256([]byte(strings.Join(parts, "|"))) - return hex.EncodeToString(sum[:]) + return strings.Join(parts, "|") } diff --git a/pkg/llmproxy/auth/diff/openai_compat_test.go b/pkg/llmproxy/auth/diff/openai_compat_test.go index 434d989d5b..029b24c0ed 100644 --- a/pkg/llmproxy/auth/diff/openai_compat_test.go +++ b/pkg/llmproxy/auth/diff/openai_compat_test.go @@ -163,6 +163,26 @@ func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) { } } +func TestOpenAICompatSignature_DoesNotIncludeRawAPIKeyMaterial(t *testing.T) { + entry := config.OpenAICompatibility{ + Name: "provider", + APIKeyEntries: []config.OpenAICompatibilityAPIKey{ + {APIKey: "super-secret-key"}, + {APIKey: "another-secret-key"}, + }, + } + sig := openAICompatSignature(entry) + if sig == "" { + t.Fatal("expected non-empty signature") + } + if strings.Contains(sig, "super-secret-key") || strings.Contains(sig, "another-secret-key") { + t.Fatalf("signature must not include API key values: %q", sig) + } + if !strings.Contains(sig, "api_keys=2") { + t.Fatalf("expected signature to keep api key count, got %q", sig) + } +} + func TestCountOpenAIModelsSkipsBlanks(t *testing.T) { models := []config.OpenAICompatibilityModel{ {Name: "m1"}, diff --git a/pkg/llmproxy/auth/gemini/gemini_auth_test.go b/pkg/llmproxy/auth/gemini/gemini_auth_test.go index efb5cb9c88..0ef8a1b36f 100644 --- a/pkg/llmproxy/auth/gemini/gemini_auth_test.go +++ b/pkg/llmproxy/auth/gemini/gemini_auth_test.go @@ -7,6 +7,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" @@ -66,6 +67,19 @@ func TestGeminiTokenStorage_SaveAndLoad(t *testing.T) { } } +func TestGeminiTokenStorage_SaveTokenToFile_RejectsTraversalPath(t *testing.T) { + ts := &GeminiTokenStorage{Token: "raw-token-data"} + badPath := t.TempDir() + "/../gemini-token.json" + + err := ts.SaveTokenToFile(badPath) + if err == nil { + t.Fatal("expected error for traversal path") + } + if !strings.Contains(err.Error(), "invalid token file path") { + t.Fatalf("expected invalid path error, got %v", err) + } +} + func TestGeminiAuth_CreateTokenStorage(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/oauth2/v1/userinfo" { diff --git a/pkg/llmproxy/auth/gemini/gemini_token.go b/pkg/llmproxy/auth/gemini/gemini_token.go index 3091d30486..b06e0f8532 100644 --- a/pkg/llmproxy/auth/gemini/gemini_token.go +++ b/pkg/llmproxy/auth/gemini/gemini_token.go @@ -47,13 +47,17 @@ type GeminiTokenStorage struct { // Returns: // - error: An error if the operation fails, nil otherwise func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "gemini" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/gemini/gemini_token_test.go b/pkg/llmproxy/auth/gemini/gemini_token_test.go new file mode 100644 index 0000000000..025c943792 --- /dev/null +++ b/pkg/llmproxy/auth/gemini/gemini_token_test.go @@ -0,0 +1,10 @@ +package gemini + +import "testing" + +func TestGeminiTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &GeminiTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../gemini-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/iflow/iflow_token.go b/pkg/llmproxy/auth/iflow/iflow_token.go index 9d6ad2328b..c75dd5ec34 100644 --- a/pkg/llmproxy/auth/iflow/iflow_token.go +++ b/pkg/llmproxy/auth/iflow/iflow_token.go @@ -25,13 +25,17 @@ type IFlowTokenStorage struct { // SaveTokenToFile serialises the token storage to disk. func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "iflow" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0o700); err != nil { return fmt.Errorf("iflow token: create directory failed: %w", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("iflow token: create file failed: %w", err) } diff --git a/pkg/llmproxy/auth/iflow/iflow_token_test.go b/pkg/llmproxy/auth/iflow/iflow_token_test.go new file mode 100644 index 0000000000..cb178a59c6 --- /dev/null +++ b/pkg/llmproxy/auth/iflow/iflow_token_test.go @@ -0,0 +1,10 @@ +package iflow + +import "testing" + +func TestIFlowTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &IFlowTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../iflow-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/kilo/kilo_token.go b/pkg/llmproxy/auth/kilo/kilo_token.go index 029cca6ef0..6a5fa30ee7 100644 --- a/pkg/llmproxy/auth/kilo/kilo_token.go +++ b/pkg/llmproxy/auth/kilo/kilo_token.go @@ -32,13 +32,17 @@ type KiloTokenStorage struct { // SaveTokenToFile serializes the Kilo token storage to a JSON file. func (ts *KiloTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "kilo" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/kilo/kilo_token_test.go b/pkg/llmproxy/auth/kilo/kilo_token_test.go new file mode 100644 index 0000000000..9b0785990a --- /dev/null +++ b/pkg/llmproxy/auth/kilo/kilo_token_test.go @@ -0,0 +1,10 @@ +package kilo + +import "testing" + +func TestKiloTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &KiloTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../kilo-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/kimi/token.go b/pkg/llmproxy/auth/kimi/token.go index 39c8e94c05..29fb3ea6f6 100644 --- a/pkg/llmproxy/auth/kimi/token.go +++ b/pkg/llmproxy/auth/kimi/token.go @@ -71,14 +71,18 @@ type DeviceCodeResponse struct { // SaveTokenToFile serializes the Kimi token storage to a JSON file. func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error { - misc.LogSavingCredentials(authFilePath) + safePath, err := misc.ResolveSafeFilePath(authFilePath) + if err != nil { + return fmt.Errorf("invalid token file path: %w", err) + } + misc.LogSavingCredentials(safePath) ts.Type = "kimi" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + if err = os.MkdirAll(filepath.Dir(safePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(safePath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } diff --git a/pkg/llmproxy/auth/kimi/token_path_test.go b/pkg/llmproxy/auth/kimi/token_path_test.go new file mode 100644 index 0000000000..c4b27147e6 --- /dev/null +++ b/pkg/llmproxy/auth/kimi/token_path_test.go @@ -0,0 +1,19 @@ +package kimi + +import ( + "strings" + "testing" +) + +func TestKimiTokenStorage_SaveTokenToFile_RejectsTraversalPath(t *testing.T) { + ts := &KimiTokenStorage{AccessToken: "token"} + badPath := t.TempDir() + "/../kimi-token.json" + + err := ts.SaveTokenToFile(badPath) + if err == nil { + t.Fatal("expected error for traversal path") + } + if !strings.Contains(err.Error(), "invalid token file path") { + t.Fatalf("expected invalid path error, got %v", err) + } +} diff --git a/pkg/llmproxy/auth/kimi/token_test.go b/pkg/llmproxy/auth/kimi/token_test.go new file mode 100644 index 0000000000..36475e6449 --- /dev/null +++ b/pkg/llmproxy/auth/kimi/token_test.go @@ -0,0 +1,10 @@ +package kimi + +import "testing" + +func TestKimiTokenStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + ts := &KimiTokenStorage{} + if err := ts.SaveTokenToFile("/tmp/../kimi-escape.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/kiro/sso_oidc.go b/pkg/llmproxy/auth/kiro/sso_oidc.go index 3ad450db74..f449bbca07 100644 --- a/pkg/llmproxy/auth/kiro/sso_oidc.go +++ b/pkg/llmproxy/auth/kiro/sso_oidc.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "os" + "regexp" "strings" "time" @@ -52,6 +53,7 @@ const ( var ( ErrAuthorizationPending = errors.New("authorization_pending") ErrSlowDown = errors.New("slow_down") + awsRegionPattern = regexp.MustCompile(`^[a-z]{2}(?:-[a-z0-9]+)+-\d+$`) ) // SSOOIDCClient handles AWS SSO OIDC authentication. @@ -106,6 +108,17 @@ func getOIDCEndpoint(region string) string { return fmt.Sprintf("https://oidc.%s.amazonaws.com", region) } +func validateIDCRegion(region string) (string, error) { + region = strings.TrimSpace(region) + if region == "" { + return defaultIDCRegion, nil + } + if !awsRegionPattern.MatchString(region) { + return "", fmt.Errorf("invalid region %q", region) + } + return region, nil +} + func buildIDCRefreshPayload(clientID, clientSecret, refreshToken string) map[string]string { return map[string]string{ "clientId": clientID, @@ -184,7 +197,11 @@ func promptSelect(prompt string, options []string) int { // RegisterClientWithRegion registers a new OIDC client with AWS using a specific region. func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region string) (*RegisterClientResponse, error) { - endpoint := getOIDCEndpoint(region) + validatedRegion, err := validateIDCRegion(region) + if err != nil { + return nil, err + } + endpoint := getOIDCEndpoint(validatedRegion) payload := map[string]interface{}{ "clientName": "Kiro IDE", @@ -231,7 +248,11 @@ func (c *SSOOIDCClient) RegisterClientWithRegion(ctx context.Context, region str // StartDeviceAuthorizationWithIDC starts the device authorization flow for IDC. func (c *SSOOIDCClient) StartDeviceAuthorizationWithIDC(ctx context.Context, clientID, clientSecret, startURL, region string) (*StartDeviceAuthResponse, error) { - endpoint := getOIDCEndpoint(region) + validatedRegion, err := validateIDCRegion(region) + if err != nil { + return nil, err + } + endpoint := getOIDCEndpoint(validatedRegion) payload := map[string]string{ "clientId": clientID, diff --git a/pkg/llmproxy/auth/kiro/sso_oidc_test.go b/pkg/llmproxy/auth/kiro/sso_oidc_test.go index 350716a745..5979457216 100644 --- a/pkg/llmproxy/auth/kiro/sso_oidc_test.go +++ b/pkg/llmproxy/auth/kiro/sso_oidc_test.go @@ -8,12 +8,6 @@ import ( "testing" ) -type roundTripperFunc func(*http.Request) (*http.Response, error) - -func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) -} - func TestRefreshToken_IncludesGrantTypeAndExtensionHeaders(t *testing.T) { t.Parallel() @@ -106,3 +100,39 @@ func TestRefreshTokenWithRegion_UsesRegionHostAndGrantType(t *testing.T) { t.Fatalf("unexpected token data: %#v", got) } } + +func TestRegisterClientWithRegion_RejectsInvalidRegion(t *testing.T) { + t.Parallel() + + client := &SSOOIDCClient{ + httpClient: &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + t.Fatalf("unexpected outbound request: %s", req.URL.String()) + return nil, nil + }), + }, + } + + _, err := client.RegisterClientWithRegion(context.Background(), "us-east-1\nmalicious") + if err == nil { + t.Fatalf("expected invalid region error") + } +} + +func TestStartDeviceAuthorizationWithIDC_RejectsInvalidRegion(t *testing.T) { + t.Parallel() + + client := &SSOOIDCClient{ + httpClient: &http.Client{ + Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) { + t.Fatalf("unexpected outbound request: %s", req.URL.String()) + return nil, nil + }), + }, + } + + _, err := client.StartDeviceAuthorizationWithIDC(context.Background(), "cid", "secret", "https://view.awsapps.com/start", "../../etc/passwd") + if err == nil { + t.Fatalf("expected invalid region error") + } +} diff --git a/pkg/llmproxy/auth/kiro/token.go b/pkg/llmproxy/auth/kiro/token.go index 0484a2dc6d..7df2911107 100644 --- a/pkg/llmproxy/auth/kiro/token.go +++ b/pkg/llmproxy/auth/kiro/token.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "strings" ) // KiroTokenStorage holds the persistent token data for Kiro authentication. @@ -39,7 +40,11 @@ type KiroTokenStorage struct { // SaveTokenToFile persists the token storage to the specified file path. func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { - dir := filepath.Dir(authFilePath) + cleanPath, err := cleanTokenPath(authFilePath, "kiro token") + if err != nil { + return err + } + dir := filepath.Dir(cleanPath) if err := os.MkdirAll(dir, 0700); err != nil { return fmt.Errorf("failed to create directory: %w", err) } @@ -49,13 +54,29 @@ func (s *KiroTokenStorage) SaveTokenToFile(authFilePath string) error { return fmt.Errorf("failed to marshal token storage: %w", err) } - if err := os.WriteFile(authFilePath, data, 0600); err != nil { + if err := os.WriteFile(cleanPath, data, 0600); err != nil { return fmt.Errorf("failed to write token file: %w", err) } return nil } +func cleanTokenPath(path, scope string) (string, error) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", fmt.Errorf("%s: auth file path is empty", scope) + } + clean := filepath.Clean(filepath.FromSlash(trimmed)) + if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("%s: auth file path is invalid", scope) + } + abs, err := filepath.Abs(clean) + if err != nil { + return "", fmt.Errorf("%s: resolve auth file path: %w", scope, err) + } + return filepath.Clean(abs), nil +} + // LoadFromFile loads token storage from the specified file path. func LoadFromFile(authFilePath string) (*KiroTokenStorage, error) { data, err := os.ReadFile(authFilePath) diff --git a/pkg/llmproxy/auth/kiro/token_extra_test.go b/pkg/llmproxy/auth/kiro/token_extra_test.go index 6c8c75ad85..32bd04e20f 100644 --- a/pkg/llmproxy/auth/kiro/token_extra_test.go +++ b/pkg/llmproxy/auth/kiro/token_extra_test.go @@ -3,6 +3,7 @@ package kiro import ( "os" "path/filepath" + "strings" "testing" ) @@ -51,3 +52,16 @@ func TestLoadFromFile_Errors(t *testing.T) { t.Error("expected error for invalid JSON") } } + +func TestKiroTokenStorageSaveTokenToFileRejectsTraversalPath(t *testing.T) { + t.Parallel() + + ts := &KiroTokenStorage{Type: "kiro", AccessToken: "token"} + err := ts.SaveTokenToFile("../kiro-token.json") + if err == nil { + t.Fatal("expected error for traversal path") + } + if !strings.Contains(err.Error(), "auth file path is invalid") { + t.Fatalf("expected invalid path error, got %v", err) + } +} diff --git a/pkg/llmproxy/auth/qwen/qwen_auth_test.go b/pkg/llmproxy/auth/qwen/qwen_auth_test.go index 8e66f4bf8e..10ad23d666 100644 --- a/pkg/llmproxy/auth/qwen/qwen_auth_test.go +++ b/pkg/llmproxy/auth/qwen/qwen_auth_test.go @@ -82,3 +82,16 @@ func TestRefreshTokens(t *testing.T) { t.Errorf("got access token %q, want new-access", resp.AccessToken) } } + +func TestQwenTokenStorageSaveTokenToFileRejectsTraversalPath(t *testing.T) { + t.Parallel() + + ts := &QwenTokenStorage{AccessToken: "token"} + err := ts.SaveTokenToFile("../qwen.json") + if err == nil { + t.Fatal("expected error for traversal path") + } + if !strings.Contains(err.Error(), "auth file path is invalid") { + t.Fatalf("expected invalid path error, got %v", err) + } +} diff --git a/pkg/llmproxy/auth/qwen/qwen_token.go b/pkg/llmproxy/auth/qwen/qwen_token.go index b17df8e346..10104bf89c 100644 --- a/pkg/llmproxy/auth/qwen/qwen_token.go +++ b/pkg/llmproxy/auth/qwen/qwen_token.go @@ -8,6 +8,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/misc" ) @@ -44,11 +45,15 @@ type QwenTokenStorage struct { func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { misc.LogSavingCredentials(authFilePath) ts.Type = "qwen" - if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { + cleanPath, err := cleanTokenFilePath(authFilePath, "qwen token") + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(cleanPath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(cleanPath) if err != nil { return fmt.Errorf("failed to create token file: %w", err) } @@ -61,3 +66,19 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { } return nil } + +func cleanTokenFilePath(path, scope string) (string, error) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", fmt.Errorf("%s: auth file path is empty", scope) + } + clean := filepath.Clean(filepath.FromSlash(trimmed)) + if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("%s: auth file path is invalid", scope) + } + abs, err := filepath.Abs(clean) + if err != nil { + return "", fmt.Errorf("%s: resolve auth file path: %w", scope, err) + } + return filepath.Clean(abs), nil +} diff --git a/pkg/llmproxy/auth/qwen/qwen_token_test.go b/pkg/llmproxy/auth/qwen/qwen_token_test.go new file mode 100644 index 0000000000..3fb4881ab5 --- /dev/null +++ b/pkg/llmproxy/auth/qwen/qwen_token_test.go @@ -0,0 +1,36 @@ +package qwen + +import ( + "os" + "path/filepath" + "testing" +) + +func TestQwenTokenStorage_SaveTokenToFile(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "qwen-token.json") + ts := &QwenTokenStorage{ + AccessToken: "access", + Email: "test@example.com", + } + + if err := ts.SaveTokenToFile(path); err != nil { + t.Fatalf("SaveTokenToFile failed: %v", err) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("expected token file to exist: %v", err) + } +} + +func TestQwenTokenStorage_SaveTokenToFile_RejectsTraversalPath(t *testing.T) { + t.Parallel() + + ts := &QwenTokenStorage{ + AccessToken: "access", + } + if err := ts.SaveTokenToFile("../qwen-token.json"); err == nil { + t.Fatal("expected traversal path to be rejected") + } +} diff --git a/pkg/llmproxy/auth/synthesizer/helpers.go b/pkg/llmproxy/auth/synthesizer/helpers.go index 3ee77354c5..a1c7ac4387 100644 --- a/pkg/llmproxy/auth/synthesizer/helpers.go +++ b/pkg/llmproxy/auth/synthesizer/helpers.go @@ -1,7 +1,8 @@ package synthesizer import ( - "crypto/sha256" + "crypto/hmac" + "crypto/sha512" "encoding/hex" "fmt" "sort" @@ -12,8 +13,10 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +const stableIDGeneratorHashKey = "auth-stable-id-generator:v1" + // StableIDGenerator generates stable, deterministic IDs for auth entries. -// It uses SHA256 hashing with collision handling via counters. +// It uses keyed HMAC-SHA512 hashing with collision handling via counters. // It is not safe for concurrent use. type StableIDGenerator struct { counters map[string]int @@ -30,7 +33,7 @@ func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) if g == nil { return kind + ":000000000000", "000000000000" } - hasher := sha256.New() + hasher := hmac.New(sha512.New, []byte(stableIDGeneratorHashKey)) hasher.Write([]byte(kind)) for _, part := range parts { trimmed := strings.TrimSpace(part) diff --git a/pkg/llmproxy/auth/synthesizer/helpers_test.go b/pkg/llmproxy/auth/synthesizer/helpers_test.go index b21d3e109a..5840f6716e 100644 --- a/pkg/llmproxy/auth/synthesizer/helpers_test.go +++ b/pkg/llmproxy/auth/synthesizer/helpers_test.go @@ -1,6 +1,8 @@ package synthesizer import ( + "crypto/sha256" + "encoding/hex" "reflect" "strings" "testing" @@ -10,6 +12,26 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +func TestStableIDGenerator_Next_DoesNotUseLegacySHA256(t *testing.T) { + gen := NewStableIDGenerator() + id, short := gen.Next("gemini:apikey", "test-key", "https://api.example.com") + if id == "" || short == "" { + t.Fatal("expected generated IDs to be non-empty") + } + + legacyHasher := sha256.New() + legacyHasher.Write([]byte("gemini:apikey")) + legacyHasher.Write([]byte{0}) + legacyHasher.Write([]byte("test-key")) + legacyHasher.Write([]byte{0}) + legacyHasher.Write([]byte("https://api.example.com")) + legacyShort := hex.EncodeToString(legacyHasher.Sum(nil))[:12] + + if short == legacyShort { + t.Fatalf("expected short id to differ from legacy sha256 digest %q", legacyShort) + } +} + func TestNewStableIDGenerator(t *testing.T) { gen := NewStableIDGenerator() if gen == nil { diff --git a/pkg/llmproxy/auth/vertex/vertex_credentials.go b/pkg/llmproxy/auth/vertex/vertex_credentials.go index de01dd9440..2d8c107662 100644 --- a/pkg/llmproxy/auth/vertex/vertex_credentials.go +++ b/pkg/llmproxy/auth/vertex/vertex_credentials.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/misc" log "github.com/sirupsen/logrus" @@ -44,11 +45,15 @@ func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error { } // Ensure we tag the file with the provider type. s.Type = "vertex" + cleanPath, err := cleanCredentialPath(authFilePath, "vertex credential") + if err != nil { + return err + } - if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil { + if err := os.MkdirAll(filepath.Dir(cleanPath), 0o700); err != nil { return fmt.Errorf("vertex credential: create directory failed: %w", err) } - f, err := os.Create(authFilePath) + f, err := os.Create(cleanPath) if err != nil { return fmt.Errorf("vertex credential: create file failed: %w", err) } @@ -64,3 +69,19 @@ func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error { } return nil } + +func cleanCredentialPath(path, scope string) (string, error) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", fmt.Errorf("%s: auth file path is empty", scope) + } + clean := filepath.Clean(filepath.FromSlash(trimmed)) + if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("%s: auth file path is invalid", scope) + } + abs, err := filepath.Abs(clean) + if err != nil { + return "", fmt.Errorf("%s: resolve auth file path: %w", scope, err) + } + return filepath.Clean(abs), nil +} diff --git a/pkg/llmproxy/auth/vertex/vertex_credentials_test.go b/pkg/llmproxy/auth/vertex/vertex_credentials_test.go index d69c4311f1..91947892a1 100644 --- a/pkg/llmproxy/auth/vertex/vertex_credentials_test.go +++ b/pkg/llmproxy/auth/vertex/vertex_credentials_test.go @@ -3,6 +3,7 @@ package vertex import ( "os" "path/filepath" + "strings" "testing" ) @@ -47,3 +48,19 @@ func TestVertexCredentialStorage_NilChecks(t *testing.T) { t.Error("expected error for empty service account") } } + +func TestVertexCredentialStorage_SaveTokenToFileRejectsTraversalPath(t *testing.T) { + t.Parallel() + + s := &VertexCredentialStorage{ + ServiceAccount: map[string]any{"project_id": "p"}, + } + + err := s.SaveTokenToFile("../vertex.json") + if err == nil { + t.Fatal("expected error for traversal path") + } + if !strings.Contains(err.Error(), "auth file path is invalid") { + t.Fatalf("expected invalid path error, got %v", err) + } +} diff --git a/pkg/llmproxy/cmd/iflow_cookie.go b/pkg/llmproxy/cmd/iflow_cookie.go index 176eb7480a..f400723c4d 100644 --- a/pkg/llmproxy/cmd/iflow_cookie.go +++ b/pkg/llmproxy/cmd/iflow_cookie.go @@ -71,7 +71,7 @@ func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) { return } - fmt.Printf("Authentication successful! API key: %s\n", tokenData.APIKey) + fmt.Println("Authentication successful.") fmt.Printf("Expires at: %s\n", tokenData.Expire) fmt.Printf("Authentication saved to: %s\n", authFilePath) } diff --git a/pkg/llmproxy/config/config.go b/pkg/llmproxy/config/config.go index 82e0732a89..fde2a02d89 100644 --- a/pkg/llmproxy/config/config.go +++ b/pkg/llmproxy/config/config.go @@ -1648,12 +1648,20 @@ func appendPath(path []string, key string) []string { if len(path) == 0 { return []string{key} } - newPath := make([]string, len(path)+1) + newPath := make([]string, checkedPathLengthPlusOne(len(path))) copy(newPath, path) newPath[len(path)] = key return newPath } +func checkedPathLengthPlusOne(pathLen int) int { + maxInt := int(^uint(0) >> 1) + if pathLen < 0 || pathLen >= maxInt { + panic(fmt.Sprintf("path length overflow: %d", pathLen)) + } + return pathLen + 1 +} + // isKnownDefaultValue returns true if the given node at the specified path // represents a known default value that should not be written to the config file. // This prevents non-zero defaults from polluting the config. diff --git a/pkg/llmproxy/config/config_test.go b/pkg/llmproxy/config/config_test.go index a18c5a6dcf..baa88143a3 100644 --- a/pkg/llmproxy/config/config_test.go +++ b/pkg/llmproxy/config/config_test.go @@ -79,3 +79,17 @@ func TestLoadConfigOptional_DirectoryPath(t *testing.T) { t.Fatal("expected non-nil config for optional directory config path") } } + +func TestCheckedPathLengthPlusOne(t *testing.T) { + if got := checkedPathLengthPlusOne(4); got != 5 { + t.Fatalf("expected 5, got %d", got) + } + + maxInt := int(^uint(0) >> 1) + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic for overflow path length") + } + }() + _ = checkedPathLengthPlusOne(maxInt) +} diff --git a/pkg/llmproxy/executor/antigravity_executor.go b/pkg/llmproxy/executor/antigravity_executor.go index d456b50951..79c6efd5a2 100644 --- a/pkg/llmproxy/executor/antigravity_executor.go +++ b/pkg/llmproxy/executor/antigravity_executor.go @@ -9,6 +9,7 @@ import ( "context" "crypto/sha256" "encoding/binary" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -213,7 +214,7 @@ attemptLoop: } if attempt+1 < attempts { delay := antigravityNoCapacityRetryDelay(attempt) - log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts) + log.Debugf("antigravity executor: no capacity, retrying in %s (attempt %d/%d)", delay, attempt+1, attempts) if errWait := antigravityWait(ctx, delay); errWait != nil { return resp, errWait } @@ -258,6 +259,15 @@ attemptLoop: return resp, err } +func antigravityModelFingerprint(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return hex.EncodeToString(sum[:8]) +} + // executeClaudeNonStream performs a claude non-streaming request to the Antigravity API. func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { baseModel := thinking.ParseSuffix(req.Model).ModelName @@ -758,7 +768,7 @@ attemptLoop: } if attempt+1 < attempts { delay := antigravityNoCapacityRetryDelay(attempt) - log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts) + log.Debugf("antigravity executor: no capacity, retrying in %s (attempt %d/%d)", delay, attempt+1, attempts) if errWait := antigravityWait(ctx, delay); errWait != nil { return nil, errWait } @@ -905,6 +915,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if base == "" { base = buildBaseURL(e.cfg, auth) } + base, err = sanitizeAntigravityBaseURL(base) + if err != nil { + return cliproxyexecutor.Response{}, err + } var requestURL strings.Builder requestURL.WriteString(base) @@ -1530,6 +1544,16 @@ func resolveHost(base string) string { return strings.TrimPrefix(strings.TrimPrefix(base, "https://"), "http://") } +func sanitizeAntigravityBaseURL(base string) (string, error) { + normalized := strings.TrimSuffix(strings.TrimSpace(base), "/") + switch normalized { + case antigravityBaseURLDaily, antigravitySandboxBaseURLDaily, antigravityBaseURLProd: + return normalized, nil + default: + return "", fmt.Errorf("antigravity executor: unsupported base url %q", base) + } +} + func resolveUserAgent(auth *cliproxyauth.Auth) string { if auth != nil { if auth.Attributes != nil { diff --git a/pkg/llmproxy/executor/antigravity_executor_logging_test.go b/pkg/llmproxy/executor/antigravity_executor_logging_test.go new file mode 100644 index 0000000000..ce17fad150 --- /dev/null +++ b/pkg/llmproxy/executor/antigravity_executor_logging_test.go @@ -0,0 +1,14 @@ +package executor + +import "testing" + +func TestAntigravityModelFingerprint_RedactsRawModel(t *testing.T) { + raw := "my-sensitive-model-name" + got := antigravityModelFingerprint(raw) + if got == "" { + t.Fatal("expected non-empty fingerprint") + } + if got == raw { + t.Fatalf("fingerprint must not equal raw model: %q", got) + } +} diff --git a/pkg/llmproxy/executor/antigravity_executor_security_test.go b/pkg/llmproxy/executor/antigravity_executor_security_test.go new file mode 100644 index 0000000000..4f44c62c6b --- /dev/null +++ b/pkg/llmproxy/executor/antigravity_executor_security_test.go @@ -0,0 +1,30 @@ +package executor + +import "testing" + +func TestSanitizeAntigravityBaseURL_AllowsKnownHosts(t *testing.T) { + t.Parallel() + + cases := []string{ + antigravityBaseURLDaily, + antigravitySandboxBaseURLDaily, + antigravityBaseURLProd, + } + for _, base := range cases { + got, err := sanitizeAntigravityBaseURL(base) + if err != nil { + t.Fatalf("sanitizeAntigravityBaseURL(%q) error: %v", base, err) + } + if got != base { + t.Fatalf("sanitizeAntigravityBaseURL(%q) = %q, want %q", base, got, base) + } + } +} + +func TestSanitizeAntigravityBaseURL_RejectsUntrustedHost(t *testing.T) { + t.Parallel() + + if _, err := sanitizeAntigravityBaseURL("https://127.0.0.1:8080"); err == nil { + t.Fatal("expected error for untrusted antigravity base URL") + } +} diff --git a/pkg/llmproxy/executor/codex_websockets_executor.go b/pkg/llmproxy/executor/codex_websockets_executor.go index a5e820af33..598b87e5ea 100644 --- a/pkg/llmproxy/executor/codex_websockets_executor.go +++ b/pkg/llmproxy/executor/codex_websockets_executor.go @@ -5,6 +5,8 @@ package executor import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net" @@ -399,7 +401,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { - log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) + log.Debug("Executing Codex Websockets stream request") if ctx == nil { ctx = context.Background() } @@ -1295,15 +1297,34 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess } func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { - log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) + log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL)) } -func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { +func logCodexWebsocketDisconnected(sessionID string, _ string, _ string, reason string, err error) { if err != nil { - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason), err) return } - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason)) +} + +func sanitizeCodexWebsocketLogField(raw string) string { + return util.HideAPIKey(strings.TrimSpace(raw)) +} + +func sanitizeCodexWebsocketLogURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || !parsed.IsAbs() { + return util.HideAPIKey(trimmed) + } + parsed.User = nil + parsed.Fragment = "" + parsed.RawQuery = util.MaskSensitiveQuery(parsed.RawQuery) + return parsed.String() } // CodexAutoExecutor routes Codex requests to the websocket transport only when: diff --git a/pkg/llmproxy/executor/codex_websockets_executor_logging_test.go b/pkg/llmproxy/executor/codex_websockets_executor_logging_test.go new file mode 100644 index 0000000000..6fc69acef1 --- /dev/null +++ b/pkg/llmproxy/executor/codex_websockets_executor_logging_test.go @@ -0,0 +1,28 @@ +package executor + +import ( + "strings" + "testing" +) + +func TestSanitizeCodexWebsocketLogURLMasksQueryAndUserInfo(t *testing.T) { + raw := "wss://user:secret@example.com/v1/realtime?api_key=verysecret&token=abc123&foo=bar#frag" + got := sanitizeCodexWebsocketLogURL(raw) + + if strings.Contains(got, "secret") || strings.Contains(got, "abc123") || strings.Contains(got, "verysecret") { + t.Fatalf("expected sensitive values to be masked, got %q", got) + } + if strings.Contains(got, "user:") { + t.Fatalf("expected userinfo to be removed, got %q", got) + } + if strings.Contains(got, "#frag") { + t.Fatalf("expected fragment to be removed, got %q", got) + } +} + +func TestSanitizeCodexWebsocketLogFieldMasksTokenLikeValue(t *testing.T) { + got := sanitizeCodexWebsocketLogField(" sk-super-secret-token ") + if got == "sk-super-secret-token" { + t.Fatalf("expected auth field to be masked, got %q", got) + } +} diff --git a/pkg/llmproxy/executor/gemini_cli_executor.go b/pkg/llmproxy/executor/gemini_cli_executor.go index 072f70ffa6..214373bbfc 100644 --- a/pkg/llmproxy/executor/gemini_cli_executor.go +++ b/pkg/llmproxy/executor/gemini_cli_executor.go @@ -236,7 +236,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) if httpResp.StatusCode == 429 { if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) + log.Debug("gemini cli executor: rate limited, retrying with next model") } else { log.Debug("gemini cli executor: rate limited, no additional fallback model") } @@ -373,7 +373,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) if httpResp.StatusCode == 429 { if idx+1 < len(models) { - log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1]) + log.Debug("gemini cli executor: rate limited, retrying with next model") } else { log.Debug("gemini cli executor: rate limited, no additional fallback model") } diff --git a/pkg/llmproxy/executor/iflow_executor.go b/pkg/llmproxy/executor/iflow_executor.go index 8fbf189f47..eacbeb5d0e 100644 --- a/pkg/llmproxy/executor/iflow_executor.go +++ b/pkg/llmproxy/executor/iflow_executor.go @@ -17,7 +17,6 @@ import ( iflowauth "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/auth/iflow" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/config" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/thinking" - "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -409,9 +408,9 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut return auth, nil } - // Log the old access token (masked) before refresh + // Log refresh start without including token material. if oldAccessToken != "" { - log.Debugf("iflow executor: refreshing access token, old: %s", util.HideAPIKey(oldAccessToken)) + log.Debug("iflow executor: refreshing access token") } svc := iflowauth.NewIFlowAuth(e.cfg, nil) @@ -435,8 +434,7 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut auth.Metadata["type"] = "iflow" auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) - // Log the new access token (masked) after successful refresh - log.Debugf("iflow executor: token refresh successful, new: %s", util.HideAPIKey(tokenData.AccessToken)) + log.Debug("iflow executor: token refresh successful") if auth.Attributes == nil { auth.Attributes = make(map[string]string) diff --git a/pkg/llmproxy/executor/kiro_executor.go b/pkg/llmproxy/executor/kiro_executor.go index 5ed6a73bea..3b7916fe76 100644 --- a/pkg/llmproxy/executor/kiro_executor.go +++ b/pkg/llmproxy/executor/kiro_executor.go @@ -4,8 +4,10 @@ import ( "bufio" "bytes" "context" + "crypto/sha256" "encoding/base64" "encoding/binary" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -1645,7 +1647,7 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { // Check for Haiku variants if strings.Contains(modelLower, "haiku") { - log.Debugf("kiro: unknown Haiku model '%s', mapping to claude-haiku-4.5", model) + log.Debug("kiro: unknown haiku variant, mapping to claude-haiku-4.5") return "claude-haiku-4.5" } @@ -1653,37 +1655,42 @@ func (e *KiroExecutor) mapModelToKiro(model string) string { if strings.Contains(modelLower, "sonnet") { // Check for specific version patterns if strings.Contains(modelLower, "3-7") || strings.Contains(modelLower, "3.7") { - log.Debugf("kiro: unknown Sonnet 3.7 model '%s', mapping to claude-3-7-sonnet-20250219", model) + log.Debug("kiro: unknown sonnet 3.7 variant, mapping to claude-3-7-sonnet-20250219") return "claude-3-7-sonnet-20250219" } if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { - log.Debugf("kiro: unknown Sonnet 4.6 model '%s', mapping to claude-sonnet-4.6", model) + log.Debug("kiro: unknown sonnet 4.6 variant, mapping to claude-sonnet-4.6") return "claude-sonnet-4.6" } if strings.Contains(modelLower, "4-5") || strings.Contains(modelLower, "4.5") { - log.Debugf("kiro: unknown Sonnet 4.5 model '%s', mapping to claude-sonnet-4.5", model) + log.Debug("kiro: unknown Sonnet 4.5 model, mapping to claude-sonnet-4.5") return "claude-sonnet-4.5" } - // Default to Sonnet 4 - log.Debugf("kiro: unknown Sonnet model '%s', mapping to claude-sonnet-4", model) - return "claude-sonnet-4" - } // Check for Opus variants if strings.Contains(modelLower, "opus") { if strings.Contains(modelLower, "4-6") || strings.Contains(modelLower, "4.6") { - log.Debugf("kiro: unknown Opus 4.6 model '%s', mapping to claude-opus-4.6", model) + log.Debug("kiro: unknown Opus 4.6 model, mapping to claude-opus-4.6") return "claude-opus-4.6" } - log.Debugf("kiro: unknown Opus model '%s', mapping to claude-opus-4.5", model) + log.Debug("kiro: unknown opus variant, mapping to claude-opus-4.5") return "claude-opus-4.5" } // Final fallback to Sonnet 4.5 (most commonly used model) - log.Warnf("kiro: unknown model '%s', falling back to claude-sonnet-4.5", model) + log.Warn("kiro: unknown model variant, falling back to claude-sonnet-4.5") return "claude-sonnet-4.5" } +func kiroModelFingerprint(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return "" + } + sum := sha256.Sum256([]byte(trimmed)) + return hex.EncodeToString(sum[:8]) +} + // EventStreamError represents an Event Stream processing error type EventStreamError struct { Type string // "fatal", "malformed" diff --git a/pkg/llmproxy/executor/kiro_executor_logging_test.go b/pkg/llmproxy/executor/kiro_executor_logging_test.go new file mode 100644 index 0000000000..a42c3bc7ea --- /dev/null +++ b/pkg/llmproxy/executor/kiro_executor_logging_test.go @@ -0,0 +1,14 @@ +package executor + +import "testing" + +func TestKiroModelFingerprint_RedactsRawModel(t *testing.T) { + raw := "user-custom-model-with-sensitive-suffix" + got := kiroModelFingerprint(raw) + if got == "" { + t.Fatal("expected non-empty fingerprint") + } + if got == raw { + t.Fatalf("fingerprint must not equal raw model: %q", got) + } +} diff --git a/pkg/llmproxy/executor/user_id_cache.go b/pkg/llmproxy/executor/user_id_cache.go index ff8efd9d1d..fc64823131 100644 --- a/pkg/llmproxy/executor/user_id_cache.go +++ b/pkg/llmproxy/executor/user_id_cache.go @@ -1,7 +1,8 @@ package executor import ( - "crypto/sha256" + "crypto/hmac" + "crypto/sha512" "encoding/hex" "sync" "time" @@ -21,6 +22,7 @@ var ( const ( userIDTTL = time.Hour userIDCacheCleanupPeriod = 15 * time.Minute + userIDCacheHashKey = "executor-user-id-cache:v1" ) func startUserIDCacheCleanup() { @@ -45,8 +47,9 @@ func purgeExpiredUserIDs() { } func userIDCacheKey(apiKey string) string { - sum := sha256.Sum256([]byte(apiKey)) - return hex.EncodeToString(sum[:]) + hasher := hmac.New(sha512.New, []byte(userIDCacheHashKey)) + hasher.Write([]byte(apiKey)) + return hex.EncodeToString(hasher.Sum(nil)) } func cachedUserID(apiKey string) string { diff --git a/pkg/llmproxy/executor/user_id_cache_test.go b/pkg/llmproxy/executor/user_id_cache_test.go index 420a3cad43..4b1ed0c2e9 100644 --- a/pkg/llmproxy/executor/user_id_cache_test.go +++ b/pkg/llmproxy/executor/user_id_cache_test.go @@ -1,6 +1,8 @@ package executor import ( + "crypto/sha256" + "encoding/hex" "testing" "time" ) @@ -84,3 +86,16 @@ func TestCachedUserID_RenewsTTLOnHit(t *testing.T) { t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon)) } } + +func TestUserIDCacheKey_DoesNotUseLegacySHA256(t *testing.T) { + apiKey := "api-key-legacy-check" + got := userIDCacheKey(apiKey) + if got == "" { + t.Fatal("expected non-empty cache key") + } + + legacy := sha256.Sum256([]byte(apiKey)) + if got == hex.EncodeToString(legacy[:]) { + t.Fatalf("expected cache key to differ from legacy sha256") + } +} diff --git a/pkg/llmproxy/logging/request_logger.go b/pkg/llmproxy/logging/request_logger.go index b2f135cd48..2aebb888cc 100644 --- a/pkg/llmproxy/logging/request_logger.go +++ b/pkg/llmproxy/logging/request_logger.go @@ -399,7 +399,7 @@ func (l *FileRequestLogger) generateFilename(url string, requestID ...string) st // Use request ID if provided, otherwise use sequential ID var idPart string if len(requestID) > 0 && requestID[0] != "" { - idPart = requestID[0] + idPart = l.sanitizeForFilename(requestID[0]) } else { id := requestLogID.Add(1) idPart = fmt.Sprintf("%d", id) diff --git a/pkg/llmproxy/logging/request_logger_security_test.go b/pkg/llmproxy/logging/request_logger_security_test.go new file mode 100644 index 0000000000..6483597d2b --- /dev/null +++ b/pkg/llmproxy/logging/request_logger_security_test.go @@ -0,0 +1,27 @@ +package logging + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestGenerateFilename_SanitizesRequestIDForPathSafety(t *testing.T) { + t.Parallel() + + logsDir := t.TempDir() + logger := NewFileRequestLogger(true, logsDir, "", 0) + + filename := logger.generateFilename("/v1/responses", "../escape-path") + resolved := filepath.Join(logsDir, filename) + rel, err := filepath.Rel(logsDir, resolved) + if err != nil { + t.Fatalf("filepath.Rel failed: %v", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + t.Fatalf("generated filename escaped logs dir: %s", filename) + } + if strings.Contains(filename, "/") { + t.Fatalf("generated filename contains path separator: %s", filename) + } +} diff --git a/pkg/llmproxy/misc/credentials.go b/pkg/llmproxy/misc/credentials.go index b03cd788d2..86225ff7ae 100644 --- a/pkg/llmproxy/misc/credentials.go +++ b/pkg/llmproxy/misc/credentials.go @@ -24,3 +24,22 @@ func LogSavingCredentials(path string) { func LogCredentialSeparator() { log.Debug(credentialSeparator) } + +// ValidateCredentialPath rejects unsafe credential file paths and returns a cleaned path. +func ValidateCredentialPath(path string) (string, error) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", fmt.Errorf("credential path is empty") + } + if strings.ContainsRune(trimmed, '\x00') { + return "", fmt.Errorf("credential path contains NUL byte") + } + cleaned := filepath.Clean(trimmed) + if cleaned == "." { + return "", fmt.Errorf("credential path is invalid") + } + if cleaned != trimmed { + return "", fmt.Errorf("credential path must be clean and traversal-free") + } + return cleaned, nil +} diff --git a/pkg/llmproxy/misc/path_security.go b/pkg/llmproxy/misc/path_security.go new file mode 100644 index 0000000000..28e78e9575 --- /dev/null +++ b/pkg/llmproxy/misc/path_security.go @@ -0,0 +1,69 @@ +package misc + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +// ResolveSafeFilePath validates and normalizes a file path, rejecting path traversal components. +func ResolveSafeFilePath(path string) (string, error) { + trimmed := strings.TrimSpace(path) + if trimmed == "" { + return "", fmt.Errorf("path is empty") + } + if hasPathTraversalComponent(trimmed) { + return "", fmt.Errorf("path traversal is not allowed") + } + cleaned := filepath.Clean(trimmed) + if cleaned == "." { + return "", fmt.Errorf("path is invalid") + } + return cleaned, nil +} + +// ResolveSafeFilePathInDir resolves a file name inside baseDir and rejects paths that escape baseDir. +func ResolveSafeFilePathInDir(baseDir, fileName string) (string, error) { + base := strings.TrimSpace(baseDir) + if base == "" { + return "", fmt.Errorf("base directory is empty") + } + name := strings.TrimSpace(fileName) + if name == "" { + return "", fmt.Errorf("file name is empty") + } + if strings.Contains(name, "/") || strings.Contains(name, "\\") { + return "", fmt.Errorf("file name must not contain path separators") + } + if hasPathTraversalComponent(name) { + return "", fmt.Errorf("file name must not contain traversal components") + } + cleanName := filepath.Clean(name) + if cleanName == "." || cleanName == ".." { + return "", fmt.Errorf("file name is invalid") + } + baseAbs, err := filepath.Abs(base) + if err != nil { + return "", fmt.Errorf("resolve base directory: %w", err) + } + resolved := filepath.Clean(filepath.Join(baseAbs, cleanName)) + rel, err := filepath.Rel(baseAbs, resolved) + if err != nil { + return "", fmt.Errorf("resolve relative path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("resolved path escapes base directory") + } + return resolved, nil +} + +func hasPathTraversalComponent(path string) bool { + normalized := strings.ReplaceAll(path, "\\", "/") + for _, component := range strings.Split(normalized, "/") { + if component == ".." { + return true + } + } + return false +} diff --git a/pkg/llmproxy/misc/path_security_test.go b/pkg/llmproxy/misc/path_security_test.go new file mode 100644 index 0000000000..6eaf1d2beb --- /dev/null +++ b/pkg/llmproxy/misc/path_security_test.go @@ -0,0 +1,36 @@ +package misc + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestResolveSafeFilePathRejectsTraversal(t *testing.T) { + _, err := ResolveSafeFilePath("/tmp/../escape.json") + if err == nil { + t.Fatal("expected traversal path to be rejected") + } +} + +func TestResolveSafeFilePathInDirRejectsSeparatorsAndTraversal(t *testing.T) { + base := t.TempDir() + + if _, err := ResolveSafeFilePathInDir(base, "..\\escape.json"); err == nil { + t.Fatal("expected backslash traversal payload to be rejected") + } + if _, err := ResolveSafeFilePathInDir(base, "../escape.json"); err == nil { + t.Fatal("expected slash traversal payload to be rejected") + } +} + +func TestResolveSafeFilePathInDirResolvesInsideBaseDir(t *testing.T) { + base := t.TempDir() + path, err := ResolveSafeFilePathInDir(base, "valid.json") + if err != nil { + t.Fatalf("expected valid file name: %v", err) + } + if !strings.HasPrefix(path, filepath.Clean(base)+string(filepath.Separator)) { + t.Fatalf("expected resolved path %q under base %q", path, base) + } +} diff --git a/pkg/llmproxy/registry/model_registry.go b/pkg/llmproxy/registry/model_registry.go index 1a60f6cd40..2509afd260 100644 --- a/pkg/llmproxy/registry/model_registry.go +++ b/pkg/llmproxy/registry/model_registry.go @@ -602,7 +602,7 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) { if registration, exists := r.models[modelID]; exists { registration.QuotaExceededClients[clientID] = new(time.Now()) - log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID) + log.Debug("Marked model as quota exceeded for client") } } @@ -645,9 +645,9 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) { registration.SuspendedClients[clientID] = reason registration.LastUpdated = time.Now() if reason != "" { - log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason) + log.Debugf("Suspended client %s for model %s (reason provided)", clientID, modelID) } else { - log.Debugf("Suspended client %s for model %s", clientID, modelID) + log.Debug("Suspended client for model") } } @@ -671,7 +671,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { } delete(registration.SuspendedClients, clientID) registration.LastUpdated = time.Now() - log.Debugf("Resumed client %s for model %s", clientID, modelID) + log.Debug("Resumed suspended client for model") } // ClientSupportsModel reports whether the client registered support for modelID. diff --git a/pkg/llmproxy/runtime/executor/codex_websockets_executor.go b/pkg/llmproxy/runtime/executor/codex_websockets_executor.go index a5e820af33..be72c632d4 100644 --- a/pkg/llmproxy/runtime/executor/codex_websockets_executor.go +++ b/pkg/llmproxy/runtime/executor/codex_websockets_executor.go @@ -5,6 +5,8 @@ package executor import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" "io" "net" @@ -399,7 +401,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { - log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) + log.Debug("executing codex websockets stream request") if ctx == nil { ctx = context.Background() } @@ -1295,15 +1297,34 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess } func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { - log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) + log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL)) } -func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { +func logCodexWebsocketDisconnected(sessionID string, _ string, _ string, reason string, err error) { if err != nil { - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason), err) return } - log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), sanitizeCodexWebsocketLogField(authID), sanitizeCodexWebsocketLogURL(wsURL), strings.TrimSpace(reason)) +} + +func sanitizeCodexWebsocketLogField(raw string) string { + return util.HideAPIKey(strings.TrimSpace(raw)) +} + +func sanitizeCodexWebsocketLogURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || !parsed.IsAbs() { + return util.HideAPIKey(trimmed) + } + parsed.User = nil + parsed.Fragment = "" + parsed.RawQuery = util.MaskSensitiveQuery(parsed.RawQuery) + return parsed.String() } // CodexAutoExecutor routes Codex requests to the websocket transport only when: diff --git a/pkg/llmproxy/runtime/executor/codex_websockets_executor_logging_test.go b/pkg/llmproxy/runtime/executor/codex_websockets_executor_logging_test.go new file mode 100644 index 0000000000..6fc69acef1 --- /dev/null +++ b/pkg/llmproxy/runtime/executor/codex_websockets_executor_logging_test.go @@ -0,0 +1,28 @@ +package executor + +import ( + "strings" + "testing" +) + +func TestSanitizeCodexWebsocketLogURLMasksQueryAndUserInfo(t *testing.T) { + raw := "wss://user:secret@example.com/v1/realtime?api_key=verysecret&token=abc123&foo=bar#frag" + got := sanitizeCodexWebsocketLogURL(raw) + + if strings.Contains(got, "secret") || strings.Contains(got, "abc123") || strings.Contains(got, "verysecret") { + t.Fatalf("expected sensitive values to be masked, got %q", got) + } + if strings.Contains(got, "user:") { + t.Fatalf("expected userinfo to be removed, got %q", got) + } + if strings.Contains(got, "#frag") { + t.Fatalf("expected fragment to be removed, got %q", got) + } +} + +func TestSanitizeCodexWebsocketLogFieldMasksTokenLikeValue(t *testing.T) { + got := sanitizeCodexWebsocketLogField(" sk-super-secret-token ") + if got == "sk-super-secret-token" { + t.Fatalf("expected auth field to be masked, got %q", got) + } +} diff --git a/pkg/llmproxy/runtime/executor/iflow_executor.go b/pkg/llmproxy/runtime/executor/iflow_executor.go index 2fb4259d86..c9808a8df0 100644 --- a/pkg/llmproxy/runtime/executor/iflow_executor.go +++ b/pkg/llmproxy/runtime/executor/iflow_executor.go @@ -17,7 +17,6 @@ import ( iflowauth "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/auth/iflow" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/config" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/thinking" - "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -409,9 +408,9 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut return auth, nil } - // Log the old access token (masked) before refresh + // Avoid logging token material. if oldAccessToken != "" { - log.Debugf("iflow executor: refreshing access token, old: %s", util.HideAPIKey(oldAccessToken)) + log.Debug("iflow executor: refreshing access token") } svc := iflowauth.NewIFlowAuth(e.cfg, nil) @@ -435,8 +434,7 @@ func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyaut auth.Metadata["type"] = "iflow" auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339) - // Log the new access token (masked) after successful refresh - log.Debugf("iflow executor: token refresh successful, new: %s", util.HideAPIKey(tokenData.AccessToken)) + log.Debug("iflow executor: token refresh successful") if auth.Attributes == nil { auth.Attributes = make(map[string]string) diff --git a/pkg/llmproxy/store/gitstore.go b/pkg/llmproxy/store/gitstore.go index e27f177145..aa258c98e1 100644 --- a/pkg/llmproxy/store/gitstore.go +++ b/pkg/llmproxy/store/gitstore.go @@ -225,6 +225,10 @@ func (s *GitTokenStore) Save(_ context.Context, auth *cliproxyauth.Auth) (string if path == "" { return "", fmt.Errorf("auth filestore: missing file path attribute for %s", auth.ID) } + path, err = ensurePathWithinDir(path, s.baseDirSnapshot(), "auth filestore") + if err != nil { + return "", err + } if auth.Disabled { if _, statErr := os.Stat(path); os.IsNotExist(statErr) { @@ -399,14 +403,26 @@ func (s *GitTokenStore) PersistAuthFiles(_ context.Context, message string, path } func (s *GitTokenStore) resolveDeletePath(id string) (string, error) { - if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) { - return id, nil - } dir := s.baseDirSnapshot() if dir == "" { return "", fmt.Errorf("auth filestore: directory not configured") } - return filepath.Join(dir, id), nil + clean := filepath.Clean(filepath.FromSlash(strings.TrimSpace(id))) + if clean == "." || clean == "" { + return "", fmt.Errorf("auth filestore: invalid id") + } + if filepath.IsAbs(clean) || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("auth filestore: id resolves outside auth directory") + } + path := filepath.Join(dir, clean) + rel, err := filepath.Rel(dir, path) + if err != nil { + return "", fmt.Errorf("auth filestore: relative path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("auth filestore: id resolves outside auth directory") + } + return path, nil } func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { diff --git a/pkg/llmproxy/store/gitstore_security_test.go b/pkg/llmproxy/store/gitstore_security_test.go new file mode 100644 index 0000000000..9d7b6340e3 --- /dev/null +++ b/pkg/llmproxy/store/gitstore_security_test.go @@ -0,0 +1,42 @@ +package store + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestResolveDeletePath_RejectsTraversalAndAbsolute(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + s := &GitTokenStore{} + s.SetBaseDir(baseDir) + + if _, err := s.resolveDeletePath("../outside.json"); err == nil { + t.Fatalf("expected traversal id to be rejected") + } + if _, err := s.resolveDeletePath(filepath.Join(baseDir, "nested", "token.json")); err == nil { + t.Fatalf("expected absolute id to be rejected") + } +} + +func TestResolveDeletePath_ReturnsPathInsideBaseDir(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + s := &GitTokenStore{} + s.SetBaseDir(baseDir) + + path, err := s.resolveDeletePath("nested/token.json") + if err != nil { + t.Fatalf("resolveDeletePath failed: %v", err) + } + rel, err := filepath.Rel(baseDir, path) + if err != nil { + t.Fatalf("filepath.Rel failed: %v", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + t.Fatalf("resolved path escaped base dir: %s", path) + } +} diff --git a/pkg/llmproxy/store/objectstore.go b/pkg/llmproxy/store/objectstore.go index 2f738ab918..9c4be6b8aa 100644 --- a/pkg/llmproxy/store/objectstore.go +++ b/pkg/llmproxy/store/objectstore.go @@ -168,6 +168,10 @@ func (s *ObjectTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (s if path == "" { return "", fmt.Errorf("object store: missing file path attribute for %s", auth.ID) } + path, err = ensurePathWithinDir(path, s.authDir, "object store") + if err != nil { + return "", err + } if auth.Disabled { if _, statErr := os.Stat(path); errors.Is(statErr, fs.ErrNotExist) { @@ -512,10 +516,7 @@ func (s *ObjectTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, err } if auth.Attributes != nil { if path := strings.TrimSpace(auth.Attributes["path"]); path != "" { - if filepath.IsAbs(path) { - return path, nil - } - return filepath.Join(s.authDir, path), nil + return s.ensureManagedAuthPath(path) } } fileName := strings.TrimSpace(auth.FileName) @@ -528,7 +529,7 @@ func (s *ObjectTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, err if !strings.HasSuffix(strings.ToLower(fileName), ".json") { fileName += ".json" } - return filepath.Join(s.authDir, fileName), nil + return s.ensureManagedAuthPath(fileName) } func (s *ObjectTokenStore) resolveDeletePath(id string) (string, error) { @@ -536,21 +537,47 @@ func (s *ObjectTokenStore) resolveDeletePath(id string) (string, error) { if id == "" { return "", fmt.Errorf("object store: id is empty") } - // Absolute paths are honored as-is; callers must ensure they point inside the mirror. - if filepath.IsAbs(id) { - return id, nil - } - // Treat any non-absolute id (including nested like "team/foo") as relative to the mirror authDir. - // Normalize separators and guard against path traversal. clean := filepath.Clean(filepath.FromSlash(id)) if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) { return "", fmt.Errorf("object store: invalid auth identifier %s", id) } - // Ensure .json suffix. if !strings.HasSuffix(strings.ToLower(clean), ".json") { clean += ".json" } - return filepath.Join(s.authDir, clean), nil + return s.ensureManagedAuthPath(clean) +} + +func (s *ObjectTokenStore) ensureManagedAuthPath(path string) (string, error) { + if s == nil { + return "", fmt.Errorf("object store: store not initialized") + } + authDir := strings.TrimSpace(s.authDir) + if authDir == "" { + return "", fmt.Errorf("object store: auth directory not configured") + } + absAuthDir, err := filepath.Abs(authDir) + if err != nil { + return "", fmt.Errorf("object store: resolve auth directory: %w", err) + } + candidate := strings.TrimSpace(path) + if candidate == "" { + return "", fmt.Errorf("object store: auth path is empty") + } + if !filepath.IsAbs(candidate) { + candidate = filepath.Join(absAuthDir, filepath.FromSlash(candidate)) + } + absCandidate, err := filepath.Abs(candidate) + if err != nil { + return "", fmt.Errorf("object store: resolve auth path %q: %w", path, err) + } + rel, err := filepath.Rel(absAuthDir, absCandidate) + if err != nil { + return "", fmt.Errorf("object store: compute relative auth path: %w", err) + } + if rel == "." || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("object store: path %q escapes auth directory", path) + } + return absCandidate, nil } func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { diff --git a/pkg/llmproxy/store/objectstore_path_test.go b/pkg/llmproxy/store/objectstore_path_test.go new file mode 100644 index 0000000000..653c197670 --- /dev/null +++ b/pkg/llmproxy/store/objectstore_path_test.go @@ -0,0 +1,58 @@ +package store + +import ( + "path/filepath" + "strings" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestObjectResolveAuthPathRejectsTraversalFromAttributes(t *testing.T) { + t.Parallel() + + store := &ObjectTokenStore{authDir: filepath.Join(t.TempDir(), "auths")} + auth := &cliproxyauth.Auth{ + Attributes: map[string]string{"path": "../escape.json"}, + } + if _, err := store.resolveAuthPath(auth); err == nil { + t.Fatalf("expected traversal path rejection") + } +} + +func TestObjectResolveAuthPathRejectsAbsoluteOutsideAuthDir(t *testing.T) { + t.Parallel() + + root := t.TempDir() + store := &ObjectTokenStore{authDir: filepath.Join(root, "auths")} + outside := filepath.Join(root, "..", "outside.json") + auth := &cliproxyauth.Auth{ + Attributes: map[string]string{"path": outside}, + } + if _, err := store.resolveAuthPath(auth); err == nil { + t.Fatalf("expected outside absolute path rejection") + } +} + +func TestObjectResolveDeletePathConstrainsToAuthDir(t *testing.T) { + t.Parallel() + + root := t.TempDir() + authDir := filepath.Join(root, "auths") + store := &ObjectTokenStore{authDir: authDir} + + got, err := store.resolveDeletePath("team/provider") + if err != nil { + t.Fatalf("resolve delete path: %v", err) + } + if !strings.HasSuffix(got, filepath.Join("team", "provider.json")) { + t.Fatalf("expected .json suffix, got %s", got) + } + rel, err := filepath.Rel(authDir, got) + if err != nil { + t.Fatalf("relative path: %v", err) + } + if strings.HasPrefix(rel, "..") || rel == "." { + t.Fatalf("path escaped auth directory: %s", got) + } +} diff --git a/pkg/llmproxy/store/path_guard.go b/pkg/llmproxy/store/path_guard.go new file mode 100644 index 0000000000..fd2c9b7eb1 --- /dev/null +++ b/pkg/llmproxy/store/path_guard.go @@ -0,0 +1,39 @@ +package store + +import ( + "fmt" + "os" + "path/filepath" + "strings" +) + +func ensurePathWithinDir(path, baseDir, scope string) (string, error) { + trimmedPath := strings.TrimSpace(path) + if trimmedPath == "" { + return "", fmt.Errorf("%s: path is empty", scope) + } + trimmedBase := strings.TrimSpace(baseDir) + if trimmedBase == "" { + return "", fmt.Errorf("%s: base directory is not configured", scope) + } + + absBase, err := filepath.Abs(trimmedBase) + if err != nil { + return "", fmt.Errorf("%s: resolve base directory: %w", scope, err) + } + absPath, err := filepath.Abs(trimmedPath) + if err != nil { + return "", fmt.Errorf("%s: resolve path: %w", scope, err) + } + cleanBase := filepath.Clean(absBase) + cleanPath := filepath.Clean(absPath) + + rel, err := filepath.Rel(cleanBase, cleanPath) + if err != nil { + return "", fmt.Errorf("%s: compute relative path: %w", scope, err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("%s: path escapes managed directory", scope) + } + return cleanPath, nil +} diff --git a/pkg/llmproxy/store/path_guard_test.go b/pkg/llmproxy/store/path_guard_test.go new file mode 100644 index 0000000000..12e5edd685 --- /dev/null +++ b/pkg/llmproxy/store/path_guard_test.go @@ -0,0 +1,57 @@ +package store + +import ( + "context" + "path/filepath" + "strings" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestObjectTokenStoreSaveRejectsPathOutsideAuthDir(t *testing.T) { + t.Parallel() + + authDir := filepath.Join(t.TempDir(), "auths") + store := &ObjectTokenStore{authDir: authDir} + outside := filepath.Join(t.TempDir(), "outside.json") + auth := &cliproxyauth.Auth{ + ID: "outside", + Disabled: true, + Attributes: map[string]string{ + "path": outside, + }, + } + + _, err := store.Save(context.Background(), auth) + if err == nil { + t.Fatal("expected error for path outside managed auth directory") + } + if !strings.Contains(err.Error(), "escapes") { + t.Fatalf("expected managed directory error, got: %v", err) + } +} + +func TestGitTokenStoreSaveRejectsPathOutsideAuthDir(t *testing.T) { + t.Parallel() + + baseDir := filepath.Join(t.TempDir(), "repo", "auths") + store := NewGitTokenStore("", "", "") + store.SetBaseDir(baseDir) + outside := filepath.Join(t.TempDir(), "outside.json") + auth := &cliproxyauth.Auth{ + ID: "outside", + Attributes: map[string]string{ + "path": outside, + }, + Metadata: map[string]any{"type": "test"}, + } + + _, err := store.Save(context.Background(), auth) + if err == nil { + t.Fatal("expected error for path outside managed auth directory") + } + if !strings.Contains(err.Error(), "escapes") { + t.Fatalf("expected managed directory error, got: %v", err) + } +} diff --git a/pkg/llmproxy/store/postgresstore.go b/pkg/llmproxy/store/postgresstore.go index c10b5abfdd..72baae5975 100644 --- a/pkg/llmproxy/store/postgresstore.go +++ b/pkg/llmproxy/store/postgresstore.go @@ -550,29 +550,57 @@ func (s *PostgresStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) } if auth.Attributes != nil { if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { - return p, nil + return s.ensureManagedAuthPath(p) } } if fileName := strings.TrimSpace(auth.FileName); fileName != "" { - if filepath.IsAbs(fileName) { - return fileName, nil - } - return filepath.Join(s.authDir, fileName), nil + return s.ensureManagedAuthPath(fileName) } if auth.ID == "" { return "", fmt.Errorf("postgres store: missing id") } - if filepath.IsAbs(auth.ID) { - return auth.ID, nil - } - return filepath.Join(s.authDir, filepath.FromSlash(auth.ID)), nil + return s.ensureManagedAuthPath(auth.ID) } func (s *PostgresStore) resolveDeletePath(id string) (string, error) { - if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) { - return id, nil + id = strings.TrimSpace(id) + if id == "" { + return "", fmt.Errorf("postgres store: id is empty") } - return filepath.Join(s.authDir, filepath.FromSlash(id)), nil + return s.ensureManagedAuthPath(id) +} + +func (s *PostgresStore) ensureManagedAuthPath(path string) (string, error) { + if s == nil { + return "", fmt.Errorf("postgres store: store not initialized") + } + authDir := strings.TrimSpace(s.authDir) + if authDir == "" { + return "", fmt.Errorf("postgres store: auth directory not configured") + } + absAuthDir, err := filepath.Abs(authDir) + if err != nil { + return "", fmt.Errorf("postgres store: resolve auth directory: %w", err) + } + candidate := strings.TrimSpace(path) + if candidate == "" { + return "", fmt.Errorf("postgres store: auth path is empty") + } + if !filepath.IsAbs(candidate) { + candidate = filepath.Join(absAuthDir, filepath.FromSlash(candidate)) + } + absCandidate, err := filepath.Abs(candidate) + if err != nil { + return "", fmt.Errorf("postgres store: resolve auth path %q: %w", path, err) + } + rel, err := filepath.Rel(absAuthDir, absCandidate) + if err != nil { + return "", fmt.Errorf("postgres store: compute relative auth path: %w", err) + } + if rel == "." || rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("postgres store: path %q outside managed directory", path) + } + return absCandidate, nil } func (s *PostgresStore) relativeAuthID(path string) (string, error) { @@ -612,6 +640,30 @@ func (s *PostgresStore) absoluteAuthPath(id string) (string, error) { return path, nil } +func (s *PostgresStore) resolveManagedAuthPath(candidate string) (string, error) { + trimmed := strings.TrimSpace(candidate) + if trimmed == "" { + return "", fmt.Errorf("postgres store: auth path is empty") + } + + var resolved string + if filepath.IsAbs(trimmed) { + resolved = filepath.Clean(trimmed) + } else { + resolved = filepath.Join(s.authDir, filepath.FromSlash(trimmed)) + resolved = filepath.Clean(resolved) + } + + rel, err := filepath.Rel(s.authDir, resolved) + if err != nil { + return "", fmt.Errorf("postgres store: compute relative path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("postgres store: path %q outside managed directory", candidate) + } + return resolved, nil +} + func (s *PostgresStore) fullTableName(name string) string { if strings.TrimSpace(s.cfg.Schema) == "" { return quoteIdentifier(name) diff --git a/pkg/llmproxy/store/postgresstore_path_test.go b/pkg/llmproxy/store/postgresstore_path_test.go new file mode 100644 index 0000000000..50cf943722 --- /dev/null +++ b/pkg/llmproxy/store/postgresstore_path_test.go @@ -0,0 +1,51 @@ +package store + +import ( + "path/filepath" + "strings" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestPostgresResolveAuthPathRejectsTraversalFromFileName(t *testing.T) { + t.Parallel() + + store := &PostgresStore{authDir: filepath.Join(t.TempDir(), "auths")} + auth := &cliproxyauth.Auth{FileName: "../escape.json"} + if _, err := store.resolveAuthPath(auth); err == nil { + t.Fatalf("expected traversal path rejection") + } +} + +func TestPostgresResolveAuthPathRejectsAbsoluteOutsideAuthDir(t *testing.T) { + t.Parallel() + + root := t.TempDir() + store := &PostgresStore{authDir: filepath.Join(root, "auths")} + outside := filepath.Join(root, "..", "outside.json") + auth := &cliproxyauth.Auth{Attributes: map[string]string{"path": outside}} + if _, err := store.resolveAuthPath(auth); err == nil { + t.Fatalf("expected outside absolute path rejection") + } +} + +func TestPostgresResolveDeletePathConstrainsToAuthDir(t *testing.T) { + t.Parallel() + + root := t.TempDir() + authDir := filepath.Join(root, "auths") + store := &PostgresStore{authDir: authDir} + + got, err := store.resolveDeletePath("team/provider.json") + if err != nil { + t.Fatalf("resolve delete path: %v", err) + } + rel, err := filepath.Rel(authDir, got) + if err != nil { + t.Fatalf("relative path: %v", err) + } + if strings.HasPrefix(rel, "..") || rel == "." { + t.Fatalf("path escaped auth directory: %s", got) + } +} diff --git a/pkg/llmproxy/store/postgresstore_test.go b/pkg/llmproxy/store/postgresstore_test.go index bf56a111c8..2e4e9b9fac 100644 --- a/pkg/llmproxy/store/postgresstore_test.go +++ b/pkg/llmproxy/store/postgresstore_test.go @@ -5,9 +5,12 @@ import ( "database/sql" "os" "path/filepath" + "strings" "testing" _ "modernc.org/sqlite" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) func TestSyncAuthFromDatabase_PreservesLocalOnlyFiles(t *testing.T) { @@ -81,6 +84,42 @@ func TestSyncAuthFromDatabase_ContinuesOnPathConflict(t *testing.T) { } } +func TestPostgresStoreSave_RejectsPathOutsideAuthDir(t *testing.T) { + t.Parallel() + + store, db := newSQLitePostgresStore(t) + t.Cleanup(func() { _ = db.Close() }) + + auth := &cliproxyauth.Auth{ + ID: "outside.json", + FileName: "../../outside.json", + Metadata: map[string]any{"type": "kiro"}, + } + _, err := store.Save(context.Background(), auth) + if err == nil { + t.Fatalf("expected save to reject path traversal") + } + if !strings.Contains(err.Error(), "outside managed directory") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestPostgresStoreDelete_RejectsAbsolutePathOutsideAuthDir(t *testing.T) { + t.Parallel() + + store, db := newSQLitePostgresStore(t) + t.Cleanup(func() { _ = db.Close() }) + + outside := filepath.Join(filepath.Dir(store.authDir), "outside.json") + err := store.Delete(context.Background(), outside) + if err == nil { + t.Fatalf("expected delete to reject absolute path outside auth dir") + } + if !strings.Contains(err.Error(), "outside managed directory") { + t.Fatalf("unexpected error: %v", err) + } +} + func newSQLitePostgresStore(t *testing.T) (*PostgresStore, *sql.DB) { t.Helper() diff --git a/pkg/llmproxy/thinking/apply.go b/pkg/llmproxy/thinking/apply.go index 384f704136..79f691fd27 100644 --- a/pkg/llmproxy/thinking/apply.go +++ b/pkg/llmproxy/thinking/apply.go @@ -100,7 +100,6 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string if applier == nil { log.WithFields(log.Fields{ "provider": providerFormat, - "model": model, }).Debug("thinking: unknown provider, passthrough |") return body, nil } @@ -126,34 +125,24 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string }).Debug("thinking: model does not support thinking, stripping config |") return StripThinkingConfig(body, providerFormat), nil } - log.WithFields(log.Fields{ - "provider": providerFormat, - "model": baseModel, - }).Debug("thinking: model does not support thinking, passthrough |") + log.Debug("thinking: model does not support thinking, passthrough |") return body, nil } // 4. Get config: suffix priority over body var config ThinkingConfig - if suffixResult.HasSuffix { - config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model) - log.WithFields(log.Fields{ - "provider": providerFormat, - "model": model, - "mode": config.Mode, - "budget": config.Budget, - "level": config.Level, - }).Debug("thinking: config from model suffix |") - } else { - config = extractThinkingConfig(body, providerFormat) - if hasThinkingConfig(config) { + if suffixResult.HasSuffix { + config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model) log.WithFields(log.Fields{ "provider": providerFormat, - "model": modelInfo.ID, "mode": config.Mode, "budget": config.Budget, "level": config.Level, - }).Debug("thinking: original config from request |") + }).Debug("thinking: config from model suffix |") + } else { + config = extractThinkingConfig(body, providerFormat) + if hasThinkingConfig(config) { + log.WithField("provider", providerFormat).Debug("thinking: request includes thinking config |") } } @@ -164,7 +153,6 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string config = ThinkingConfig{Mode: ModeAuto, Budget: -1} log.WithFields(log.Fields{ "provider": providerFormat, - "model": modelInfo.ID, "mode": config.Mode, "forced": true, }).Debug("thinking: forced thinking for thinking model |") @@ -180,11 +168,7 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string // 5. Validate and normalize configuration validated, err := ValidateConfig(config, modelInfo, fromFormat, providerFormat, suffixResult.HasSuffix) if err != nil { - log.WithFields(log.Fields{ - "provider": providerFormat, - "model": modelInfo.ID, - "error": err.Error(), - }).Warn("thinking: validation failed |") + log.Warn("thinking: validation failed |") // Return original body on validation failure (defensive programming). // This ensures callers who ignore the error won't receive nil body. // The upstream service will decide how to handle the unmodified request. @@ -201,11 +185,11 @@ func ApplyThinking(body []byte, model string, fromFormat string, toFormat string } log.WithFields(log.Fields{ - "provider": providerFormat, - "model": modelInfo.ID, - "mode": validated.Mode, - "budget": validated.Budget, - "level": validated.Level, + "provider": redactLogText(providerFormat), + "model": redactLogText(modelInfo.ID), + "mode": redactLogMode(validated.Mode), + "budget": redactLogInt(validated.Budget), + "level": redactLogLevel(validated.Level), }).Debug("thinking: processed config to apply |") // 6. Apply configuration using provider-specific applier @@ -246,9 +230,9 @@ func parseSuffixToConfig(rawSuffix, provider, model string) ThinkingConfig { // Unknown suffix format - return empty config log.WithFields(log.Fields{ - "provider": provider, - "model": model, - "raw_suffix": rawSuffix, + "provider": redactLogText(provider), + "model": redactLogText(model), + "raw_suffix": redactLogText(rawSuffix), }).Debug("thinking: unknown suffix format, treating as no config |") return ThinkingConfig{} } @@ -274,8 +258,8 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma if !hasThinkingConfig(config) { log.WithFields(log.Fields{ - "model": modelID, - "provider": toFormat, + "model": redactLogText(modelID), + "provider": redactLogText(toFormat), }).Debug("thinking: user-defined model, passthrough (no config) |") return body, nil } @@ -283,18 +267,18 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma applier := GetProviderApplier(toFormat) if applier == nil { log.WithFields(log.Fields{ - "model": modelID, - "provider": toFormat, + "model": redactLogText(modelID), + "provider": redactLogText(toFormat), }).Debug("thinking: user-defined model, passthrough (unknown provider) |") return body, nil } log.WithFields(log.Fields{ - "provider": toFormat, - "model": modelID, - "mode": config.Mode, - "budget": config.Budget, - "level": config.Level, + "provider": redactLogText(toFormat), + "model": redactLogText(modelID), + "mode": redactLogMode(config.Mode), + "budget": redactLogInt(config.Budget), + "level": redactLogLevel(config.Level), }).Debug("thinking: applying config for user-defined model (skip validation)") config = normalizeUserDefinedConfig(config, fromFormat, toFormat) diff --git a/pkg/llmproxy/thinking/apply_logging_test.go b/pkg/llmproxy/thinking/apply_logging_test.go new file mode 100644 index 0000000000..5f5902f931 --- /dev/null +++ b/pkg/llmproxy/thinking/apply_logging_test.go @@ -0,0 +1,34 @@ +package thinking + +import ( + "bytes" + "strings" + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestApplyThinking_UnknownProviderLogDoesNotExposeModel(t *testing.T) { + var buf bytes.Buffer + prevOut := log.StandardLogger().Out + prevLevel := log.GetLevel() + log.SetOutput(&buf) + log.SetLevel(log.DebugLevel) + t.Cleanup(func() { + log.SetOutput(prevOut) + log.SetLevel(prevLevel) + }) + + model := "sensitive-user-model" + if _, err := ApplyThinking([]byte(`{"messages":[]}`), model, "", "unknown-provider", ""); err != nil { + t.Fatalf("ApplyThinking returned unexpected error: %v", err) + } + + logs := buf.String() + if !strings.Contains(logs, "thinking: unknown provider") { + t.Fatalf("expected unknown provider log, got %q", logs) + } + if strings.Contains(logs, model) { + t.Fatalf("log output leaked model value: %q", logs) + } +} diff --git a/pkg/llmproxy/thinking/log_redaction.go b/pkg/llmproxy/thinking/log_redaction.go new file mode 100644 index 0000000000..f2e450a5b8 --- /dev/null +++ b/pkg/llmproxy/thinking/log_redaction.go @@ -0,0 +1,34 @@ +package thinking + +import ( + "fmt" + "strings" +) + +const redactedLogValue = "[REDACTED]" + +func redactLogText(value string) string { + if strings.TrimSpace(value) == "" { + return "" + } + return redactedLogValue +} + +func redactLogInt(_ int) string { + return redactedLogValue +} + +func redactLogMode(_ ThinkingMode) string { + return redactedLogValue +} + +func redactLogLevel(_ ThinkingLevel) string { + return redactedLogValue +} + +func redactLogError(err error) string { + if err == nil { + return "" + } + return fmt.Sprintf("%T", err) +} diff --git a/pkg/llmproxy/thinking/log_redaction_test.go b/pkg/llmproxy/thinking/log_redaction_test.go new file mode 100644 index 0000000000..3c66972fce --- /dev/null +++ b/pkg/llmproxy/thinking/log_redaction_test.go @@ -0,0 +1,213 @@ +package thinking + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/registry" + log "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" +) + +type redactionTestApplier struct{} + +func (redactionTestApplier) Apply(body []byte, _ ThinkingConfig, _ *registry.ModelInfo) ([]byte, error) { + return body, nil +} + +func TestThinkingValidateLogsRedactSensitiveValues(t *testing.T) { + hook := test.NewLocal(log.StandardLogger()) + defer hook.Reset() + + previousLevel := log.GetLevel() + log.SetLevel(log.DebugLevel) + defer log.SetLevel(previousLevel) + + providerSecret := "provider-secret-l6-validate" + modelSecret := "model-secret-l6-validate" + + convertAutoToMidRange( + ThinkingConfig{Mode: ModeAuto, Budget: -1}, + ®istry.ThinkingSupport{Levels: []string{"low", "high"}}, + providerSecret, + modelSecret, + ) + + convertAutoToMidRange( + ThinkingConfig{Mode: ModeAuto, Budget: -1}, + ®istry.ThinkingSupport{Min: 1000, Max: 3000}, + providerSecret, + modelSecret, + ) + + clampLevel( + LevelMedium, + ®istry.ModelInfo{ + ID: modelSecret, + Thinking: ®istry.ThinkingSupport{ + Levels: []string{"low", "high"}, + }, + }, + providerSecret, + ) + + clampBudget( + 0, + ®istry.ModelInfo{ + ID: modelSecret, + Thinking: ®istry.ThinkingSupport{ + Min: 1024, + Max: 8192, + ZeroAllowed: false, + }, + }, + providerSecret, + ) + + logClamp(providerSecret, modelSecret, 9999, 8192, 1024, 8192) + + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed, using medium level |", "provider") + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed, using medium level |", "model") + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed, using medium level |", "clamped_to") + + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed |", "provider") + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed |", "model") + assertLogFieldRedacted(t, hook, "thinking: mode converted, dynamic not allowed |", "clamped_to") + + assertLogFieldRedacted(t, hook, "thinking: level clamped |", "provider") + assertLogFieldRedacted(t, hook, "thinking: level clamped |", "model") + assertLogFieldRedacted(t, hook, "thinking: level clamped |", "original_value") + assertLogFieldRedacted(t, hook, "thinking: level clamped |", "clamped_to") + + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "provider") + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "model") + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "original_value") + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "min") + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "max") + assertLogFieldRedacted(t, hook, "thinking: budget zero not allowed |", "clamped_to") + + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "provider") + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "model") + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "original_value") + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "min") + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "max") + assertLogFieldRedacted(t, hook, "thinking: budget clamped |", "clamped_to") +} + +func TestThinkingApplyLogsRedactSensitiveValues(t *testing.T) { + hook := test.NewLocal(log.StandardLogger()) + defer hook.Reset() + + previousLevel := log.GetLevel() + log.SetLevel(log.DebugLevel) + defer log.SetLevel(previousLevel) + + previousClaude := GetProviderApplier("claude") + RegisterProvider("claude", redactionTestApplier{}) + defer RegisterProvider("claude", previousClaude) + + modelSecret := "model-secret-l6-apply" + suffixSecret := "suffix-secret-l6-apply" + + reg := registry.GetGlobalRegistry() + clientID := "redaction-test-client-l6-apply" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{ + { + ID: modelSecret, + Thinking: ®istry.ThinkingSupport{ + Min: 1000, + Max: 3000, + ZeroAllowed: false, + }, + }, + }) + defer reg.RegisterClient(clientID, "claude", nil) + + _, err := ApplyThinking( + []byte(`{"thinking":{"budget_tokens":2000}}`), + modelSecret, + "claude", + "claude", + "claude", + ) + if err != nil { + t.Fatalf("ApplyThinking success path returned error: %v", err) + } + + _ = parseSuffixToConfig(suffixSecret, "claude", modelSecret) + + _, err = applyUserDefinedModel( + []byte(`{}`), + nil, + "claude", + "claude", + SuffixResult{ModelName: modelSecret}, + ) + if err != nil { + t.Fatalf("applyUserDefinedModel no-config path returned error: %v", err) + } + + _, err = applyUserDefinedModel( + []byte(`{"thinking":{"budget_tokens":2000}}`), + nil, + "claude", + "lane6-unknown-provider", + SuffixResult{ModelName: modelSecret, HasSuffix: true, RawSuffix: "high"}, + ) + if err != nil { + t.Fatalf("applyUserDefinedModel unknown-provider path returned error: %v", err) + } + + _, err = applyUserDefinedModel( + []byte(`{"thinking":{"budget_tokens":2000}}`), + nil, + "claude", + "claude", + SuffixResult{ModelName: modelSecret}, + ) + if err != nil { + t.Fatalf("applyUserDefinedModel apply path returned error: %v", err) + } + + assertLogFieldRedacted(t, hook, "thinking: processed config to apply |", "provider") + assertLogFieldRedacted(t, hook, "thinking: processed config to apply |", "model") + assertLogFieldRedacted(t, hook, "thinking: processed config to apply |", "mode") + assertLogFieldRedacted(t, hook, "thinking: processed config to apply |", "budget") + assertLogFieldRedacted(t, hook, "thinking: processed config to apply |", "level") + + assertLogFieldRedacted(t, hook, "thinking: unknown suffix format, treating as no config |", "provider") + assertLogFieldRedacted(t, hook, "thinking: unknown suffix format, treating as no config |", "model") + assertLogFieldRedacted(t, hook, "thinking: unknown suffix format, treating as no config |", "raw_suffix") + + assertLogFieldRedacted(t, hook, "thinking: user-defined model, passthrough (no config) |", "provider") + assertLogFieldRedacted(t, hook, "thinking: user-defined model, passthrough (no config) |", "model") + + assertLogFieldRedacted(t, hook, "thinking: user-defined model, passthrough (unknown provider) |", "provider") + assertLogFieldRedacted(t, hook, "thinking: user-defined model, passthrough (unknown provider) |", "model") + + assertLogFieldRedacted(t, hook, "thinking: applying config for user-defined model (skip validation)", "provider") + assertLogFieldRedacted(t, hook, "thinking: applying config for user-defined model (skip validation)", "model") + assertLogFieldRedacted(t, hook, "thinking: applying config for user-defined model (skip validation)", "mode") + assertLogFieldRedacted(t, hook, "thinking: applying config for user-defined model (skip validation)", "budget") + assertLogFieldRedacted(t, hook, "thinking: applying config for user-defined model (skip validation)", "level") +} + +func assertLogFieldRedacted(t *testing.T, hook *test.Hook, message, field string) { + t.Helper() + for _, entry := range hook.AllEntries() { + if entry.Message != message { + continue + } + value, ok := entry.Data[field] + if !ok && field == "level" { + value, ok = entry.Data["fields.level"] + } + if !ok { + t.Fatalf("log %q missing field %q", message, field) + } + if value != redactedLogValue { + t.Fatalf("log %q field %q = %v, want %q", message, field, value, redactedLogValue) + } + return + } + t.Fatalf("log %q not found", message) +} diff --git a/pkg/llmproxy/thinking/validate.go b/pkg/llmproxy/thinking/validate.go index 654de79b6a..04d9719a33 100644 --- a/pkg/llmproxy/thinking/validate.go +++ b/pkg/llmproxy/thinking/validate.go @@ -171,10 +171,10 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp config.Level = LevelMedium config.Budget = 0 log.WithFields(log.Fields{ - "provider": provider, - "model": model, + "provider": redactLogText(provider), + "model": redactLogText(model), "original_mode": "auto", - "clamped_to": string(LevelMedium), + "clamped_to": redactLogLevel(LevelMedium), }).Debug("thinking: mode converted, dynamic not allowed, using medium level |") return config } @@ -192,10 +192,10 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp config.Budget = mid } log.WithFields(log.Fields{ - "provider": provider, - "model": model, + "provider": redactLogText(provider), + "model": redactLogText(model), "original_mode": "auto", - "clamped_to": config.Budget, + "clamped_to": redactLogInt(config.Budget), }).Debug("thinking: mode converted, dynamic not allowed |") return config } @@ -238,10 +238,10 @@ func clampLevel(level ThinkingLevel, modelInfo *registry.ModelInfo, provider str if bestIdx >= 0 { clamped := standardLevelOrder[bestIdx] log.WithFields(log.Fields{ - "provider": provider, - "model": model, - "original_value": string(level), - "clamped_to": string(clamped), + "provider": redactLogText(provider), + "model": redactLogText(model), + "original_value": redactLogLevel(level), + "clamped_to": redactLogLevel(clamped), }).Debug("thinking: level clamped |") return clamped } @@ -270,12 +270,12 @@ func clampBudget(value int, modelInfo *registry.ModelInfo, provider string) int min, max := support.Min, support.Max if value == 0 && !support.ZeroAllowed { log.WithFields(log.Fields{ - "provider": provider, - "model": model, - "original_value": value, - "clamped_to": min, - "min": min, - "max": max, + "provider": redactLogText(provider), + "model": redactLogText(model), + "original_value": redactLogInt(value), + "clamped_to": redactLogInt(min), + "min": redactLogInt(min), + "max": redactLogInt(max), }).Warn("thinking: budget zero not allowed |") return min } @@ -368,11 +368,11 @@ func abs(x int) int { func logClamp(provider, model string, original, clampedTo, min, max int) { log.WithFields(log.Fields{ - "provider": provider, - "model": model, - "original_value": original, - "min": min, - "max": max, - "clamped_to": clampedTo, + "provider": redactLogText(provider), + "model": redactLogText(model), + "original_value": redactLogInt(original), + "min": redactLogInt(min), + "max": redactLogInt(max), + "clamped_to": redactLogInt(clampedTo), }).Debug("thinking: budget clamped |") } diff --git a/pkg/llmproxy/translator/kiro/claude/kiro_websearch.go b/pkg/llmproxy/translator/kiro/claude/kiro_websearch.go index cfc0104e9b..a670afb572 100644 --- a/pkg/llmproxy/translator/kiro/claude/kiro_websearch.go +++ b/pkg/llmproxy/translator/kiro/claude/kiro_websearch.go @@ -17,6 +17,8 @@ import ( "github.com/tidwall/sjson" ) +const maxInt = int(^uint(0) >> 1) + // cachedToolDescription stores the dynamically-fetched web_search tool description. // Written by the executor via SetWebSearchDescription, read by the translator // when building the remote_web_search tool for Kiro API requests. @@ -411,7 +413,11 @@ func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchI existingContent, _ := resp["content"].([]interface{}) // Build new content: search indicators first, then existing content - newContent := make([]interface{}, 0, len(searches)*2+len(existingContent)) + capacity, err := checkedSearchContentCapacity(len(searches), len(existingContent)) + if err != nil { + return responsePayload, err + } + newContent := make([]interface{}, 0, capacity) for _, s := range searches { // server_tool_use block @@ -459,6 +465,16 @@ func InjectSearchIndicatorsInResponse(responsePayload []byte, searches []SearchI return result, nil } +func checkedSearchContentCapacity(searchCount, existingCount int) (int, error) { + if searchCount < 0 || existingCount < 0 { + return 0, fmt.Errorf("invalid negative content sizes: searches=%d existing=%d", searchCount, existingCount) + } + if searchCount > (maxInt-existingCount)/2 { + return 0, fmt.Errorf("search indicator content capacity overflow: searches=%d existing=%d", searchCount, existingCount) + } + return searchCount*2 + existingCount, nil +} + // SearchIndicator holds the data for one search operation to inject into a response. type SearchIndicator struct { ToolUseID string diff --git a/pkg/llmproxy/translator/kiro/claude/kiro_websearch_test.go b/pkg/llmproxy/translator/kiro/claude/kiro_websearch_test.go index f3b893037c..77c7026d7b 100644 --- a/pkg/llmproxy/translator/kiro/claude/kiro_websearch_test.go +++ b/pkg/llmproxy/translator/kiro/claude/kiro_websearch_test.go @@ -80,3 +80,25 @@ func TestGenerateWebSearchEvents(t *testing.T) { t.Error("message_start event not found") } } + +func TestCheckedSearchContentCapacity(t *testing.T) { + t.Run("ok", func(t *testing.T) { + got, err := checkedSearchContentCapacity(3, 4) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != 10 { + t.Fatalf("expected 10, got %d", got) + } + }) + + t.Run("overflow", func(t *testing.T) { + _, err := checkedSearchContentCapacity(maxInt/2+1, 0) + if err == nil { + t.Fatal("expected overflow error, got nil") + } + if !strings.Contains(err.Error(), "overflow") { + t.Fatalf("expected overflow error, got: %v", err) + } + }) +} diff --git a/pkg/llmproxy/watcher/clients.go b/pkg/llmproxy/watcher/clients.go index 97b3aa7ca2..5f004aafe0 100644 --- a/pkg/llmproxy/watcher/clients.go +++ b/pkg/llmproxy/watcher/clients.go @@ -56,8 +56,14 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string } geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - log.Debugf("loaded %d API key clients", totalAPIKeyClients) + staticCredentialClientCount := summarizeStaticCredentialClients( + geminiAPIKeyCount, + vertexCompatAPIKeyCount, + claudeAPIKeyCount, + codexAPIKeyCount, + openAICompatCount, + ) + log.Debugf("loaded %d static credential clients", staticCredentialClientCount) var authFileCount int if rescanAuth { @@ -100,7 +106,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.clientsMutex.Unlock() } - totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + totalNewClients := authFileCount + staticCredentialClientCount if w.reloadCallback != nil { log.Debugf("triggering server update callback before auth refresh") @@ -109,15 +115,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState(forceAuthRefresh) - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", - totalNewClients, - authFileCount, - geminiAPIKeyCount, - vertexCompatAPIKeyCount, - claudeAPIKeyCount, - codexAPIKeyCount, - openAICompatCount, - ) + log.Infof("%s", clientReloadSummary(totalNewClients, authFileCount, staticCredentialClientCount)) } func (w *Watcher) addOrUpdateClient(path string) { diff --git a/pkg/llmproxy/watcher/config_reload.go b/pkg/llmproxy/watcher/config_reload.go index 236ef72bf4..3da47edd5f 100644 --- a/pkg/llmproxy/watcher/config_reload.go +++ b/pkg/llmproxy/watcher/config_reload.go @@ -117,9 +117,9 @@ func (w *Watcher) reloadConfig() bool { if oldConfig != nil { details := diff.BuildConfigChangeDetails(oldConfig, newConfig) if len(details) > 0 { - log.Debugf("config changes detected:") - for _, d := range details { - log.Debugf(" %s", d) + log.Debugf("config changes detected: %d field group(s)", len(details)) + for _, line := range redactedConfigChangeLogLines(details) { + log.Debug(line) } } else { log.Debugf("no material config field changes detected") diff --git a/pkg/llmproxy/watcher/diff/models_summary.go b/pkg/llmproxy/watcher/diff/models_summary.go index faa82a7640..326c23ac27 100644 --- a/pkg/llmproxy/watcher/diff/models_summary.go +++ b/pkg/llmproxy/watcher/diff/models_summary.go @@ -1,7 +1,8 @@ package diff import ( - "crypto/sha256" + "crypto/hmac" + "crypto/sha512" "encoding/hex" "sort" "strings" @@ -9,6 +10,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/config" ) +const vertexModelsSummaryHashKey = "watcher-vertex-models-summary:v1" + type GeminiModelsSummary struct { hash string count int @@ -113,9 +116,10 @@ func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummar return VertexModelsSummary{} } sort.Strings(names) - sum := sha256.Sum256([]byte(strings.Join(names, "|"))) + hasher := hmac.New(sha512.New, []byte(vertexModelsSummaryHashKey)) + hasher.Write([]byte(strings.Join(names, "|"))) return VertexModelsSummary{ - hash: hex.EncodeToString(sum[:]), + hash: hex.EncodeToString(hasher.Sum(nil)), count: len(names), } } diff --git a/pkg/llmproxy/watcher/diff/oauth_excluded_test.go b/pkg/llmproxy/watcher/diff/oauth_excluded_test.go index 0fbcc27114..1ddd7c769d 100644 --- a/pkg/llmproxy/watcher/diff/oauth_excluded_test.go +++ b/pkg/llmproxy/watcher/diff/oauth_excluded_test.go @@ -1,6 +1,8 @@ package diff import ( + "crypto/sha256" + "encoding/hex" "testing" "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/config" @@ -98,6 +100,21 @@ func TestSummarizeVertexModels(t *testing.T) { } } +func TestSummarizeVertexModels_DoesNotUseLegacySHA256(t *testing.T) { + summary := SummarizeVertexModels([]config.VertexCompatModel{ + {Name: "m1"}, + {Name: "m2"}, + }) + if summary.hash == "" { + t.Fatal("expected non-empty hash") + } + + legacy := sha256.Sum256([]byte("m1|m2")) + if summary.hash == hex.EncodeToString(legacy[:]) { + t.Fatalf("expected vertex hash to differ from legacy sha256") + } +} + func expectContains(t *testing.T, list []string, target string) { t.Helper() for _, entry := range list { diff --git a/pkg/llmproxy/watcher/diff/openai_compat.go b/pkg/llmproxy/watcher/diff/openai_compat.go index 41726db3c3..dfbeafee21 100644 --- a/pkg/llmproxy/watcher/diff/openai_compat.go +++ b/pkg/llmproxy/watcher/diff/openai_compat.go @@ -1,7 +1,8 @@ package diff import ( - "crypto/sha256" + "crypto/hmac" + "crypto/sha512" "encoding/hex" "fmt" "sort" @@ -10,6 +11,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/pkg/llmproxy/config" ) +const openAICompatSignatureHashKey = "watcher-openai-compat-signature:v1" + // DiffOpenAICompatibility produces human-readable change descriptions. func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { changes := make([]string, 0) @@ -178,6 +181,7 @@ func openAICompatSignature(entry config.OpenAICompatibility) string { if len(parts) == 0 { return "" } - sum := sha256.Sum256([]byte(strings.Join(parts, "|"))) - return hex.EncodeToString(sum[:]) + hasher := hmac.New(sha512.New, []byte(openAICompatSignatureHashKey)) + hasher.Write([]byte(strings.Join(parts, "|"))) + return hex.EncodeToString(hasher.Sum(nil)) } diff --git a/pkg/llmproxy/watcher/diff/openai_compat_test.go b/pkg/llmproxy/watcher/diff/openai_compat_test.go index 434d989d5b..4e2907c0f3 100644 --- a/pkg/llmproxy/watcher/diff/openai_compat_test.go +++ b/pkg/llmproxy/watcher/diff/openai_compat_test.go @@ -1,6 +1,8 @@ package diff import ( + "crypto/sha256" + "encoding/hex" "strings" "testing" @@ -161,6 +163,20 @@ func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) { if sigC := openAICompatSignature(c); sigC == sigB { t.Fatalf("expected signature to change when models change, got %s", sigC) } + +} + +func TestOpenAICompatSignature_DoesNotUseLegacySHA256(t *testing.T) { + entry := config.OpenAICompatibility{Name: "provider"} + got := openAICompatSignature(entry) + if got == "" { + t.Fatal("expected non-empty signature") + } + + legacy := sha256.Sum256([]byte("name=provider")) + if got == hex.EncodeToString(legacy[:]) { + t.Fatalf("expected signature to differ from legacy sha256") + } } func TestCountOpenAIModelsSkipsBlanks(t *testing.T) { diff --git a/pkg/llmproxy/watcher/logging_helpers.go b/pkg/llmproxy/watcher/logging_helpers.go new file mode 100644 index 0000000000..b4cd3ae225 --- /dev/null +++ b/pkg/llmproxy/watcher/logging_helpers.go @@ -0,0 +1,24 @@ +package watcher + +import "fmt" + +func summarizeStaticCredentialClients(gemini, vertex, claude, codex, openAICompat int) int { + return gemini + vertex + claude + codex + openAICompat +} + +func clientReloadSummary(totalClients, authFileCount, staticCredentialClients int) string { + return fmt.Sprintf( + "full client load complete - %d clients (%d auth files + %d static credential clients)", + totalClients, + authFileCount, + staticCredentialClients, + ) +} + +func redactedConfigChangeLogLines(details []string) []string { + lines := make([]string, 0, len(details)) + for i := range details { + lines = append(lines, fmt.Sprintf(" change[%d] recorded (redacted)", i+1)) + } + return lines +} diff --git a/pkg/llmproxy/watcher/logging_safety_test.go b/pkg/llmproxy/watcher/logging_safety_test.go new file mode 100644 index 0000000000..2dd7424e5a --- /dev/null +++ b/pkg/llmproxy/watcher/logging_safety_test.go @@ -0,0 +1,40 @@ +package watcher + +import ( + "strings" + "testing" +) + +func TestRedactedConfigChangeLogLines(t *testing.T) { + lines := redactedConfigChangeLogLines([]string{ + "api-key: sk-live-abc123", + "oauth-token: bearer secret", + }) + if len(lines) != 2 { + t.Fatalf("expected 2 lines, got %d", len(lines)) + } + for _, line := range lines { + if strings.Contains(line, "sk-live-abc123") || strings.Contains(line, "secret") { + t.Fatalf("sensitive content leaked in redacted line: %q", line) + } + if !strings.Contains(line, "redacted") { + t.Fatalf("expected redacted marker in line: %q", line) + } + } +} + +func TestClientReloadSummary(t *testing.T) { + got := clientReloadSummary(9, 4, 5) + if !strings.Contains(got, "9 clients") { + t.Fatalf("expected total client count, got %q", got) + } + if !strings.Contains(got, "4 auth files") { + t.Fatalf("expected auth file count, got %q", got) + } + if !strings.Contains(got, "5 static credential clients") { + t.Fatalf("expected static credential count, got %q", got) + } + if strings.Contains(strings.ToLower(got), "api key") { + t.Fatalf("summary should not mention api keys directly: %q", got) + } +} diff --git a/pkg/llmproxy/watcher/synthesizer/helpers.go b/pkg/llmproxy/watcher/synthesizer/helpers.go index 3ee77354c5..dc31c7136f 100644 --- a/pkg/llmproxy/watcher/synthesizer/helpers.go +++ b/pkg/llmproxy/watcher/synthesizer/helpers.go @@ -1,7 +1,8 @@ package synthesizer import ( - "crypto/sha256" + "crypto/hmac" + "crypto/sha512" "encoding/hex" "fmt" "sort" @@ -12,8 +13,10 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +const stableIDGeneratorHashKey = "watcher-stable-id-generator:v1" + // StableIDGenerator generates stable, deterministic IDs for auth entries. -// It uses SHA256 hashing with collision handling via counters. +// It uses keyed HMAC-SHA512 hashing with collision handling via counters. // It is not safe for concurrent use. type StableIDGenerator struct { counters map[string]int @@ -30,7 +33,7 @@ func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) if g == nil { return kind + ":000000000000", "000000000000" } - hasher := sha256.New() + hasher := hmac.New(sha512.New, []byte(stableIDGeneratorHashKey)) hasher.Write([]byte(kind)) for _, part := range parts { trimmed := strings.TrimSpace(part) diff --git a/pkg/llmproxy/watcher/synthesizer/helpers_test.go b/pkg/llmproxy/watcher/synthesizer/helpers_test.go index b21d3e109a..5840f6716e 100644 --- a/pkg/llmproxy/watcher/synthesizer/helpers_test.go +++ b/pkg/llmproxy/watcher/synthesizer/helpers_test.go @@ -1,6 +1,8 @@ package synthesizer import ( + "crypto/sha256" + "encoding/hex" "reflect" "strings" "testing" @@ -10,6 +12,26 @@ import ( coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) +func TestStableIDGenerator_Next_DoesNotUseLegacySHA256(t *testing.T) { + gen := NewStableIDGenerator() + id, short := gen.Next("gemini:apikey", "test-key", "https://api.example.com") + if id == "" || short == "" { + t.Fatal("expected generated IDs to be non-empty") + } + + legacyHasher := sha256.New() + legacyHasher.Write([]byte("gemini:apikey")) + legacyHasher.Write([]byte{0}) + legacyHasher.Write([]byte("test-key")) + legacyHasher.Write([]byte{0}) + legacyHasher.Write([]byte("https://api.example.com")) + legacyShort := hex.EncodeToString(legacyHasher.Sum(nil))[:12] + + if short == legacyShort { + t.Fatalf("expected short id to differ from legacy sha256 digest %q", legacyShort) + } +} + func TestNewStableIDGenerator(t *testing.T) { gen := NewStableIDGenerator() if gen == nil { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index f8b325e05a..5242f8b7ab 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -473,7 +473,7 @@ func appendAPIResponse(c *gin.Context, data []byte) { if existing, exists := c.Get("API_RESPONSE"); exists { if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 { - combined := make([]byte, 0, len(existingBytes)+len(data)+1) + combined := make([]byte, 0, len(existingBytes)) combined = append(combined, existingBytes...) if existingBytes[len(existingBytes)-1] != '\n' { combined = append(combined, '\n') diff --git a/sdk/api/handlers/handlers_append_response_test.go b/sdk/api/handlers/handlers_append_response_test.go new file mode 100644 index 0000000000..784a968381 --- /dev/null +++ b/sdk/api/handlers/handlers_append_response_test.go @@ -0,0 +1,27 @@ +package handlers + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestAppendAPIResponse_AppendsWithNewline(t *testing.T) { + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Set("API_RESPONSE", []byte("first")) + + appendAPIResponse(ginCtx, []byte("second")) + + value, exists := ginCtx.Get("API_RESPONSE") + if !exists { + t.Fatal("expected API_RESPONSE to be set") + } + got, ok := value.([]byte) + if !ok { + t.Fatalf("expected []byte API_RESPONSE, got %T", value) + } + if string(got) != "first\nsecond" { + t.Fatalf("unexpected API_RESPONSE: %q", string(got)) + } +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 4715d7f7b1..c09bca9712 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -162,14 +162,7 @@ func (s *FileTokenStore) Delete(ctx context.Context, id string) error { } func (s *FileTokenStore) resolveDeletePath(id string) (string, error) { - if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) { - return id, nil - } - dir := s.baseDirSnapshot() - if dir == "" { - return "", fmt.Errorf("auth filestore: directory not configured") - } - return filepath.Join(dir, id), nil + return s.resolveManagedPath(id) } func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) { @@ -274,29 +267,47 @@ func (s *FileTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error } if auth.Attributes != nil { if p := strings.TrimSpace(auth.Attributes["path"]); p != "" { - return p, nil + return s.resolveManagedPath(p) } } if fileName := strings.TrimSpace(auth.FileName); fileName != "" { - if filepath.IsAbs(fileName) { - return fileName, nil - } - if dir := s.baseDirSnapshot(); dir != "" { - return filepath.Join(dir, fileName), nil - } - return fileName, nil + return s.resolveManagedPath(fileName) } if auth.ID == "" { return "", fmt.Errorf("auth filestore: missing id") } - if filepath.IsAbs(auth.ID) { - return auth.ID, nil + return s.resolveManagedPath(auth.ID) +} + +func (s *FileTokenStore) resolveManagedPath(candidate string) (string, error) { + trimmed := strings.TrimSpace(candidate) + if trimmed == "" { + return "", fmt.Errorf("auth filestore: path is empty") } - dir := s.baseDirSnapshot() - if dir == "" { + baseDir := s.baseDirSnapshot() + if baseDir == "" { return "", fmt.Errorf("auth filestore: directory not configured") } - return filepath.Join(dir, auth.ID), nil + absBase, err := filepath.Abs(baseDir) + if err != nil { + return "", fmt.Errorf("auth filestore: resolve base directory: %w", err) + } + + var resolved string + if filepath.IsAbs(trimmed) { + resolved = filepath.Clean(trimmed) + } else { + resolved = filepath.Join(absBase, filepath.FromSlash(trimmed)) + resolved = filepath.Clean(resolved) + } + rel, err := filepath.Rel(absBase, resolved) + if err != nil { + return "", fmt.Errorf("auth filestore: resolve relative path: %w", err) + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + return "", fmt.Errorf("auth filestore: path %q escapes base directory", candidate) + } + return resolved, nil } func (s *FileTokenStore) labelFor(metadata map[string]any) string { diff --git a/sdk/auth/filestore_deletepath_test.go b/sdk/auth/filestore_deletepath_test.go new file mode 100644 index 0000000000..e37dfd28fa --- /dev/null +++ b/sdk/auth/filestore_deletepath_test.go @@ -0,0 +1,49 @@ +package auth + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestFileTokenStoreResolveDeletePathRejectsEscapeInputs(t *testing.T) { + t.Parallel() + + store := NewFileTokenStore() + store.SetBaseDir(t.TempDir()) + + absolute := filepath.Join(t.TempDir(), "outside.json") + cases := []string{ + "../outside.json", + absolute, + } + for _, id := range cases { + if _, err := store.resolveDeletePath(id); err == nil { + t.Fatalf("expected id %q to be rejected", id) + } + } +} + +func TestFileTokenStoreDeleteRemovesFileWithinBaseDir(t *testing.T) { + t.Parallel() + + baseDir := t.TempDir() + store := NewFileTokenStore() + store.SetBaseDir(baseDir) + + target := filepath.Join(baseDir, "nested", "auth.json") + if err := os.MkdirAll(filepath.Dir(target), 0o700); err != nil { + t.Fatalf("create nested dir: %v", err) + } + if err := os.WriteFile(target, []byte(`{"ok":true}`), 0o600); err != nil { + t.Fatalf("write target file: %v", err) + } + + if err := store.Delete(context.Background(), "nested/auth.json"); err != nil { + t.Fatalf("delete auth file: %v", err) + } + if _, err := os.Stat(target); !os.IsNotExist(err) { + t.Fatalf("expected target to be deleted, stat err=%v", err) + } +} diff --git a/sdk/auth/filestore_test.go b/sdk/auth/filestore_test.go index 9e135ad4c9..a773423916 100644 --- a/sdk/auth/filestore_test.go +++ b/sdk/auth/filestore_test.go @@ -1,6 +1,13 @@ package auth -import "testing" +import ( + "context" + "path/filepath" + "strings" + "testing" + + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) func TestExtractAccessToken(t *testing.T) { t.Parallel() @@ -78,3 +85,42 @@ func TestExtractAccessToken(t *testing.T) { }) } } + +func TestFileTokenStoreSave_RejectsPathOutsideBaseDir(t *testing.T) { + t.Parallel() + + store := NewFileTokenStore() + baseDir := t.TempDir() + store.SetBaseDir(baseDir) + + auth := &cliproxyauth.Auth{ + ID: "outside.json", + FileName: "../../outside.json", + Metadata: map[string]any{"type": "kiro"}, + } + + _, err := store.Save(context.Background(), auth) + if err == nil { + t.Fatalf("expected save to reject path traversal") + } + if !strings.Contains(err.Error(), "escapes base directory") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestFileTokenStoreDelete_RejectsAbsolutePathOutsideBaseDir(t *testing.T) { + t.Parallel() + + store := NewFileTokenStore() + baseDir := t.TempDir() + store.SetBaseDir(baseDir) + + outside := filepath.Join(filepath.Dir(baseDir), "outside.json") + err := store.Delete(context.Background(), outside) + if err == nil { + t.Fatalf("expected delete to reject absolute path outside base dir") + } + if !strings.Contains(err.Error(), "escapes base directory") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index bb855d716e..25eaffb51d 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -3,6 +3,8 @@ package auth import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "io" @@ -1739,7 +1741,7 @@ func (m *Manager) checkRefreshes(ctx context.Context) { if !m.shouldRefresh(a, now) { continue } - log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ) + log.Debugf("checking refresh for %s kind=%s", authLogRef(a), typ) if exec := m.executorFor(a.Provider); exec == nil { continue @@ -1999,10 +2001,10 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { cloned := auth.Clone() updated, err := exec.Refresh(ctx, cloned) if err != nil && errors.Is(err, context.Canceled) { - log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID) + log.Debugf("refresh canceled for %s", authLogRef(auth)) return } - log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err) + log.Debugf("refresh completed for %s err=%v", authLogRef(auth), err) now := time.Now() if err != nil { m.mu.Lock() @@ -2143,6 +2145,25 @@ func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string return strings.Join(parts, " ") } +func authLogRef(auth *Auth) string { + if auth == nil { + return "provider=unknown auth_id_hash=none" + } + provider := strings.TrimSpace(auth.Provider) + if provider == "" { + provider = "unknown" + } + identifier := strings.TrimSpace(auth.ID) + if identifier == "" { + identifier = strings.TrimSpace(auth.FileName) + } + if identifier == "" { + return "provider=" + provider + " auth_id_hash=none" + } + sum := sha256.Sum256([]byte(identifier)) + return "provider=" + provider + " auth_id_hash=" + hex.EncodeToString(sum[:6]) +} + // InjectCredentials delegates per-provider HTTP request preparation when supported. // If the registered executor for the auth provider implements RequestPreparer, // it will be invoked to modify the request (e.g., add headers). diff --git a/sdk/cliproxy/auth/conductor_logging_test.go b/sdk/cliproxy/auth/conductor_logging_test.go new file mode 100644 index 0000000000..816a61bdfd --- /dev/null +++ b/sdk/cliproxy/auth/conductor_logging_test.go @@ -0,0 +1,23 @@ +package auth + +import ( + "strings" + "testing" +) + +func TestAuthLogRef(t *testing.T) { + auth := &Auth{ + ID: "sensitive-auth-id-12345", + Provider: "claude", + } + got := authLogRef(auth) + if !strings.Contains(got, "provider=claude") { + t.Fatalf("expected provider in log ref, got %q", got) + } + if strings.Contains(got, auth.ID) { + t.Fatalf("log ref leaked raw auth id: %q", got) + } + if !strings.Contains(got, "auth_id_hash=") { + t.Fatalf("expected auth hash marker in log ref, got %q", got) + } +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 2ccc5649d8..f60685dbaa 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -2,8 +2,8 @@ package auth import ( "crypto/sha256" - "encoding/hex" "encoding/json" + "fmt" "strconv" "strings" "sync" @@ -133,7 +133,7 @@ func stableAuthIndex(seed string) string { return "" } sum := sha256.Sum256([]byte(seed)) - return hex.EncodeToString(sum[:8]) + return fmt.Sprintf("%x", sum[:]) } // EnsureIndex returns a stable index derived from the auth file name or API key. @@ -148,10 +148,6 @@ func (a *Auth) EnsureIndex() string { seed := strings.TrimSpace(a.FileName) if seed != "" { seed = "file:" + seed - } else if a.Attributes != nil { - if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" { - seed = "api_key:" + apiKey - } } if seed == "" { if id := strings.TrimSpace(a.ID); id != "" { diff --git a/sdk/cliproxy/auth/types_test.go b/sdk/cliproxy/auth/types_test.go index 8249b0635b..51ddda45fb 100644 --- a/sdk/cliproxy/auth/types_test.go +++ b/sdk/cliproxy/auth/types_test.go @@ -1,6 +1,9 @@ package auth -import "testing" +import ( + "strings" + "testing" +) func TestToolPrefixDisabled(t *testing.T) { var a *Auth @@ -33,3 +36,29 @@ func TestToolPrefixDisabled(t *testing.T) { t.Error("should return false when set to false") } } + +func TestStableAuthIndex_UsesFullDigest(t *testing.T) { + idx := stableAuthIndex("seed-value") + if len(idx) != 64 { + t.Fatalf("stableAuthIndex length = %d, want 64", len(idx)) + } +} + +func TestEnsureIndex_DoesNotUseAPIKeySeed(t *testing.T) { + a := &Auth{ + ID: "auth-id-1", + Attributes: map[string]string{ + "api_key": "sensitive-token", + }, + } + idx := a.EnsureIndex() + if idx == "" { + t.Fatal("expected non-empty index") + } + if idx != stableAuthIndex("id:"+a.ID) { + t.Fatalf("EnsureIndex = %q, want id-derived index", idx) + } + if strings.Contains(idx, "sensitive-token") { + t.Fatal("index should not include API key material") + } +}