Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 27 additions & 48 deletions internals/proxy/middlewares/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net"
"net/http"
"net/url"
"slices"
"strings"
)

Expand All @@ -14,6 +13,7 @@ var InternalProxy Middleware = Middleware{
Use: proxyHandler,
}

const trustedProxyKey contextKey = "isProxyTrusted"
const clientIPKey contextKey = "clientIP"
const originURLKey contextKey = "originURL"

Expand Down Expand Up @@ -41,45 +41,36 @@ func proxyHandler(next http.Handler) http.Handler {
rawTrustedProxies = getConfig("").SETTINGS.ACCESS.TRUSTED_PROXIES
}

var trusted bool
var ip net.IP
var originUrl string

host, _, _ := net.SplitHostPort(req.RemoteAddr)

originUrl := parseOrigin(req.Proto, req.Host)

ip = net.ParseIP(host)

if len(rawTrustedProxies) != 0 {
trustedProxies := parseIPsAndIPNets(rawTrustedProxies)

var forwardedEntries []ForwardedEntry
trusted = isIPInList(ip, trustedProxies)
}

if trusted {
var forwardedEntries []ForwardedEntry

if req.Header.Get("Forwarded") != "" {
if req.Header.Get("Forwarded") != "" {
forwardedEntries = parseForwarded(req.Header.Get("Forwarded"))
} else {
forwardedEntries = parseXForwardedHeaders(req.Header)
}

if len(forwardedEntries) != 0 {
originInfo := getOriginFromForwarded(forwardedEntries, trustedProxies)
ip = originInfo.IP

originUrl = originInfo.Proto + "://" + originInfo.Host
}
}

if ip == nil {
host, _, _ := net.SplitHostPort(req.RemoteAddr)

ip = net.ParseIP(host)
}

if originUrl == "" {
originUrl = req.Proto + "://" + req.Host
if len(forwardedEntries) != 0 {
ip = parseForIP(forwardedEntries[0].For)

if !strings.Contains(req.Host, ":") {
if req.Proto == "https" {
originUrl += ":443"
} else {
originUrl += ":80"
}
}
}
originUrl = parseOrigin(forwardedEntries[0].Proto, forwardedEntries[0].Host)
}
}

originURL, err := url.Parse(originUrl)

Expand All @@ -89,6 +80,7 @@ func proxyHandler(next http.Handler) http.Handler {
return
}

req = setContext(req, trustedProxyKey, trusted)
req = setContext(req, originURLKey, originURL)

req = setContext(req, clientIPKey, ip)
Expand All @@ -97,29 +89,16 @@ func proxyHandler(next http.Handler) http.Handler {
})
}

func getOriginFromForwarded(entries []ForwardedEntry, trusted []*net.IPNet) OriginInfo {
var origin OriginInfo

// reverse to place origin client last
slices.Reverse(entries)

for _, entry := range entries {
ip := parseForIP(entry.For)

if ip == nil {
continue
}

// ip not trusted => use as client ip
if !isIPInList(ip, trusted) {
origin.IP = ip
origin.Proto = entry.Proto
origin.Host = entry.Host
break
func parseOrigin(proto, host string) string {
if !strings.Contains(host, ":") {
if proto == "https" {
host += ":443"
} else {
host += ":80"
}
}

return origin
return proto + "://" + host
}

func parseForIP(value string) net.IP {
Expand Down