diff --git a/internals/proxy/middlewares/auth.go b/internals/proxy/middlewares/auth.go index 98ffa6a2..11b0eadd 100644 --- a/internals/proxy/middlewares/auth.go +++ b/internals/proxy/middlewares/auth.go @@ -12,6 +12,7 @@ import ( "github.com/codeshelldev/gotl/pkg/logger" log "github.com/codeshelldev/gotl/pkg/logger" + "github.com/codeshelldev/gotl/pkg/request" "github.com/codeshelldev/secured-signal-api/internals/config" ) @@ -22,12 +23,12 @@ var Auth Middleware = Middleware{ type AuthMethod struct { Name string - Authenticate func(req *http.Request, tokens []string) (bool, error) + Authenticate func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) } var BearerAuth = AuthMethod { Name: "Bearer", - Authenticate: func(req *http.Request, tokens []string) (bool, error) { + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { header := req.Header.Get("Authorization") headerParts := strings.SplitN(header, " ", 2) @@ -50,7 +51,7 @@ var BearerAuth = AuthMethod { var BasicAuth = AuthMethod { Name: "Basic", - Authenticate: func(req *http.Request, tokens []string) (bool, error) { + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { header := req.Header.Get("Authorization") if strings.TrimSpace(header) == "" { @@ -90,9 +91,50 @@ var BasicAuth = AuthMethod { }, } +var BodyAuth = AuthMethod { + Name: "Body", + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { + const authField = "auth" + + body, err := request.GetReqBody(req) + + if err != nil { + return false, nil + } + + body.Write(req) + + if body.Empty { + return false, nil + } + + value, exists := body.Data[authField] + + if !exists { + return false, nil + } + + auth, ok := value.(string) + + if !ok { + return false, nil + } + + if isValidToken(tokens, auth) { + delete(body.Data, authField) + + body.Write(req) + + return true, nil + } + + return false, errors.New("invalid Body token") + }, +} + var QueryAuth = AuthMethod { Name: "Query", - Authenticate: func(req *http.Request, tokens []string) (bool, error) { + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { const authQuery = "@authorization" auth := req.URL.Query().Get(authQuery) @@ -117,7 +159,7 @@ var QueryAuth = AuthMethod { var PathAuth = AuthMethod { Name: "Path", - Authenticate: func(req *http.Request, tokens []string) (bool, error) { + Authenticate: func(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { parts := strings.Split(req.URL.Path, "/") if len(parts) == 0 { @@ -155,6 +197,7 @@ func authHandler(next http.Handler) http.Handler { var authChain = NewAuthChain(). Use(BearerAuth). Use(BasicAuth). + Use(BodyAuth). Use(QueryAuth). Use(PathAuth) @@ -166,11 +209,11 @@ func authHandler(next http.Handler) http.Handler { var authToken string - success, _ := authChain.Eval(req, tokens) + success, _ := authChain.Eval(w, req, tokens) if !success { + logger.Warn("User failed to provide auth") w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"") - http.Error(w, "Unauthorized", http.StatusUnauthorized) return } @@ -202,12 +245,12 @@ func (chain *AuthChain) Use(method AuthMethod) *AuthChain { return chain } -func (chain *AuthChain) Eval(req *http.Request, tokens []string) (bool, error) { +func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []string) (bool, error) { var err error var success bool for _, method := range chain.methods { - success, err = method.Authenticate(req, tokens) + success, err = method.Authenticate(w, req, tokens) if err != nil { logger.Warn("User failed ", method.Name, " auth: ", err.Error())