diff --git a/internals/proxy/middlewares/auth.go b/internals/proxy/middlewares/auth.go index 9d0f2725..98ffa6a2 100644 --- a/internals/proxy/middlewares/auth.go +++ b/internals/proxy/middlewares/auth.go @@ -3,11 +3,14 @@ package middlewares import ( "context" "encoding/base64" + "errors" "maps" "net/http" + "net/url" "slices" "strings" + "github.com/codeshelldev/gotl/pkg/logger" log "github.com/codeshelldev/gotl/pkg/logger" "github.com/codeshelldev/secured-signal-api/internals/config" ) @@ -17,79 +20,157 @@ var Auth Middleware = Middleware{ Use: authHandler, } -func authHandler(next http.Handler) http.Handler { - tokenKeys := maps.Keys(config.ENV.CONFIGS) - tokens := slices.Collect(tokenKeys) +type AuthMethod struct { + Name string + Authenticate func(req *http.Request, tokens []string) (bool, error) +} - if tokens == nil { - tokens = []string{} - } +var BearerAuth = AuthMethod { + Name: "Bearer", + Authenticate: func(req *http.Request, tokens []string) (bool, error) { + header := req.Header.Get("Authorization") - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if len(tokens) <= 0 { - next.ServeHTTP(w, req) - return + headerParts := strings.SplitN(header, " ", 2) + + if len(headerParts) != 2 { + return false, nil } - authHeader := req.Header.Get("Authorization") + if strings.ToLower(headerParts[0]) == "bearer" { + if isValidToken(tokens, headerParts[1]) { + return true, nil + } + + return false, errors.New("invalid Bearer token") + } - authQuery := req.URL.Query().Get("@authorization") + return false, nil + }, +} - var authType authType = None +var BasicAuth = AuthMethod { + Name: "Basic", + Authenticate: func(req *http.Request, tokens []string) (bool, error) { + header := req.Header.Get("Authorization") - var authToken string + if strings.TrimSpace(header) == "" { + return false, nil + } - success := false + headerParts := strings.SplitN(header, " ", 2) - if authHeader != "" { - authBody := strings.Split(authHeader, " ") + if len(headerParts) != 2 { + return false, nil + } - authType = getAuthType(authBody[0]) - authToken = authBody[1] + if strings.ToLower(headerParts[0]) == "basic" { + base64Bytes, err := base64.StdEncoding.DecodeString(headerParts[1]) - switch authType { - case Bearer: - if isValidToken(tokens, authToken) { - success = true - } - case Basic: - basicAuthBody, err := base64.StdEncoding.DecodeString(authToken) + if err != nil { + log.Error("Could not decode Basic auth payload: ", err.Error()) + return false, errors.New("invalid base64 in Basic auth") + } - if err != nil { - log.Error("Could not decode Basic Auth Payload: ", err.Error()) - } + parts := strings.SplitN(string(base64Bytes), ":", 2) - basicAuth := string(basicAuthBody) - basicAuthParts := strings.Split(basicAuth, ":") + if len(parts) != 2 { + return false, errors.New("Basic auth must be user:password") + } - user := "api" - authToken = basicAuthParts[1] + user, password := parts[0], parts[1] - if basicAuthParts[0] == user && isValidToken(tokens, authToken) { - success = true - } + if strings.ToLower(user) == "api" && isValidToken(tokens, password) { + return true, nil } - } else if authQuery != "" { - authType = Query + return false, errors.New("invalid user:password") + } - authToken = strings.TrimSpace(authQuery) + return false, nil + }, +} - if isValidToken(tokens, authToken) { - success = true +var QueryAuth = AuthMethod { + Name: "Query", + Authenticate: func(req *http.Request, tokens []string) (bool, error) { + const authQuery = "@authorization" - modifiedQuery := req.URL.Query() + auth := req.URL.Query().Get(authQuery) - modifiedQuery.Del("@authorization") + if strings.TrimSpace(auth) == "" { + return false, nil + } - req.URL.RawQuery = modifiedQuery.Encode() - } + if isValidToken(tokens, auth) { + query := req.URL.Query() + + query.Del(authQuery) + + req.URL.RawQuery = query.Encode() + + return true, nil + } + + return false, errors.New("invalid Query token") + }, +} + +var PathAuth = AuthMethod { + Name: "Path", + Authenticate: func(req *http.Request, tokens []string) (bool, error) { + parts := strings.Split(req.URL.Path, "/") + + if len(parts) == 0 { + return false, nil + } + + unescaped, err := url.PathUnescape(parts[1]) + + if err != nil { + return false, nil + } + + auth, exists := strings.CutPrefix(unescaped, "auth=") + + if !exists { + return false, nil + } + + if isValidToken(tokens, auth) { + return true, nil + } + + return false, errors.New("invalid Path token") + }, +} + +func authHandler(next http.Handler) http.Handler { + tokenKeys := maps.Keys(config.ENV.CONFIGS) + tokens := slices.Collect(tokenKeys) + + if tokens == nil { + tokens = []string{} + } + + var authChain = NewAuthChain(). + Use(BearerAuth). + Use(BasicAuth). + Use(QueryAuth). + Use(PathAuth) + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if len(tokens) <= 0 { + next.ServeHTTP(w, req) + return } + var authToken string + + success, _ := authChain.Eval(req, tokens) + if !success { w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"") - log.Warn("User failed ", string(authType), " Auth") http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -101,17 +182,41 @@ func authHandler(next http.Handler) http.Handler { }) } -func getAuthType(str string) authType { - switch str { - case "Bearer": - return Bearer - case "Basic": - return Basic - default: - return None - } -} - func isValidToken(tokens []string, match string) bool { return slices.Contains(tokens, match) } + +type AuthChain struct { + methods []AuthMethod +} + +func NewAuthChain() *AuthChain { + return &AuthChain{} +} + +func (chain *AuthChain) Use(method AuthMethod) *AuthChain { + chain.methods = append(chain.methods, method) + + logger.Debug("Registered ", method.Name, " auth") + + return chain +} + +func (chain *AuthChain) Eval(req *http.Request, tokens []string) (bool, error) { + var err error + var success bool + + for _, method := range chain.methods { + success, err = method.Authenticate(req, tokens) + + if err != nil { + logger.Warn("User failed ", method.Name, " auth: ", err.Error()) + } + + if success { + return success, nil + } + } + + return false, err +} \ No newline at end of file diff --git a/internals/proxy/middlewares/common.go b/internals/proxy/middlewares/common.go index 6984b471..1f2a9878 100644 --- a/internals/proxy/middlewares/common.go +++ b/internals/proxy/middlewares/common.go @@ -11,15 +11,6 @@ type Context struct { Next http.Handler } -type authType string - -const ( - Bearer authType = "Bearer" - Basic authType = "Basic" - Query authType = "Query" - None authType = "None" -) - type contextKey string const tokenKey contextKey = "token" diff --git a/internals/proxy/middlewares/middleware.go b/internals/proxy/middlewares/middleware.go index cc222e35..3588e161 100644 --- a/internals/proxy/middlewares/middleware.go +++ b/internals/proxy/middlewares/middleware.go @@ -22,7 +22,7 @@ func NewChain() *Chain { func (chain *Chain) Use(middleware Middleware) *Chain { chain.middlewares = append(chain.middlewares, middleware) - logger.Debug("Registered ", middleware.Name) + logger.Debug("Registered ", middleware.Name, " middleware") return chain }