Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions data/defaults.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
api:
auth:
methods: [bearer, basic, body]

service:
port: 8880

logLevel: info
logLevel: info

settings:
message:
Expand Down
11 changes: 11 additions & 0 deletions internals/config/structure/structure.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ type SERVICE struct {
type API struct {
URL string `koanf:"url" env>aliases:".apiurl"`
TOKENS []string `koanf:"tokens" env>aliases:".apitokens,.apitoken" aliases:"token"`
AUTH AUTH `koanf:"auth"`
}

type AUTH struct {
METHODS []string `koanf:"methods" env>aliases:".authmethods"`
TOKENS []Token `koanf:"tokens" aliases:"token"`
}

type Token struct {
Set []string `koanf:"set"`
Methods []string `koanf:"methods"`
}

type SETTINGS struct {
Expand Down
15 changes: 13 additions & 2 deletions internals/config/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func NormalizeTokens() {
}

func InitTokens() {
apiTokens := DEFAULT.API.TOKENS
apiTokens := parseAuthTokens(*DEFAULT)

for _, token := range apiTokens {
ENV.CONFIGS[token] = DEFAULT
Expand Down Expand Up @@ -80,14 +80,25 @@ func parseTokenConfigs(configArray []structure.CONFIG) map[string]structure.CONF
configs := map[string]structure.CONFIG{}

for _, config := range configArray {
for _, token := range config.API.TOKENS {
tokens := parseAuthTokens(config)
for _, token := range tokens {
configs[token] = config
}
}

return configs
}

func parseAuthTokens(config structure.CONFIG) []string {
tokens := config.API.TOKENS

for _, token := range config.API.AUTH.TOKENS {
tokens = append(tokens, token.Set...)
}

return tokens
}

func getSchemeTagByPointer(config any, tag string, fieldPointer any) string {
v := reflect.ValueOf(config)
if v.Kind() == reflect.Pointer {
Expand Down
185 changes: 119 additions & 66 deletions internals/proxy/middlewares/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/codeshelldev/gotl/pkg/logger"
"github.com/codeshelldev/gotl/pkg/request"
"github.com/codeshelldev/secured-signal-api/internals/config"
"github.com/codeshelldev/secured-signal-api/internals/config/structure"
)

var Auth Middleware = Middleware{
Expand All @@ -22,59 +23,6 @@ var Auth Middleware = Middleware{
const tokenKey contextKey = "token"
const isAuthKey contextKey = "isAuthenticated"

func authHandler(next http.Handler) http.Handler {
var authChain = NewAuthChain().
Use(BearerAuth).
Use(BasicAuth).
Use(BodyAuth).
Use(QueryAuth).
Use(PathAuth)

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tokenKeys := maps.Keys(config.ENV.CONFIGS)
tokens := slices.Collect(tokenKeys)

if tokens == nil {
tokens = []string{}
}

if config.ENV.INSECURE || len(tokens) <= 0 {
next.ServeHTTP(w, req)
return
}

token, _ := authChain.Eval(w, req, tokens)

if token == "" {
onUnauthorized(w)

req = setContext(req, isAuthKey, false)
} else {
req = setContext(req, isAuthKey, true)
req = setContext(req, tokenKey, token)
}

next.ServeHTTP(w, req)
})
}

var InternalAuthRequirement Middleware = Middleware{
Name: "_Auth_Requirement",
Use: authRequirementHandler,
}

func authRequirementHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
isAuthenticated := getContext[bool](req, isAuthKey)

if !isAuthenticated {
return
}

next.ServeHTTP(w, req)
})
}

type AuthMethod struct {
Name string
Authenticate func(w http.ResponseWriter, req *http.Request, tokens []string) (string, error)
Expand Down Expand Up @@ -240,16 +188,6 @@ var PathAuth = AuthMethod{
},
}

func onUnauthorized(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"")

http.Error(w, "Unauthorized", http.StatusUnauthorized)
}

func isValidToken(tokens []string, match string) bool {
return slices.Contains(tokens, match)
}

type AuthChain struct {
methods []AuthMethod
}
Expand All @@ -266,7 +204,7 @@ func (chain *AuthChain) Use(method AuthMethod) *AuthChain {
return chain
}

func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []string) (string, error) {
func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []string) (AuthMethod, string, error) {
var err error
var token string

Expand All @@ -278,11 +216,126 @@ func (chain *AuthChain) Eval(w http.ResponseWriter, req *http.Request, tokens []
}

if token != "" {
return token, nil


return method, token, nil
}
}

logger.Warn("Client failed to provide any auth")

return "", err
return AuthMethod{}, "", err
}

func authHandler(next http.Handler) http.Handler {
var authChain = NewAuthChain().
Use(BearerAuth).
Use(BasicAuth).
Use(BodyAuth).
Use(QueryAuth).
Use(PathAuth)

return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
tokenKeys := maps.Keys(config.ENV.CONFIGS)
tokens := slices.Collect(tokenKeys)

if tokens == nil {
tokens = []string{}
}

if config.ENV.INSECURE || len(tokens) <= 0 {
next.ServeHTTP(w, req)
return
}

method, token, _ := authChain.Eval(w, req, tokens)

if token == "" {
onUnauthorized(w)

req = setContext(req, isAuthKey, false)
} else {
conf := getConfigWithoutDefault(token)

allowedMethods := conf.API.AUTH.METHODS

if allowedMethods == nil {
allowedMethods = getConfig("").API.AUTH.METHODS
}

if isAuthMethodAllowed(method, token, conf.API.TOKENS, conf.API.AUTH.METHODS, conf.API.AUTH.TOKENS) {
req = setContext(req, isAuthKey, true)
req = setContext(req, tokenKey, token)
} else {
logger.Warn("Client tried using disabled auth method: ", method.Name)

onUnauthorized(w)

req = setContext(req, isAuthKey, false)
}
}

next.ServeHTTP(w, req)
})
}

var InternalAuthRequirement Middleware = Middleware{
Name: "_Auth_Requirement",
Use: authRequirementHandler,
}

func authRequirementHandler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
isAuthenticated := getContext[bool](req, isAuthKey)

if !isAuthenticated {
return
}

next.ServeHTTP(w, req)
})
}

func onUnauthorized(w http.ResponseWriter) {
w.Header().Set("WWW-Authenticate", "Basic realm=\"Login Required\", Bearer realm=\"Access Token Required\"")

http.Error(w, "Unauthorized", http.StatusUnauthorized)
}

func isValidToken(tokens []string, match string) bool {
return slices.Contains(tokens, match)
}

type AuthToken struct {
Token string
Methods []string
}

func getTokenMethodMap(rawTokens []string, defaultMethods []string, tokenMethodSet []structure.Token) map[string][]string {
tokenMethodMap := map[string][]string{}

for _, token := range rawTokens {
tokenMethodMap[token] = defaultMethods
}

for _, set := range tokenMethodSet {
for _, token := range set.Set {
tokenMethodMap[token] = set.Methods
}
}

return tokenMethodMap
}

func isAuthMethodAllowed(method AuthMethod, token string, rawTokens []string, defaultMethods []string, tokenMethodSet []structure.Token) bool {
if (len(defaultMethods) == 0 || defaultMethods == nil) && (len(tokenMethodSet) == 0 || tokenMethodSet == nil) {
// default: allow all
return true
}

tokenMethodMap := getTokenMethodMap(rawTokens, defaultMethods, tokenMethodSet)

return slices.ContainsFunc(tokenMethodMap[token], func(try string) bool {
return strings.EqualFold(try, method.Name)
})
}