Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions internal/runtime/executor/user_id_cache.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package executor

import (
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"os"
"sync"
"time"
)
Expand All @@ -16,8 +18,11 @@ var (
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu sync.RWMutex
userIDCacheCleanupOnce sync.Once
userIDCacheHashKey = resolveUserIDCacheHashKey()
)

const userIDCacheHashFallback = "executor-user-id-cache:hmac-sha256-v1"

const (
userIDTTL = time.Hour
userIDCacheCleanupPeriod = 15 * time.Minute
Expand Down Expand Up @@ -45,8 +50,16 @@ func purgeExpiredUserIDs() {
}

func userIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
hasher := hmac.New(sha256.New, userIDCacheHashKey)
_, _ = hasher.Write([]byte(apiKey))
return hex.EncodeToString(hasher.Sum(nil))
}

func resolveUserIDCacheHashKey() []byte {
if env := os.Getenv("CLIPROXY_USER_ID_CACHE_HASH_KEY"); env != "" {
return []byte(env)
}
return []byte(userIDCacheHashFallback)
}

func cachedUserID(apiKey string) string {
Expand Down
150 changes: 115 additions & 35 deletions pkg/llmproxy/api/handlers/management/api_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
safeURL, parsedURL, errSanitizeURL := sanitizeAPICallURL(urlStr)
if errSanitizeURL != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errSanitizeURL.Error()})
return
}
if errValidateURL := validateAPICallURL(parsedURL); errValidateURL != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errValidateURL.Error()})
if errResolve := validateResolvedHostIPs(parsedURL.Hostname()); errResolve != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errResolve.Error()})
return
}

Expand Down Expand Up @@ -212,7 +212,7 @@
}
}

req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, safeURL, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
Expand All @@ -226,11 +226,15 @@
req.Header.Set(key, value)
}
if hostOverride != "" {
if !isAllowedHostOverride(parsedURL, hostOverride) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid host override"})
return
}
req.Host = hostOverride
}

httpClient := &http.Client{
Timeout: defaultAPICallTimeout,

Check failure

Code scanning / CodeQL

Uncontrolled data used in network request Critical

The
URL
of this request depends on a
user-provided value
.
}
httpClient.Transport = h.apiCallTransport(auth)

Expand Down Expand Up @@ -268,8 +272,8 @@

// If this is a GitHub Copilot token endpoint response, try to enrich with quota information
if resp.StatusCode == http.StatusOK &&
strings.Contains(urlStr, "copilot_internal") &&
strings.Contains(urlStr, "/token") {
strings.Contains(safeURL, "copilot_internal") &&
strings.Contains(safeURL, "/token") {
response = h.enrichCopilotTokenResponse(c.Request.Context(), response, auth, urlStr)
}

Expand Down Expand Up @@ -298,6 +302,35 @@
return ""
}

func isAllowedHostOverride(parsedURL *url.URL, override string) bool {
if parsedURL == nil {
return false
}
trimmed := strings.TrimSpace(override)
if trimmed == "" {
return false
}
if strings.ContainsAny(trimmed, " \r\n\t") {
return false
}

requestHost := strings.TrimSpace(parsedURL.Host)
requestHostname := strings.TrimSpace(parsedURL.Hostname())
if requestHost == "" {
return false
}
if strings.EqualFold(trimmed, requestHost) {
return true
}
if strings.EqualFold(trimmed, requestHostname) {
return true
}
if len(trimmed) > 2 && trimmed[0] == '[' && trimmed[len(trimmed)-1] == ']' {
return false
}
return false
}

func validateAPICallURL(parsedURL *url.URL) error {
if parsedURL == nil {
return fmt.Errorf("invalid url")
Expand All @@ -306,13 +339,13 @@
if scheme != "http" && scheme != "https" {
return fmt.Errorf("unsupported url scheme")
}
if parsedURL.User != nil {
return fmt.Errorf("target host is not allowed")
}
host := strings.TrimSpace(parsedURL.Hostname())
if host == "" {
return fmt.Errorf("invalid url host")
}
if parsedURL.User != nil {
return fmt.Errorf("target user info is not allowed")
}
if strings.EqualFold(host, "localhost") {
return fmt.Errorf("target host is not allowed")
}
Expand All @@ -324,6 +357,42 @@
return nil
}

func sanitizeAPICallURL(raw string) (string, *url.URL, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", nil, fmt.Errorf("missing url")
}
parsedURL, errParseURL := url.Parse(trimmed)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
return "", nil, fmt.Errorf("invalid url")
}
if errValidateURL := validateAPICallURL(parsedURL); errValidateURL != nil {
return "", nil, errValidateURL
}
parsedURL.Fragment = ""
return parsedURL.String(), parsedURL, nil
}

