Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 162 additions & 57 deletions internals/proxy/middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package middlewares
import (
"context"
"encoding/base64"
"errors"
"maps"
"net/http"
"net/url"
"slices"
"strings"

"github.com/codeshelldev/gotl/pkg/logger"
log "github.com/codeshelldev/gotl/pkg/logger"
"github.com/codeshelldev/secured-signal-api/internals/config"
)
Expand All @@ -17,79 +20,157 @@ var Auth Middleware = Middleware{
Use: authHandler,
}

func authHandler(next http.Handler) http.Handler {
tokenKeys := maps.Keys(config.ENV.CONFIGS)
tokens := slices.Collect(tokenKeys)
type AuthMethod struct {
Name string
Authenticate func(req *http.Request, tokens []string) (bool, error)
}

if tokens == nil {
tokens = []string{}
}
var BearerAuth = AuthMethod {
Name: "Bearer",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
header := req.Header.Get("Authorization")

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if len(tokens) <= 0 {
next.ServeHTTP(w, req)
return
headerParts := strings.SplitN(header, " ", 2)

if len(headerParts) != 2 {
return false, nil
}

authHeader := req.Header.Get("Authorization")
if strings.ToLower(headerParts[0]) == "bearer" {
if isValidToken(tokens, headerParts[1]) {
return true, nil
}

return false, errors.New("invalid Bearer token")
}

authQuery := req.URL.Query().Get("@authorization")
return false, nil
},
}

var authType authType = None
var BasicAuth = AuthMethod {
Name: "Basic",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
header := req.Header.Get("Authorization")

var authToken string
if strings.TrimSpace(header) == "" {
return false, nil
}

success := false
headerParts := strings.SplitN(header, " ", 2)

if authHeader != "" {
authBody := strings.Split(authHeader, " ")
if len(headerParts) != 2 {
return false, nil
}

authType = getAuthType(authBody[0])
authToken = authBody[1]
if strings.ToLower(headerParts[0]) == "basic" {
base64Bytes, err := base64.StdEncoding.DecodeString(headerParts[1])

switch authType {
case Bearer:
if isValidToken(tokens, authToken) {
success = true
}
case Basic:
basicAuthBody, err := base64.StdEncoding.DecodeString(authToken)
if err != nil {
log.Error("Could not decode Basic auth payload: ", err.Error())
return false, errors.New("invalid base64 in Basic auth")
}

if err != nil {
log.Error("Could not decode Basic Auth Payload: ", err.Error())
}
parts := strings.SplitN(string(base64Bytes), ":", 2)

basicAuth := string(basicAuthBody)
basicAuthParts := strings.Split(basicAuth, ":")
if len(parts) != 2 {
return false, errors.New("Basic auth must be user:password")
}

user := "api"
authToken = basicAuthParts[1]
user, password := parts[0], parts[1]

if basicAuthParts[0] == user && isValidToken(tokens, authToken) {
success = true
}
if strings.ToLower(user) == "api" && isValidToken(tokens, password) {
return true, nil
}

} else if authQuery != "" {
authType = Query
return false, errors.New("invalid user:password")
}

authToken = strings.TrimSpace(authQuery)
return false, nil
},
}

if isValidToken(tokens, authToken) {
success = true
var QueryAuth = AuthMethod {
Name: "Query",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
const authQuery = "@authorization"

modifiedQuery := req.URL.Query()
auth := req.URL.Query().Get(authQuery)

modifiedQuery.Del("@authorization")
if strings.TrimSpace(auth) == "" {
return false, nil
}

req.URL.RawQuery = modifiedQuery.Encode()
}
if isValidToken(tokens, auth) {
query := req.URL.Query()

query.Del(authQuery)

req.URL.RawQuery = query.Encode()

return true, nil
}

return false, errors.New("invalid Query token")
},
}

var PathAuth = AuthMethod {
Name: "Path",
Authenticate: func(req *http.Request, tokens []string) (bool, error) {
parts := strings.Split(req.URL.Path, "/")

if len(parts) == 0 {
return false, nil
}

unescaped, err := url.PathUnescape(parts[1])

if err != nil {
return false, nil
}

auth, exists := strings.CutPrefix(unescaped, "auth=")

if !exists {
return false, nil
}

if isValidToken(tokens, auth) {
return true, nil
}

return false, errors.New("invalid Path token")
},
}

func authHandler(next http.Handler) http.Handler {
tokenKeys := maps.Keys(config.ENV.CONFIGS)
tokens := slices.Collect(tokenKeys)

if tokens == nil {
tokens = []string{}
}

var authChain = NewAuthChain().
Use(BearerAuth).
Use(BasicAuth).
Use(QueryAuth).
Use(PathAuth)

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if len(tokens) <= 0 {
next.ServeHTTP(w, req)
return
}

var authToken string

success, _ := authChain.Eval(req, tokens)

if !success {
w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"")

log.Warn("User failed ", string(authType), " Auth")
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
Expand All @@ -101,17 +182,41 @@ func authHandler(next http.Handler) http.Handler {
})
}

func getAuthType(str string) authType {
switch str {
case "Bearer":
return Bearer
case "Basic":
return Basic
default:
return None
}
}

func isValidToken(tokens []string, match string) bool {
return slices.Contains(tokens, match)
}

type AuthChain struct {
methods []AuthMethod
}

func NewAuthChain() *AuthChain {
return &AuthChain{}
}

func (chain *AuthChain) Use(method AuthMethod) *AuthChain {
chain.methods = append(chain.methods, method)

logger.Debug("Registered ", method.Name, " auth")

return chain
}

func (chain *AuthChain) Eval(req *http.Request, tokens []string) (bool, error) {
var err error
var success bool

for _, method := range chain.methods {
success, err = method.Authenticate(req, tokens)

if err != nil {
logger.Warn("User failed ", method.Name, " auth: ", err.Error())
}

if success {
return success, nil
}
}

return false, err
}
9 changes: 0 additions & 9 deletions internals/proxy/middlewares/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ type Context struct {
Next http.Handler
}

type authType string

const (
Bearer authType = "Bearer"
Basic authType = "Basic"
Query authType = "Query"
None authType = "None"
)

type contextKey string

const tokenKey contextKey = "token"
Expand Down
2 changes: 1 addition & 1 deletion internals/proxy/middlewares/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func NewChain() *Chain {
func (chain *Chain) Use(middleware Middleware) *Chain {
chain.middlewares = append(chain.middlewares, middleware)

logger.Debug("Registered ", middleware.Name)
logger.Debug("Registered ", middleware.Name, " middleware")

return chain
}
Expand Down