diff --git a/internals/proxy/middlewares/proxy.go b/internals/proxy/middlewares/proxy.go index 7d0f8d3..b5daad5 100644 --- a/internals/proxy/middlewares/proxy.go +++ b/internals/proxy/middlewares/proxy.go @@ -5,7 +5,6 @@ import ( "net" "net/http" "net/url" - "slices" "strings" ) @@ -14,6 +13,7 @@ var InternalProxy Middleware = Middleware{ Use: proxyHandler, } +const trustedProxyKey contextKey = "isProxyTrusted" const clientIPKey contextKey = "clientIP" const originURLKey contextKey = "originURL" @@ -41,45 +41,36 @@ func proxyHandler(next http.Handler) http.Handler { rawTrustedProxies = getConfig("").SETTINGS.ACCESS.TRUSTED_PROXIES } + var trusted bool var ip net.IP - var originUrl string + + host, _, _ := net.SplitHostPort(req.RemoteAddr) + + originUrl := parseOrigin(req.Proto, req.Host) + + ip = net.ParseIP(host) if len(rawTrustedProxies) != 0 { trustedProxies := parseIPsAndIPNets(rawTrustedProxies) - var forwardedEntries []ForwardedEntry + trusted = isIPInList(ip, trustedProxies) + } + + if trusted { + var forwardedEntries []ForwardedEntry - if req.Header.Get("Forwarded") != "" { + if req.Header.Get("Forwarded") != "" { forwardedEntries = parseForwarded(req.Header.Get("Forwarded")) } else { forwardedEntries = parseXForwardedHeaders(req.Header) } - if len(forwardedEntries) != 0 { - originInfo := getOriginFromForwarded(forwardedEntries, trustedProxies) - ip = originInfo.IP - - originUrl = originInfo.Proto + "://" + originInfo.Host - } - } - - if ip == nil { - host, _, _ := net.SplitHostPort(req.RemoteAddr) - - ip = net.ParseIP(host) - } - - if originUrl == "" { - originUrl = req.Proto + "://" + req.Host + if len(forwardedEntries) != 0 { + ip = parseForIP(forwardedEntries[0].For) - if !strings.Contains(req.Host, ":") { - if req.Proto == "https" { - originUrl += ":443" - } else { - originUrl += ":80" - } - } - } + originUrl = parseOrigin(forwardedEntries[0].Proto, forwardedEntries[0].Host) + } + } originURL, err := url.Parse(originUrl) @@ -89,6 +80,7 @@ func proxyHandler(next http.Handler) http.Handler { return } + req = setContext(req, trustedProxyKey, trusted) req = setContext(req, originURLKey, originURL) req = setContext(req, clientIPKey, ip) @@ -97,29 +89,16 @@ func proxyHandler(next http.Handler) http.Handler { }) } -func getOriginFromForwarded(entries []ForwardedEntry, trusted []*net.IPNet) OriginInfo { - var origin OriginInfo - - // reverse to place origin client last - slices.Reverse(entries) - - for _, entry := range entries { - ip := parseForIP(entry.For) - - if ip == nil { - continue - } - - // ip not trusted => use as client ip - if !isIPInList(ip, trusted) { - origin.IP = ip - origin.Proto = entry.Proto - origin.Host = entry.Host - break +func parseOrigin(proto, host string) string { + if !strings.Contains(host, ":") { + if proto == "https" { + host += ":443" + } else { + host += ":80" } } - return origin + return proto + "://" + host } func parseForIP(value string) net.IP {