func validateResolvedHostIPs(host string) error {
trimmed := strings.TrimSpace(host)
if trimmed == "" {
return fmt.Errorf("invalid url host")
}
resolved, errLookup := net.LookupIP(trimmed)
if errLookup != nil {
return fmt.Errorf("target host resolution failed")
}
for _, ip := range resolved {
if ip == nil {
continue
}
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return fmt.Errorf("target host is not allowed")
}
}
return nil
}

func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
Expand Down Expand Up @@ -728,10 +797,12 @@
}

func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
hasAuthProxy := false
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
hasAuthProxy = true
}
}
if h != nil && h.cfg != nil {
Expand All @@ -741,9 +812,14 @@
}

for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
transport, errBuild := buildProxyTransportWithError(proxyStr)
if transport != nil {
return transport
}
if hasAuthProxy {
return &transportFailureRoundTripper{err: fmt.Errorf("authentication proxy misconfigured: %v", errBuild)}
}
log.Debugf("failed to setup API call proxy from URL: %s, trying next candidate", proxyStr)
}

transport, ok := http.DefaultTransport.(*http.Transport)
Expand All @@ -755,20 +831,20 @@
return clone
}

func buildProxyTransport(proxyStr string) *http.Transport {
func buildProxyTransportWithError(proxyStr string) (*http.Transport, error) {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
return nil
return nil, fmt.Errorf("proxy URL is empty")
}

proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
return nil, fmt.Errorf("parse proxy URL failed: %w", errParse)
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
return nil, fmt.Errorf("missing proxy scheme or host: %s", proxyStr)
}

if proxyURL.Scheme == "socks5" {
Expand All @@ -781,22 +857,30 @@
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}, nil
}

if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}, nil
}

log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
}

type transportFailureRoundTripper struct {
err error
}

func (t *transportFailureRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, t.err
}

// headerContainsValue checks whether a header map contains a target value (case-insensitive key and value).
Expand Down Expand Up @@ -1221,6 +1305,16 @@
log.WithError(errQuotaURL).Debug("enrichCopilotTokenResponse: rejected token URL for quota request")
return response
}
parsedQuotaURL, errParseQuotaURL := url.Parse(quotaURL)
if errParseQuotaURL != nil {
return response
}
if errValidate := validateAPICallURL(parsedQuotaURL); errValidate != nil {
return response
}
if errResolve := validateResolvedHostIPs(parsedQuotaURL.Hostname()); errResolve != nil {
return response
}

req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodGet, quotaURL, nil)
if errNewRequest != nil {
Expand Down Expand Up @@ -1367,27 +1461,13 @@
if errParse != nil {
return "", errParse
}
if parsedURL == nil || !parsedURL.IsAbs() {
return "", fmt.Errorf("invalid token url")
}
if parsedURL.User != nil {
return "", fmt.Errorf("token url must not include user info")
return "", fmt.Errorf("unsupported host %q", parsedURL.Hostname())
}
host := strings.ToLower(parsedURL.Hostname())
if host == "" {
return "", fmt.Errorf("token url host is required")
}
if parsedURL.Scheme != "https" {
return "", fmt.Errorf("unsupported scheme %q", parsedURL.Scheme)
}
if strings.EqualFold(host, "localhost") {
return "", fmt.Errorf("token url host is not allowed")
}
if ip := net.ParseIP(host); ip != nil {
if ip.IsLoopback() || ip.IsPrivate() || ip.IsUnspecified() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return "", fmt.Errorf("token url host is not allowed")
}
}
switch host {
case "api.github.com", "api.githubcopilot.com":
return fmt.Sprintf("https://%s/copilot_pkg/llmproxy/user", host), nil
Expand Down
Loading
Loading