Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internals/config/structure/structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type CONFIG struct {
}

type SERVICE struct {
HOSTNAMES []string `koanf:"hostnames" env>aliases:".hostnames"`
PORT string `koanf:"port" env>aliases:".port"`
LOG_LEVEL string `koanf:"loglevel" env>aliases:".loglevel"`
}
Expand Down
46 changes: 46 additions & 0 deletions internals/proxy/middlewares/hostname.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package middlewares

import (
"net/http"
"net/url"
"slices"
)

var Hostname Middleware = Middleware{
Name: "Hostname",
Use: hostnameHandler,
}

func hostnameHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
logger := getLogger(req)

conf := getConfigByReq(req)

hostnames := conf.SERVICE.HOSTNAMES

if hostnames == nil {
hostnames = getConfig("").SERVICE.HOSTNAMES
}

if len(hostnames) > 0 {
URL := getContext[*url.URL](req, originURLKey)

hostname := URL.Hostname()

if hostname == "" {
logger.Error("Encountered empty hostname")
http.Error(w, "Bad Request: invalid hostname", http.StatusBadRequest)
return
}

if !slices.Contains(hostnames, hostname) {
logger.Warn("Client tried using Token with wrong hostname")
onUnauthorized(w)
return
}
}

next.ServeHTTP(w, req)
})
}
30 changes: 27 additions & 3 deletions internals/proxy/middlewares/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"net"
"net/http"
"net/url"
"strings"
)

Expand All @@ -14,6 +15,7 @@ var InternalProxy Middleware = Middleware{

const trustedProxyKey contextKey = "isProxyTrusted"
const clientIPKey contextKey = "clientIP"
const originURLKey contextKey = "originURL"

func proxyHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
Expand All @@ -32,6 +34,8 @@ func proxyHandler(next http.Handler) http.Handler {

host, _, _ := net.SplitHostPort(req.RemoteAddr)

originUrl := req.Proto + "://" + req.URL.Host

ip = net.ParseIP(host)

if len(rawTrustedProxies) != 0 {
Expand All @@ -50,10 +54,30 @@ func proxyHandler(next http.Handler) http.Handler {
if realIP != nil {
ip = realIP
}

XFHost := req.Header.Get("X-Forwarded-Host")
XFProto := req.Header.Get("X-Forwarded-Proto")
XFPort := req.Header.Get("X-Forwarded-Port")

if XFHost == "" || XFProto == "" || XFPort == "" {
logger.Warn("Missing X-Forwarded-* headers")
}

originUrl = XFProto + "://" + XFHost + ":" + XFPort
}

originURL, err := url.Parse(originUrl)

if err != nil {
logger.Error("Could not parse Url: ", originUrl)
http.Error(w, "Bad Request: invalid Url", http.StatusBadRequest)
return
}

req = setContext(req, clientIPKey, ip)
req = setContext(req, trustedProxyKey, trusted)
req = setContext(req, originURLKey, originURL)

req = setContext(req, clientIPKey, ip)

next.ServeHTTP(w, req)
})
Expand Down Expand Up @@ -123,13 +147,13 @@ func getRealIP(req *http.Request) (net.IP, error) {
realIP := net.ParseIP(strings.TrimSpace(ips[0]))

if realIP == nil {
return nil, errors.New("malformed x-forwarded-for header")
return nil, errors.New("malformed X-Forwarded-For header")
}

return realIP, nil
}

return nil, errors.New("no x-forwarded-for header present")
return nil, errors.New("no X-Forwarded-For header present")
}

func isIPInList(ip net.IP, list []*net.IPNet) bool {
Expand Down
3 changes: 2 additions & 1 deletion internals/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ func (proxy Proxy) Init() http.Handler {
Use(m.InternalClientIP).
Use(m.RequestLogger).
Use(m.InternalAuthRequirement).
Use(m.IPFilter).
Use(m.Port).
Use(m.Hostname).
Use(m.IPFilter).
Use(m.RateLimit).
Use(m.Template).
Use(m.Endpoints).
Expand Down