From cc357f35efdb664e0b1f752a5333ff1b9d10536a Mon Sep 17 00:00:00 2001 From: Stavros Date: Thu, 7 May 2026 19:14:05 +0300 Subject: [PATCH 01/22] feat: add new logger --- internal/utils/logger/logger.go | 157 ++++++++++++++++++++++++ internal/utils/logger/logger_test.go | 173 +++++++++++++++++++++++++++ 2 files changed, 330 insertions(+) create mode 100644 internal/utils/logger/logger.go create mode 100644 internal/utils/logger/logger_test.go diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go new file mode 100644 index 00000000..18d319fb --- /dev/null +++ b/internal/utils/logger/logger.go @@ -0,0 +1,157 @@ +package logger + +import ( + "io" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/tinyauthapp/tinyauth/internal/model" +) + +type Logger struct { + HTTP zerolog.Logger + App zerolog.Logger + config model.LogConfig + base zerolog.Logger + audit zerolog.Logger + writer io.Writer +} + +func NewLogger() *Logger { + return &Logger{ + writer: os.Stderr, + config: model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{ + Enabled: true, + }, + App: model.LogStreamConfig{ + Enabled: true, + }, + // No reason to enabled audit by default since it will be surpressed by the log level + }, + }, + } +} + +func (l *Logger) WithConfig(cfg model.LogConfig) *Logger { + l.config = cfg + return l +} + +func (l *Logger) WithSimpleConfig() *Logger { + l.config = model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + return l +} + +func (l *Logger) WithTestConfig() *Logger { + l.config = model.LogConfig{ + Level: "trace", + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + } + return l +} + +func (l *Logger) WithWriter(writer io.Writer) *Logger { + l.writer = writer + return l +} + +func (l *Logger) Init() { + base := log.With(). + Timestamp(). + Caller(). + Logger(). + Level(l.parseLogLevel(l.config.Level)).Output(l.writer) + + if !l.config.Json { + base = base.Output(zerolog.ConsoleWriter{ + Out: l.writer, + TimeFormat: time.RFC3339, + }) + } + + l.base = base + l.audit = l.createLogger("audit", l.config.Streams.Audit) + l.HTTP = l.createLogger("http", l.config.Streams.HTTP) + l.App = l.createLogger("app", l.config.Streams.App) +} + +func (l *Logger) parseLogLevel(level string) zerolog.Level { + if level == "" { + return zerolog.InfoLevel + } + parsed, err := zerolog.ParseLevel(strings.ToLower(level)) + if err != nil { + log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to error") + parsed = zerolog.ErrorLevel + } + return parsed +} + +func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerolog.Logger { + if !cfg.Enabled { + return zerolog.Nop() + } + sub := l.base.With().Str("stream", component).Logger() + if cfg.Level != "" { + sub = sub.Level(l.parseLogLevel(cfg.Level)) + } + return sub +} + +func (l *Logger) AuditLoginSuccess(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +func (l *Logger) AuditLoginFailure(username, provider, ip, reason string) { + l.audit.Warn(). + CallerSkipFrame(1). + Str("event", "login"). + Str("result", "failure"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Str("reason", reason). + Send() +} + +func (l *Logger) AuditLogout(username, provider, ip string) { + l.audit.Info(). + CallerSkipFrame(1). + Str("event", "logout"). + Str("result", "success"). + Str("username", username). + Str("provider", provider). + Str("ip", ip). + Send() +} + +// Used for testing +func (l *Logger) GetConfig() model.LogConfig { + return l.config +} diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go new file mode 100644 index 00000000..66387a5f --- /dev/null +++ b/internal/utils/logger/logger_test.go @@ -0,0 +1,173 @@ +package logger_test + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" +) + +func TestLogger(t *testing.T) { + type testCase struct { + description string + run func(t *testing.T) + } + + tests := []testCase{ + { + description: "Should create a simple logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithSimpleConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + }, + }, + { + description: "Should create a test logger with the expected config", + run: func(t *testing.T) { + l := logger.NewLogger().WithTestConfig() + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "trace", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: true}, + }, + }) + }, + }, + { + description: "Should create a logger with a custom config", + run: func(t *testing.T) { + customCfg := model.LogConfig{ + Level: "debug", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, customCfg) + }, + }, + { + description: "Default logger should use error type and log json", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + l := logger.NewLogger().WithWriter(&buf) + l.Init() + + cfg := l.GetConfig() + + assert.Equal(t, cfg, model.LogConfig{ + Level: "error", + Json: true, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + }) + + l.App.Error().Msg("test") + + var entry map[string]any + err := json.Unmarshal(buf.Bytes(), &entry) + require.NoError(t, err) + + assert.Equal(t, "test", entry["message"]) + assert.Equal(t, "app", entry["stream"]) + assert.Equal(t, "error", entry["level"]) + assert.NotEmpty(t, entry["time"]) + }, + }, + { + description: "Should default to error level if an invalid level is provided", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "invalid", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: true}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.ErrorLevel, l.App.GetLevel()) + assert.Equal(t, zerolog.ErrorLevel, l.HTTP.GetLevel()) + + // should not get logged + l.AuditLoginFailure("test", "test", "test", "test") + + assert.Empty(t, buf.String()) + }, + }, + { + description: "Should use nop logger for disabled streams", + run: func(t *testing.T) { + buf := bytes.Buffer{} + + customCfg := model.LogConfig{ + Level: "info", + Json: false, + Streams: model.LogStreams{ + HTTP: model.LogStreamConfig{Enabled: false}, + App: model.LogStreamConfig{Enabled: true}, + Audit: model.LogStreamConfig{Enabled: false}, + }, + } + + l := logger.NewLogger().WithConfig(customCfg).WithWriter(&buf) + l.Init() + + assert.Equal(t, zerolog.Disabled, l.HTTP.GetLevel()) + + l.App.Info().Msg("test") + + l.AuditLoginFailure("test", "test", "test", "test") + + assert.NotEmpty(t, buf.String()) + assert.Equal(t, 119, buf.Len()) // it's the length of the test log entry + }, + }, + } + + for _, test := range tests { + t.Run(test.description, test.run) + } +} From 592c221b2dcc8cde5410a8e822df49aa0ad32ac1 Mon Sep 17 00:00:00 2001 From: Stavros Date: Thu, 7 May 2026 22:31:51 +0300 Subject: [PATCH 02/22] refactor: use one struct for context handling and cancellation --- cmd/tinyauth/tinyauth.go | 6 - internal/bootstrap/app_bootstrap.go | 372 +++++++++++++------- internal/bootstrap/db_bootstrap.go | 21 +- internal/bootstrap/router_bootstrap.go | 33 +- internal/bootstrap/service_bootstrap.go | 84 ++--- internal/service/access_controls_service.go | 9 +- internal/utils/logger/logger.go | 4 +- 7 files changed, 323 insertions(+), 206 deletions(-) diff --git a/cmd/tinyauth/tinyauth.go b/cmd/tinyauth/tinyauth.go index f5bbb19f..b6293718 100644 --- a/cmd/tinyauth/tinyauth.go +++ b/cmd/tinyauth/tinyauth.go @@ -7,7 +7,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/loaders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/rs/zerolog/log" "github.com/tinyauthapp/paerser/cli" @@ -109,11 +108,6 @@ func main() { } func runCmd(cfg model.Config) error { - logger := tlog.NewLogger(cfg.Log) - logger.Init() - - tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth") - app := bootstrap.NewBootstrapApp(cfg) err := app.Setup() diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b342c48..5b10b192 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -3,98 +3,137 @@ package bootstrap import ( "bytes" "context" + "database/sql" "encoding/json" + "errors" "fmt" + "net" "net/http" "net/url" "os" + "os/signal" "sort" "strings" + "syscall" "time" + "github.com/gin-gonic/gin" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type BootstrapApp struct { - config model.Config - context struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - oauthSessionCookieName string - localUsers *[]model.LocalUser - oauthProviders map[string]model.OAuthServiceConfig - oauthWhitelist []string - configuredProviders []controller.Provider - oidcClients []model.OIDCClientConfig - } +type Services struct { + accessControlService *service.AccessControlsService + authService *service.AuthService + dockerService *service.DockerService + kubernetesService *service.KubernetesService + ldapService *service.LdapService + oauthBrokerService *service.OAuthBrokerService + oidcService *service.OIDCService +} + +type RuntimeConfig struct { + appUrl string + uuid string + cookieDomain string + sessionCookieName string + csrfCookieName string + redirectCookieName string + oauthSessionCookieName string + localUsers []model.LocalUser + oauthProviders map[string]model.OAuthServiceConfig + oauthWhitelist []string + configuredProviders []controller.Provider + oidcClients []model.OIDCClientConfig + labelProvider service.LabelProvider +} + +type App struct { + config model.Config + runtime RuntimeConfig services Services + log *logger.Logger + ctx context.Context + cancel context.CancelFunc + queries *repository.Queries + router *gin.Engine + db *sql.DB } -func NewBootstrapApp(config model.Config) *BootstrapApp { - return &BootstrapApp{ +func NewBootstrapApp(config model.Config) *App { + return &App{ config: config, } } -func (app *BootstrapApp) Setup() error { +func (app *App) Setup() error { + // create context + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + app.ctx = ctx + app.cancel = cancel + + // setup logger + log := logger.NewLogger().WithConfig(app.config.Log) + log.Init() + app.log = log + // get app url if app.config.AppURL == "" { - return fmt.Errorf("app URL cannot be empty, perhaps config loading failed") + return errors.New("app url cannot be empty, perhaps config loading failed") } appUrl, err := url.Parse(app.config.AppURL) if err != nil { - return err + return fmt.Errorf("failed to parse app url: %w", err) } - app.context.appUrl = appUrl.Scheme + "://" + appUrl.Host + app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { - return fmt.Errorf("session max lifetime cannot be less than session expiry") + return errors.New("session max lifetime cannot be less than session expiry") } - // Parse users + // parse users users, err := utils.GetUsers(app.config.Auth.Users, app.config.Auth.UsersFile, app.config.Auth.UserAttributes) if err != nil { - return err + return fmt.Errorf("failed to load users: %w", err) } - app.context.localUsers = users + app.runtime.localUsers = *users + // load oauth whitelist oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) + if err != nil { - return err + return fmt.Errorf("failed to load oauth whitelist: %w", err) } - app.context.oauthWhitelist = oauthWhitelist + app.runtime.oauthWhitelist = oauthWhitelist - // Setup OAuth providers - app.context.oauthProviders = app.config.OAuth.Providers + // Setup oauth providers + app.runtime.oauthProviders = app.config.OAuth.Providers - for name, provider := range app.context.oauthProviders { + for id, provider := range app.runtime.oauthProviders { secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" if provider.RedirectURL == "" { - provider.RedirectURL = app.context.appUrl + "/api/oauth/callback/" + name + provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id } - app.context.oauthProviders[name] = provider + app.runtime.oauthProviders[id] = provider } - for id, provider := range app.context.oauthProviders { + // set presets for built-in providers + for id, provider := range app.runtime.oauthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -102,70 +141,63 @@ func (app *BootstrapApp) Setup() error { provider.Name = utils.Capitalize(id) } } - app.context.oauthProviders[id] = provider + app.runtime.oauthProviders[id] = provider } - // Setup OIDC clients + // setup oidc clients for id, client := range app.config.OIDC.Clients { client.ID = id - app.context.oidcClients = append(app.context.oidcClients, client) + app.runtime.oidcClients = append(app.runtime.oidcClients, client) } - // Get cookie domain + // cookie domain cookieDomainResolver := utils.GetCookieDomain + if !app.config.Auth.SubdomainsEnabled { - tlog.App.Info().Msg("Subdomains disabled, automatic authentication for proxied apps will not work") + app.log.App.Warn().Msg("Subdomains are disabled, using standalone cookie domain resolver which will not work with subdomains") cookieDomainResolver = utils.GetStandaloneCookieDomain } - cookieDomain, err := cookieDomainResolver(app.context.appUrl) + cookieDomain, err := cookieDomainResolver(app.runtime.appUrl) if err != nil { - return err + return fmt.Errorf("failed to get cookie domain: %w", err) } - app.context.cookieDomain = cookieDomain + app.runtime.cookieDomain = cookieDomain - // Cookie names - app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) - cookieId := strings.Split(app.context.uuid, "-")[0] - app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) - app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) - app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) - app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + // cookie names + app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname()) - // Dumps - tlog.App.Trace().Interface("config", app.config).Msg("Config dump") - tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump") - tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") - tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") - tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") - tlog.App.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name") - tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") + cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough - // Database - db, err := app.SetupDatabase(app.config.Database.Path) + app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + + // database + err = app.SetupDatabase() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) + // queries + queries := repository.New(app.db) + app.queries = queries - // Services - services, err := app.initServices(queries) + // services + err = app.setupServices() if err != nil { return fmt.Errorf("failed to initialize services: %w", err) } - app.services = services - - // Configured providers + // configured providers configuredProviders := make([]controller.Provider, 0) - for id, provider := range app.context.oauthProviders { + for id, provider := range app.runtime.oauthProviders { configuredProviders = append(configuredProviders, controller.Provider{ Name: provider.Name, ID: id, @@ -177,7 +209,7 @@ func (app *BootstrapApp) Setup() error { return configuredProviders[i].Name < configuredProviders[j].Name }) - if services.authService.LocalAuthConfigured() { + if app.services.authService.LocalAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "Local", ID: "local", @@ -185,7 +217,7 @@ func (app *BootstrapApp) Setup() error { }) } - if services.authService.LDAPAuthConfigured() { + if app.services.authService.LDAPAuthConfigured() { configuredProviders = append(configuredProviders, controller.Provider{ Name: "LDAP", ID: "ldap", @@ -193,77 +225,150 @@ func (app *BootstrapApp) Setup() error { }) } - tlog.App.Debug().Interface("providers", configuredProviders).Msg("Authentication providers") - if len(configuredProviders) == 0 { - return fmt.Errorf("no authentication providers configured") + return errors.New("no authentication providers configured") + } + + for _, provider := range app.runtime.configuredProviders { + app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") } - app.context.configuredProviders = configuredProviders + app.runtime.configuredProviders = configuredProviders - // Setup router - router, err := app.setupRouter() + // setup router + err = app.setupRouter() if err != nil { return fmt.Errorf("failed to setup routes: %w", err) } - // Start db cleanup routine - tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + // start db cleanup routine + app.log.App.Debug().Msg("Starting database cleanup routine") + go app.dbCleanupRoutine() - // If analytics are not disabled, start heartbeat + // if analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { - tlog.App.Debug().Msg("Starting heartbeat routine") + app.log.App.Debug().Msg("Starting heartbeat routine") go app.heartbeatRoutine() } - // If we have an socket path, bind to it - if app.config.Server.SocketPath != "" { - if _, err := os.Stat(app.config.Server.SocketPath); err == nil { - tlog.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) - err := os.Remove(app.config.Server.SocketPath) - if err != nil { - return fmt.Errorf("failed to remove existing socket file: %w", err) - } - } + // create err channel to listen for server errors + errChan := make(chan error, 1) - tlog.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) - if err := router.RunUnix(app.config.Server.SocketPath); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") - } + // serve unix + go func() { + errChan <- app.serveUnix() + }() + + // serve to http + go func() { + errChan <- app.serveHTTP() + }() + // monitor cancellation and server errors + select { + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Shutting down application") return nil + case err := <-errChan: + if err != nil { + return fmt.Errorf("server error: %w", err) + } } - // Start server + return nil +} + +func (app *App) serveHTTP() error { address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - tlog.App.Info().Msgf("Starting server on %s", address) - if err := router.Run(address); err != nil { - tlog.App.Fatal().Err(err).Msg("Failed to start server") + + app.log.App.Info().Msgf("Starting server on %s", address) + + server := &http.Server{ + Addr: address, + Handler: app.router.Handler(), + } + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down server") + server.Close() + }() + + err := server.ListenAndServe() + + if err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("failed to start server: %w", err) + } + + return nil +} + +func (app *App) serveUnix() error { + if app.config.Server.SocketPath == "" { + return nil + } + + _, err := os.Stat(app.config.Server.SocketPath) + + if err == nil { + app.log.App.Info().Msgf("Removing existing socket file %s", app.config.Server.SocketPath) + err := os.Remove(app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to remove existing socket file: %w", err) + } + } + + app.log.App.Info().Msgf("Starting server on unix socket %s", app.config.Server.SocketPath) + + listener, err := net.Listen("unix", app.config.Server.SocketPath) + + if err != nil { + return fmt.Errorf("failed to create unix socket listner: %w", err) + } + + defer listener.Close() + defer os.Remove(app.config.Server.SocketPath) + + go func() { + <-app.ctx.Done() + app.log.App.Debug().Msg("Shutting down server") + listener.Close() + os.Remove(app.config.Server.SocketPath) + }() + + server := &http.Server{ + Handler: app.router.Handler(), + } + + err = server.Serve(listener) + + if err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to start server: %w", err) } return nil } -func (app *BootstrapApp) heartbeatRoutine() { +func (app *App) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() - type heartbeat struct { + type Heartbeat struct { UUID string `json:"uuid"` Version string `json:"version"` } - var body heartbeat + var body Heartbeat - body.UUID = app.context.uuid + body.UUID = app.runtime.uuid body.Version = model.Version bodyJson, err := json.Marshal(body) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to marshal heartbeat body") + app.log.App.Error().Err(err).Msg("Failed to marshal heartbeat body, heartbeat routine will not start") return } @@ -273,43 +378,58 @@ func (app *BootstrapApp) heartbeatRoutine() { heartbeatURL := model.APIServer + "/v1/instances/heartbeat" - for range ticker.C { - tlog.App.Debug().Msg("Sending heartbeat") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Sending heartbeat") - req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) + req, err := http.NewRequest(http.MethodPost, heartbeatURL, bytes.NewReader(bodyJson)) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create heartbeat request") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to create heartbeat request") + continue + } - req.Header.Add("Content-Type", "application/json") + req.Header.Add("Content-Type", "application/json") - res, err := client.Do(req) + res, err := client.Do(req) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to send heartbeat") - continue - } + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to send heartbeat") + continue + } - res.Body.Close() + res.Body.Close() - if res.StatusCode != 200 && res.StatusCode != 201 { - tlog.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + if res.StatusCode != 200 && res.StatusCode != 201 { + app.log.App.Debug().Str("status", res.Status).Msg("Heartbeat returned non-200/201 status") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping heartbeat routine") + ticker.Stop() + return } } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *App) dbCleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - ctx := context.Background() - for range ticker.C { - tlog.App.Debug().Msg("Cleaning up old database sessions") - err := queries.DeleteExpiredSessions(ctx, time.Now().Unix()) - if err != nil { - tlog.App.Error().Err(err).Msg("Failed to clean up old database sessions") + for { + select { + case <-ticker.C: + app.log.App.Debug().Msg("Running database cleanup") + + err := app.queries.DeleteExpiredSessions(app.ctx, time.Now().Unix()) + + if err != nil { + app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") + } + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Stopping database cleanup routine") + ticker.Stop() + return } } } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..5ef5c9dc 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -14,17 +14,17 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { - dir := filepath.Dir(databasePath) +func (app *App) SetupDatabase() error { + dir := filepath.Dir(app.config.Database.Path) if err := os.MkdirAll(dir, 0750); err != nil { - return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err) + return fmt.Errorf("failed to create database directory %s: %w", dir, err) } - db, err := sql.Open("sqlite", databasePath) + db, err := sql.Open("sqlite", app.config.Database.Path) if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("failed to open database: %w", err) } // Limit to 1 connection to sequence writes, this may need to be revisited in the future @@ -34,24 +34,25 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { migrations, err := iofs.New(assets.Migrations, "migrations") if err != nil { - return nil, fmt.Errorf("failed to create migrations: %w", err) + return fmt.Errorf("failed to create migrations: %w", err) } target, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err) + return fmt.Errorf("failed to create sqlite3 instance: %w", err) } migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target) if err != nil { - return nil, fmt.Errorf("failed to create migrator: %w", err) + return fmt.Errorf("failed to create migrator: %w", err) } if err := migrator.Up(); err != nil && err != migrate.ErrNoChange { - return nil, fmt.Errorf("failed to migrate database: %w", err) + return fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + app.db = db + return nil } diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index a746be79..7310fa43 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -13,7 +13,7 @@ import ( var DEV_MODES = []string{"main", "test", "development"} -func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { +func (app *App) setupRouter() error { if !slices.Contains(DEV_MODES, model.Version) { gin.SetMode(gin.ReleaseMode) } @@ -25,19 +25,19 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err := engine.SetTrustedProxies(app.config.Auth.TrustedProxies) if err != nil { - return nil, fmt.Errorf("failed to set trusted proxies: %w", err) + return fmt.Errorf("failed to set trusted proxies: %w", err) } } contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, + CookieDomain: app.runtime.cookieDomain, + SessionCookieName: app.runtime.sessionCookieName, }, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize context middleware: %w", err) + return fmt.Errorf("failed to initialize context middleware: %w", err) } engine.Use(contextMiddleware.Middleware()) @@ -47,7 +47,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err = uiMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize UI middleware: %w", err) + return fmt.Errorf("failed to initialize UI middleware: %w", err) } engine.Use(uiMiddleware.Middleware()) @@ -57,7 +57,7 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { err = zerologMiddleware.Init() if err != nil { - return nil, fmt.Errorf("failed to initialize zerolog middleware: %w", err) + return fmt.Errorf("failed to initialize zerolog middleware: %w", err) } engine.Use(zerologMiddleware.Middleware()) @@ -65,10 +65,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { apiRouter := engine.Group("/api") contextController := controller.NewContextController(controller.ContextControllerConfig{ - Providers: app.context.configuredProviders, + Providers: app.runtime.configuredProviders, Title: app.config.UI.Title, AppURL: app.config.AppURL, - CookieDomain: app.context.cookieDomain, + CookieDomain: app.runtime.cookieDomain, ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, BackgroundImage: app.config.UI.BackgroundImage, OAuthAutoRedirect: app.config.OAuth.AutoRedirect, @@ -80,10 +80,10 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ AppURL: app.config.AppURL, SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.context.csrfCookieName, - RedirectCookieName: app.context.redirectCookieName, - CookieDomain: app.context.cookieDomain, - OAuthSessionCookieName: app.context.oauthSessionCookieName, + CSRFCookieName: app.runtime.csrfCookieName, + RedirectCookieName: app.runtime.redirectCookieName, + CookieDomain: app.runtime.cookieDomain, + OAuthSessionCookieName: app.runtime.oauthSessionCookieName, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, }, apiRouter, app.services.authService) @@ -100,8 +100,8 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { proxyController.SetupRoutes() userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.context.cookieDomain, - SessionCookieName: app.context.sessionCookieName, + CookieDomain: app.runtime.cookieDomain, + SessionCookieName: app.runtime.sessionCookieName, }, apiRouter, app.services.authService) userController.SetupRoutes() @@ -121,5 +121,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { wellknownController.SetupRoutes() - return engine, nil + app.router = engine + return nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 09485bd0..b3261180 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -1,26 +1,14 @@ package bootstrap import ( + "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type Services struct { - accessControlService *service.AccessControlsService - authService *service.AuthService - dockerService *service.DockerService - kubernetesService *service.KubernetesService - ldapService *service.LdapService - oauthBrokerService *service.OAuthBrokerService - oidcService *service.OIDCService -} - -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { - services := Services{} - +func (app *App) setupServices() error { ldapService := service.NewLdapService(service.LdapServiceConfig{ Address: app.config.LDAP.Address, BindDN: app.config.LDAP.BindDN, @@ -35,81 +23,85 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er err := ldapService.Init() if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to setup LDAP service, starting without it") + app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") ldapService.Unconfigure() } - services.ldapService = ldapService - - var labelProvider service.LabelProvider - var dockerService *service.DockerService - var kubernetesService *service.KubernetesService + app.services.ldapService = ldapService useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") if useKubernetes { - tlog.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService = service.NewKubernetesService() + app.log.App.Debug().Msg("Using Kubernetes label provider") + + kubernetesService := service.NewKubernetesService() + err = kubernetesService.Init() + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize kubernetes service: %w", err) } - services.kubernetesService = kubernetesService - labelProvider = kubernetesService + + app.services.kubernetesService = kubernetesService + app.runtime.labelProvider = service.LabelProviderKubernetes } else { tlog.App.Debug().Msg("Using Docker label provider") - dockerService = service.NewDockerService() + + dockerService := service.NewDockerService() + err = dockerService.Init() + if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize docker service: %w", err) } - services.dockerService = dockerService - labelProvider = dockerService + + app.services.dockerService = dockerService + app.runtime.labelProvider = service.LabelProviderDocker } - accessControlsService := service.NewAccessControlsService(labelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps) err = accessControlsService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize access controls service: %w", err) } - services.accessControlService = accessControlsService + app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.context.oauthProviders) + oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders) err = oauthBrokerService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oauth broker service: %w", err) } - services.oauthBrokerService = oauthBrokerService + app.services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: app.context.localUsers, - OauthWhitelist: app.context.oauthWhitelist, + LocalUsers: &app.runtime.localUsers, + OauthWhitelist: app.runtime.oauthWhitelist, SessionExpiry: app.config.Auth.SessionExpiry, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.context.cookieDomain, + CookieDomain: app.runtime.cookieDomain, LoginTimeout: app.config.Auth.LoginTimeout, LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.context.sessionCookieName, + SessionCookieName: app.runtime.sessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, services.ldapService, queries, services.oauthBrokerService) + }, app.services.ldapService, app.queries, app.services.oauthBrokerService) err = authService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize auth service: %w", err) } - services.authService = authService + app.services.authService = authService oidcService := service.NewOIDCService(service.OIDCServiceConfig{ Clients: app.config.OIDC.Clients, @@ -117,15 +109,15 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er PublicKeyPath: app.config.OIDC.PublicKeyPath, Issuer: app.config.AppURL, SessionExpiry: app.config.Auth.SessionExpiry, - }, queries) + }, app.queries) err = oidcService.Init() if err != nil { - return Services{}, err + return fmt.Errorf("failed to initialize oidc service: %w", err) } - services.oidcService = oidcService + app.services.oidcService = oidcService - return services, nil + return nil } diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index fd57bf39..c16c5a25 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -7,7 +7,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type LabelProvider interface { +type LabelProvider int + +const ( + LabelProviderDocker LabelProvider = iota + LabelProviderKubernetes +) + +type LabelProviderImpl interface { GetLabels(appDomain string) (*model.App, error) } diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go index 18d319fb..d85af79e 100644 --- a/internal/utils/logger/logger.go +++ b/internal/utils/logger/logger.go @@ -77,7 +77,6 @@ func (l *Logger) WithWriter(writer io.Writer) *Logger { func (l *Logger) Init() { base := log.With(). Timestamp(). - Caller(). Logger(). Level(l.parseLogLevel(l.config.Level)).Output(l.writer) @@ -114,6 +113,9 @@ func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerol if cfg.Level != "" { sub = sub.Level(l.parseLogLevel(cfg.Level)) } + if sub.GetLevel() == zerolog.DebugLevel { + sub = sub.With().Caller().Logger() + } return sub } From 112a30f6b2f636231c655d3d68c90a5631aebd75 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 16:39:01 +0300 Subject: [PATCH 03/22] refactor: rework logging and config in controllers --- internal/bootstrap/app_bootstrap.go | 87 ++++++-------- internal/bootstrap/db_bootstrap.go | 2 +- internal/bootstrap/router_bootstrap.go | 45 ++----- internal/bootstrap/service_bootstrap.go | 22 ++-- internal/controller/context_controller.go | 81 ++++++------- internal/controller/oauth_controller.go | 103 ++++++++-------- internal/controller/oidc_controller.go | 79 ++++++------ internal/controller/proxy_controller.go | 85 +++++++------ internal/controller/resources_controller.go | 19 ++- internal/controller/user_controller.go | 120 ++++++++++--------- internal/controller/well_known_controller.go | 14 +-- internal/model/runtime.go | 30 +++++ internal/service/access_controls_service.go | 7 -- internal/utils/tlog/log_audit.go | 39 ------ internal/utils/tlog/log_wrapper.go | 97 --------------- internal/utils/tlog/log_wrapper_test.go | 93 -------------- 16 files changed, 335 insertions(+), 588 deletions(-) create mode 100644 internal/model/runtime.go delete mode 100644 internal/utils/tlog/log_audit.go delete mode 100644 internal/utils/tlog/log_wrapper.go delete mode 100644 internal/utils/tlog/log_wrapper_test.go diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b10b192..268d0d30 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -18,7 +18,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" @@ -36,25 +35,9 @@ type Services struct { oidcService *service.OIDCService } -type RuntimeConfig struct { - appUrl string - uuid string - cookieDomain string - sessionCookieName string - csrfCookieName string - redirectCookieName string - oauthSessionCookieName string - localUsers []model.LocalUser - oauthProviders map[string]model.OAuthServiceConfig - oauthWhitelist []string - configuredProviders []controller.Provider - oidcClients []model.OIDCClientConfig - labelProvider service.LabelProvider -} - -type App struct { +type BootstrapApp struct { config model.Config - runtime RuntimeConfig + runtime model.RuntimeConfig services Services log *logger.Logger ctx context.Context @@ -64,13 +47,13 @@ type App struct { db *sql.DB } -func NewBootstrapApp(config model.Config) *App { - return &App{ +func NewBootstrapApp(config model.Config) *BootstrapApp { + return &BootstrapApp{ config: config, } } -func (app *App) Setup() error { +func (app *BootstrapApp) Setup() error { // create context ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) app.ctx = ctx @@ -92,7 +75,7 @@ func (app *App) Setup() error { return fmt.Errorf("failed to parse app url: %w", err) } - app.runtime.appUrl = appUrl.Scheme + "://" + appUrl.Host + app.runtime.AppURL = appUrl.Scheme + "://" + appUrl.Host // validate session config if app.config.Auth.SessionMaxLifetime != 0 && app.config.Auth.SessionMaxLifetime < app.config.Auth.SessionExpiry { @@ -106,7 +89,7 @@ func (app *App) Setup() error { return fmt.Errorf("failed to load users: %w", err) } - app.runtime.localUsers = *users + app.runtime.LocalUsers = *users // load oauth whitelist oauthWhitelist, err := utils.GetStringList(app.config.OAuth.Whitelist, app.config.OAuth.WhitelistFile) @@ -115,25 +98,25 @@ func (app *App) Setup() error { return fmt.Errorf("failed to load oauth whitelist: %w", err) } - app.runtime.oauthWhitelist = oauthWhitelist + app.runtime.OAuthWhitelist = oauthWhitelist // Setup oauth providers - app.runtime.oauthProviders = app.config.OAuth.Providers + app.runtime.OAuthProviders = app.config.OAuth.Providers - for id, provider := range app.runtime.oauthProviders { + for id, provider := range app.runtime.OAuthProviders { secret := utils.GetSecret(provider.ClientSecret, provider.ClientSecretFile) provider.ClientSecret = secret provider.ClientSecretFile = "" if provider.RedirectURL == "" { - provider.RedirectURL = app.runtime.appUrl + "/api/oauth/callback/" + id + provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id } - app.runtime.oauthProviders[id] = provider + app.runtime.OAuthProviders[id] = provider } // set presets for built-in providers - for id, provider := range app.runtime.oauthProviders { + for id, provider := range app.runtime.OAuthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -141,13 +124,13 @@ func (app *App) Setup() error { provider.Name = utils.Capitalize(id) } } - app.runtime.oauthProviders[id] = provider + app.runtime.OAuthProviders[id] = provider } // setup oidc clients for id, client := range app.config.OIDC.Clients { client.ID = id - app.runtime.oidcClients = append(app.runtime.oidcClients, client) + app.runtime.OIDCClients = append(app.runtime.OIDCClients, client) } // cookie domain @@ -158,23 +141,23 @@ func (app *App) Setup() error { cookieDomainResolver = utils.GetStandaloneCookieDomain } - cookieDomain, err := cookieDomainResolver(app.runtime.appUrl) + cookieDomain, err := cookieDomainResolver(app.runtime.AppURL) if err != nil { return fmt.Errorf("failed to get cookie domain: %w", err) } - app.runtime.cookieDomain = cookieDomain + app.runtime.CookieDomain = cookieDomain // cookie names - app.runtime.uuid = utils.GenerateUUID(appUrl.Hostname()) + app.runtime.UUID = utils.GenerateUUID(appUrl.Hostname()) - cookieId := strings.Split(app.runtime.uuid, "-")[0] // first 8 characters of the uuid should be good enough + cookieId := strings.Split(app.runtime.UUID, "-")[0] // first 8 characters of the uuid should be good enough - app.runtime.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) - app.runtime.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) - app.runtime.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) - app.runtime.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) + app.runtime.SessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) + app.runtime.CSRFCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) + app.runtime.RedirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) + app.runtime.OAuthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) // database err = app.SetupDatabase() @@ -195,10 +178,10 @@ func (app *App) Setup() error { } // configured providers - configuredProviders := make([]controller.Provider, 0) + configuredProviders := make([]model.Provider, 0) - for id, provider := range app.runtime.oauthProviders { - configuredProviders = append(configuredProviders, controller.Provider{ + for id, provider := range app.runtime.OAuthProviders { + configuredProviders = append(configuredProviders, model.Provider{ Name: provider.Name, ID: id, OAuth: true, @@ -210,7 +193,7 @@ func (app *App) Setup() error { }) if app.services.authService.LocalAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + configuredProviders = append(configuredProviders, model.Provider{ Name: "Local", ID: "local", OAuth: false, @@ -218,7 +201,7 @@ func (app *App) Setup() error { } if app.services.authService.LDAPAuthConfigured() { - configuredProviders = append(configuredProviders, controller.Provider{ + configuredProviders = append(configuredProviders, model.Provider{ Name: "LDAP", ID: "ldap", OAuth: false, @@ -229,11 +212,11 @@ func (app *App) Setup() error { return errors.New("no authentication providers configured") } - for _, provider := range app.runtime.configuredProviders { + for _, provider := range app.runtime.ConfiguredProviders { app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") } - app.runtime.configuredProviders = configuredProviders + app.runtime.ConfiguredProviders = configuredProviders // setup router err = app.setupRouter() @@ -279,7 +262,7 @@ func (app *App) Setup() error { return nil } -func (app *App) serveHTTP() error { +func (app *BootstrapApp) serveHTTP() error { address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) app.log.App.Info().Msgf("Starting server on %s", address) @@ -304,7 +287,7 @@ func (app *App) serveHTTP() error { return nil } -func (app *App) serveUnix() error { +func (app *BootstrapApp) serveUnix() error { if app.config.Server.SocketPath == "" { return nil } @@ -351,7 +334,7 @@ func (app *App) serveUnix() error { return nil } -func (app *App) heartbeatRoutine() { +func (app *BootstrapApp) heartbeatRoutine() { ticker := time.NewTicker(time.Duration(12) * time.Hour) defer ticker.Stop() @@ -362,7 +345,7 @@ func (app *App) heartbeatRoutine() { var body Heartbeat - body.UUID = app.runtime.uuid + body.UUID = app.runtime.UUID body.Version = model.Version bodyJson, err := json.Marshal(body) @@ -412,7 +395,7 @@ func (app *App) heartbeatRoutine() { } } -func (app *App) dbCleanupRoutine() { +func (app *BootstrapApp) dbCleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 5ef5c9dc..4644036b 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -14,7 +14,7 @@ import ( _ "modernc.org/sqlite" ) -func (app *App) SetupDatabase() error { +func (app *BootstrapApp) SetupDatabase() error { dir := filepath.Dir(app.config.Database.Path) if err := os.MkdirAll(dir, 0750); err != nil { diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 7310fa43..2250fb19 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -13,7 +13,7 @@ import ( var DEV_MODES = []string{"main", "test", "development"} -func (app *App) setupRouter() error { +func (app *BootstrapApp) setupRouter() error { if !slices.Contains(DEV_MODES, model.Version) { gin.SetMode(gin.ReleaseMode) } @@ -30,8 +30,8 @@ func (app *App) setupRouter() error { } contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.runtime.cookieDomain, - SessionCookieName: app.runtime.sessionCookieName, + CookieDomain: app.runtime.CookieDomain, + SessionCookieName: app.runtime.SessionCookieName, }, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() @@ -64,52 +64,27 @@ func (app *App) setupRouter() error { apiRouter := engine.Group("/api") - contextController := controller.NewContextController(controller.ContextControllerConfig{ - Providers: app.runtime.configuredProviders, - Title: app.config.UI.Title, - AppURL: app.config.AppURL, - CookieDomain: app.runtime.cookieDomain, - ForgotPasswordMessage: app.config.UI.ForgotPasswordMessage, - BackgroundImage: app.config.UI.BackgroundImage, - OAuthAutoRedirect: app.config.OAuth.AutoRedirect, - WarningsEnabled: app.config.UI.WarningsEnabled, - }, apiRouter) + contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter) contextController.SetupRoutes() - oauthController := controller.NewOAuthController(controller.OAuthControllerConfig{ - AppURL: app.config.AppURL, - SecureCookie: app.config.Auth.SecureCookie, - CSRFCookieName: app.runtime.csrfCookieName, - RedirectCookieName: app.runtime.redirectCookieName, - CookieDomain: app.runtime.cookieDomain, - OAuthSessionCookieName: app.runtime.oauthSessionCookieName, - SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, apiRouter, app.services.authService) + oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) oauthController.SetupRoutes() - oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter) + oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) oidcController.SetupRoutes() - proxyController := controller.NewProxyController(controller.ProxyControllerConfig{ - AppURL: app.config.AppURL, - }, apiRouter, app.services.accessControlService, app.services.authService) + proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) proxyController.SetupRoutes() - userController := controller.NewUserController(controller.UserControllerConfig{ - CookieDomain: app.runtime.cookieDomain, - SessionCookieName: app.runtime.sessionCookieName, - }, apiRouter, app.services.authService) + userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) userController.SetupRoutes() - resourcesController := controller.NewResourcesController(controller.ResourcesControllerConfig{ - Path: app.config.Resources.Path, - Enabled: app.config.Resources.Enabled, - }, &engine.RouterGroup) + resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup) resourcesController.SetupRoutes() @@ -117,7 +92,7 @@ func (app *App) setupRouter() error { healthController.SetupRoutes() - wellknownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, app.services.oidcService, engine) + wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) wellknownController.SetupRoutes() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index b3261180..9f44540d 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -4,11 +4,11 @@ import ( "fmt" "os" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -func (app *App) setupServices() error { +func (app *BootstrapApp) setupServices() error { ldapService := service.NewLdapService(service.LdapServiceConfig{ Address: app.config.LDAP.Address, BindDN: app.config.LDAP.BindDN, @@ -44,9 +44,9 @@ func (app *App) setupServices() error { } app.services.kubernetesService = kubernetesService - app.runtime.labelProvider = service.LabelProviderKubernetes + app.runtime.LabelProvider = model.LabelProviderKubernetes } else { - tlog.App.Debug().Msg("Using Docker label provider") + app.log.App.Debug().Msg("Using Docker label provider") dockerService := service.NewDockerService() @@ -57,10 +57,10 @@ func (app *App) setupServices() error { } app.services.dockerService = dockerService - app.runtime.labelProvider = service.LabelProviderDocker + app.runtime.LabelProvider = model.LabelProviderDocker } - accessControlsService := service.NewAccessControlsService(app.runtime.labelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps) err = accessControlsService.Init() @@ -70,7 +70,7 @@ func (app *App) setupServices() error { app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.runtime.oauthProviders) + oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders) err = oauthBrokerService.Init() @@ -81,15 +81,15 @@ func (app *App) setupServices() error { app.services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: &app.runtime.localUsers, - OauthWhitelist: app.runtime.oauthWhitelist, + LocalUsers: &app.runtime.LocalUsers, + OauthWhitelist: app.runtime.OAuthWhitelist, SessionExpiry: app.config.Auth.SessionExpiry, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.runtime.cookieDomain, + CookieDomain: app.runtime.CookieDomain, LoginTimeout: app.config.Auth.LoginTimeout, LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.runtime.sessionCookieName, + SessionCookieName: app.runtime.SessionCookieName, IP: app.config.Auth.IP, LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index f939ba55..491cb0b8 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -5,7 +5,7 @@ import ( "net/url" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) @@ -24,48 +24,40 @@ type UserContextResponse struct { } type AppContextResponse struct { - Status int `json:"status"` - Message string `json:"message"` - Providers []Provider `json:"providers"` - Title string `json:"title"` - AppURL string `json:"appUrl"` - CookieDomain string `json:"cookieDomain"` - ForgotPasswordMessage string `json:"forgotPasswordMessage"` - BackgroundImage string `json:"backgroundImage"` - OAuthAutoRedirect string `json:"oauthAutoRedirect"` - WarningsEnabled bool `json:"warningsEnabled"` -} - -type Provider struct { - Name string `json:"name"` - ID string `json:"id"` - OAuth bool `json:"oauth"` -} - -type ContextControllerConfig struct { - Providers []Provider - Title string - AppURL string - CookieDomain string - ForgotPasswordMessage string - BackgroundImage string - OAuthAutoRedirect string - WarningsEnabled bool + Status int `json:"status"` + Message string `json:"message"` + Providers []model.Provider `json:"providers"` + Title string `json:"title"` + AppURL string `json:"appUrl"` + CookieDomain string `json:"cookieDomain"` + ForgotPasswordMessage string `json:"forgotPasswordMessage"` + BackgroundImage string `json:"backgroundImage"` + OAuthAutoRedirect string `json:"oauthAutoRedirect"` + WarningsEnabled bool `json:"warningsEnabled"` } type ContextController struct { - config ContextControllerConfig - router *gin.RouterGroup + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + router *gin.RouterGroup } -func NewContextController(config ContextControllerConfig, router *gin.RouterGroup) *ContextController { - if !config.WarningsEnabled { - tlog.App.Warn().Msg("UI warnings are disabled. This may expose users to security risks. Proceed with caution.") +func NewContextController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, +) *ContextController { + if !config.UI.WarningsEnabled { + log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") } return &ContextController{ - config: config, - router: router, + log: log, + config: config, + runtime: runtimeConfig, + router: router, } } @@ -79,7 +71,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request") c.JSON(200, UserContextResponse{ Status: 401, Message: "Unauthorized", @@ -106,8 +98,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { func (controller *ContextController) appContextHandler(c *gin.Context) { appUrl, err := url.Parse(controller.config.AppURL) + if err != nil { - tlog.App.Error().Err(err).Msg("Failed to parse app URL") + controller.log.App.Error().Err(err).Msg("Failed to parse app URL") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -118,13 +111,13 @@ func (controller *ContextController) appContextHandler(c *gin.Context) { c.JSON(200, AppContextResponse{ Status: 200, Message: "Success", - Providers: controller.config.Providers, - Title: controller.config.Title, + Providers: controller.runtime.ConfiguredProviders, + Title: controller.config.UI.Title, AppURL: fmt.Sprintf("%s://%s", appUrl.Scheme, appUrl.Host), - CookieDomain: controller.config.CookieDomain, - ForgotPasswordMessage: controller.config.ForgotPasswordMessage, - BackgroundImage: controller.config.BackgroundImage, - OAuthAutoRedirect: controller.config.OAuthAutoRedirect, - WarningsEnabled: controller.config.WarningsEnabled, + CookieDomain: controller.runtime.CookieDomain, + ForgotPasswordMessage: controller.config.UI.ForgotPasswordMessage, + BackgroundImage: controller.config.UI.BackgroundImage, + OAuthAutoRedirect: controller.config.OAuth.AutoRedirect, + WarningsEnabled: controller.config.UI.WarningsEnabled, }) } diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 7f6d6ce0..902ee3de 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -6,10 +6,11 @@ import ( "strings" "time" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -19,27 +20,27 @@ type OAuthRequest struct { Provider string `uri:"provider" binding:"required"` } -type OAuthControllerConfig struct { - CSRFCookieName string - OAuthSessionCookieName string - RedirectCookieName string - SecureCookie bool - AppURL string - CookieDomain string - SubdomainsEnabled bool -} - type OAuthController struct { - config OAuthControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + router *gin.RouterGroup + auth *service.AuthService } -func NewOAuthController(config OAuthControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *OAuthController { +func NewOAuthController( + log *logger.Logger, + config model.Config, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *OAuthController { return &OAuthController{ - config: config, - router: router, - auth: auth, + log: log, + config: config, + runtime: runtimeConfig, + router: router, + auth: auth, } } @@ -54,7 +55,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -67,7 +68,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { err = c.BindQuery(&reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind query parameters") + controller.log.App.Error().Err(err).Msg("Failed to bind query parameters") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -76,10 +77,10 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { } if !controller.isOidcRequest(reqParams) { - isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.config.CookieDomain) + isRedirectSafe := utils.IsRedirectSafe(reqParams.RedirectURI, controller.runtime.CookieDomain) if !isRedirectSafe { - tlog.App.Warn().Str("redirect_uri", reqParams.RedirectURI).Msg("Unsafe redirect URI detected, ignoring") + controller.log.App.Warn().Str("redirectUri", reqParams.RedirectURI).Msg("Unsafe redirect URI, ignoring") reqParams.RedirectURI = "" } } @@ -87,7 +88,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { sessionId, _, err := controller.auth.NewOAuthSession(req.Provider, reqParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create OAuth session") + controller.log.App.Error().Err(err).Msg("Failed to create new OAuth session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -98,7 +99,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { authUrl, err := controller.auth.GetOAuthURL(sessionId) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth URL") + controller.log.App.Error().Err(err).Msg("Failed to get OAuth URL for session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -106,7 +107,7 @@ func (controller *OAuthController) oauthURLHandler(c *gin.Context) { return } - c.SetCookie(controller.config.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, sessionId, int(time.Hour.Seconds()), "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) c.JSON(200, gin.H{ "status": 200, @@ -120,7 +121,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -128,20 +129,20 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } - sessionIdCookie, err := c.Cookie(controller.config.OAuthSessionCookieName) + sessionIdCookie, err := c.Cookie(controller.runtime.OAuthSessionCookieName) if err != nil { - tlog.App.Warn().Err(err).Msg("OAuth session cookie missing") + controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } - c.SetCookie(controller.config.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.SecureCookie, true) + c.SetCookie(controller.runtime.OAuthSessionCookieName, "", -1, "/", controller.getCookieDomain(), controller.config.Auth.SecureCookie, true) oauthPendingSession, err := controller.auth.GetOAuthPendingSession(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth pending session") + controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -150,7 +151,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { state := c.Query("state") if state != oauthPendingSession.State { - tlog.App.Warn().Err(err).Msg("CSRF token mismatch") + controller.log.App.Warn().Msg("OAuth state mismatch") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -159,7 +160,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { _, err = controller.auth.GetOAuthToken(sessionIdCookie, code) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to exchange code for token") + controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -167,21 +168,21 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) if user.Email == "" { - tlog.App.Error().Msg("OAuth provider did not return an email") + controller.log.App.Warn().Msg("OAuth provider did not return an email") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } if !controller.auth.IsEmailWhitelisted(user.Email) { - tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") - tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") + controller.log.App.Warn().Str("email", user.Email).Msg("Email not whitelisted, denying access") + controller.log.AuditLoginFailure(user.Email, req.Provider, c.ClientIP(), "email not whitelisted") queries, err := query.Values(UnauthorizedQuery{ Username: user.Email, }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -193,33 +194,33 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { var name string if strings.TrimSpace(user.Name) != "" { - tlog.App.Debug().Msg("Using name from OAuth provider") + controller.log.App.Debug().Msg("Using name from OAuth provider") name = user.Name } else { - tlog.App.Debug().Msg("No name from OAuth provider, using pseudo name") + controller.log.App.Debug().Msg("No name from OAuth provider, generating from email") name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1]) } var username string if strings.TrimSpace(user.PreferredUsername) != "" { - tlog.App.Debug().Msg("Using preferred username from OAuth provider") + controller.log.App.Debug().Msg("Using preferred username from OAuth provider") username = user.PreferredUsername } else { - tlog.App.Debug().Msg("No preferred username from OAuth provider, using pseudo username") + controller.log.App.Debug().Msg("No preferred username from OAuth provider, generating from email") username = strings.Replace(user.Email, "@", "_", 1) } svc, err := controller.auth.GetOAuthService(sessionIdCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get OAuth service for session") + controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } if svc.ID() != req.Provider { - tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", svc.ID(), req.Provider) + controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -234,25 +235,25 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { OAuthSub: user.Sub, } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") + controller.log.App.Debug().Msg("Creating session cookie for user") cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Msg("Failed to create session cookie") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } http.SetCookie(c.Writer, cookie) - tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) + controller.log.AuditLoginSuccess(sessionCookie.Username, sessionCookie.Provider, c.ClientIP()) if controller.isOidcRequest(oauthPendingSession.CallbackParams) { - tlog.App.Debug().Msg("OIDC request, redirecting to authorize page") + controller.log.App.Debug().Msg("OIDC request detected, redirecting to authorization endpoint with callback params") queries, err := query.Values(oauthPendingSession.CallbackParams) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode OIDC callback query") + controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -266,7 +267,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) return } @@ -286,8 +287,8 @@ func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) } func (controller *OAuthController) getCookieDomain() string { - if controller.config.SubdomainsEnabled { - return "." + controller.config.CookieDomain + if controller.config.Auth.SubdomainsEnabled { + return "." + controller.runtime.CookieDomain } - return controller.config.CookieDomain + return controller.runtime.CookieDomain } diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 5e3f75f5..e5a139c9 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -13,13 +13,11 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type OIDCControllerConfig struct{} - type OIDCController struct { - config OIDCControllerConfig + log *logger.Logger router *gin.RouterGroup oidc *service.OIDCService } @@ -58,9 +56,12 @@ type ClientCredentials struct { ClientSecret string } -func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { +func NewOIDCController( + log *logger.Logger, + oidcService *service.OIDCService, + router *gin.RouterGroup) *OIDCController { return &OIDCController{ - config: config, + log: log, oidc: oidcService, router: router, } @@ -80,7 +81,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { err := c.BindUri(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind URI") + controller.log.App.Error().Err(err).Msg("Failed to bind URI") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -91,7 +92,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { client, ok := controller.oidc.GetClient(req.ClientID) if !ok { - tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", req.ClientID).Msg("Client not found") c.JSON(404, gin.H{ "status": 404, "message": "Client not found", @@ -142,7 +143,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err = controller.oidc.ValidateAuthorizeParams(req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to validate authorize params") + controller.log.App.Warn().Err(err).Msg("Failed to validate authorize params") if err.Error() != "invalid_request_uri" { controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State) return @@ -174,7 +175,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to insert user info into database") + controller.log.App.Error().Err(err).Msg("Failed to store user info") controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State) return } @@ -198,7 +199,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) { if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") + controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", }) @@ -209,7 +210,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err := c.Bind(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind token request") + controller.log.App.Warn().Err(err).Msg("Failed to bind token request") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -218,7 +219,7 @@ func (controller *OIDCController) Token(c *gin.Context) { err = controller.oidc.ValidateGrantType(req.GrantType) if err != nil { - tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type") + controller.log.App.Warn().Err(err).Msg("Invalid grant type") c.JSON(400, gin.H{ "error": err.Error(), }) @@ -233,12 +234,12 @@ func (controller *OIDCController) Token(c *gin.Context) { // If it fails, we try basic auth if creds.ClientID == "" || creds.ClientSecret == "" { - tlog.App.Debug().Msg("Tried form values and they are empty, trying basic auth") + controller.log.App.Debug().Msg("Client credentials not found in form, trying basic auth") clientId, clientSecret, ok := c.Request.BasicAuth() if !ok { - tlog.App.Error().Msg("Missing authorization header") + controller.log.App.Warn().Msg("Client credentials not found in basic auth") c.Header("www-authenticate", `Basic realm="Tinyauth OIDC Token Endpoint"`) c.JSON(400, gin.H{ "error": "invalid_client", @@ -255,7 +256,7 @@ func (controller *OIDCController) Token(c *gin.Context) { client, ok := controller.oidc.GetClient(creds.ClientID) if !ok { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Client not found") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Client not found") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -263,7 +264,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if client.ClientSecret != creds.ClientSecret { - tlog.App.Warn().Str("client_id", creds.ClientID).Msg("Invalid client secret") + controller.log.App.Warn().Str("clientId", creds.ClientID).Msg("Invalid client secret") c.JSON(400, gin.H{ "error": "invalid_client", }) @@ -277,30 +278,30 @@ func (controller *OIDCController) Token(c *gin.Context) { entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID) if err != nil { if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil { - tlog.App.Error().Err(err).Msg("Failed to delete access token by code hash") + controller.log.App.Error().Err(err).Msg("Failed to delete code") } if errors.Is(err, service.ErrCodeNotFound) { - tlog.App.Warn().Msg("Code not found") + controller.log.App.Warn().Msg("Code not found") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrCodeExpired) { - tlog.App.Warn().Msg("Code expired") + controller.log.App.Warn().Msg("Code expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Warn().Msg("Invalid client ID") + controller.log.App.Warn().Msg("Code does not belong to client") c.JSON(400, gin.H{ "error": "invalid_client", }) return } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry") + controller.log.App.Error().Err(err).Msg("Failed to get code entry") c.JSON(400, gin.H{ "error": "server_error", }) @@ -308,7 +309,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } if entry.RedirectURI != req.RedirectURI { - tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch") + controller.log.App.Warn().Msg("Redirect URI does not match") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -318,7 +319,7 @@ func (controller *OIDCController) Token(c *gin.Context) { ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier) if !ok { - tlog.App.Warn().Msg("PKCE validation failed") + controller.log.App.Warn().Msg("PKCE validation failed") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -328,7 +329,7 @@ func (controller *OIDCController) Token(c *gin.Context) { tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to generate access token") + controller.log.App.Error().Err(err).Msg("Failed to generate access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -341,7 +342,7 @@ func (controller *OIDCController) Token(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrTokenExpired) { - tlog.App.Error().Err(err).Msg("Refresh token expired") + controller.log.App.Warn().Msg("Refresh token expired") c.JSON(400, gin.H{ "error": "invalid_grant", }) @@ -349,14 +350,14 @@ func (controller *OIDCController) Token(c *gin.Context) { } if errors.Is(err, service.ErrInvalidClient) { - tlog.App.Error().Err(err).Msg("Invalid client") + controller.log.App.Warn().Msg("Refresh token does not belong to client") c.JSON(400, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Error().Err(err).Msg("Failed to refresh access token") + controller.log.App.Error().Err(err).Msg("Failed to refresh access token") c.JSON(400, gin.H{ "error": "server_error", }) @@ -374,7 +375,7 @@ func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) { if !controller.oidc.IsConfigured() { - tlog.App.Warn().Msg("OIDC not configured") + controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", }) @@ -387,7 +388,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if authorization != "" { tokenType, bearerToken, ok := strings.Cut(authorization, " ") if !ok { - tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid authorization header") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -395,7 +396,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } if strings.ToLower(tokenType) != "bearer" { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type") + controller.log.App.Warn().Msg("OIDC userinfo accessed with non-bearer token") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -405,7 +406,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { token = bearerToken } else if c.Request.Method == http.MethodPost { if c.ContentType() != "application/x-www-form-urlencoded" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type") c.JSON(400, gin.H{ "error": "invalid_request", }) @@ -413,14 +414,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } token = c.PostForm("access_token") if token == "" { - tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body") + controller.log.App.Warn().Msg("OIDC userinfo POST accessed without access_token") c.JSON(401, gin.H{ "error": "invalid_request", }) return } } else { - tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header") + controller.log.App.Warn().Msg("OIDC userinfo accessed without authorization header or POST body") c.JSON(401, gin.H{ "error": "invalid_request", }) @@ -431,14 +432,14 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrTokenNotFound) { - tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token") + controller.log.App.Warn().Msg("OIDC userinfo accessed with invalid token") c.JSON(401, gin.H{ "error": "invalid_grant", }) return } - tlog.App.Err(err).Msg("Failed to get token entry") + controller.log.App.Error().Err(err).Msg("Failed to get access token") c.JSON(401, gin.H{ "error": "server_error", }) @@ -447,7 +448,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { // If we don't have the openid scope, return an error if !slices.Contains(strings.Split(entry.Scope, ","), "openid") { - tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope") + controller.log.App.Warn().Msg("OIDC userinfo accessed with token missing openid scope") c.JSON(401, gin.H{ "error": "invalid_scope", }) @@ -457,7 +458,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { user, err := controller.oidc.GetUserinfo(c, entry.Sub) if err != nil { - tlog.App.Err(err).Msg("Failed to get user entry") + controller.log.App.Error().Err(err).Msg("Failed to get user info") c.JSON(401, gin.H{ "error": "server_error", }) @@ -468,7 +469,7 @@ func (controller *OIDCController) Userinfo(c *gin.Context) { } func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) { - tlog.App.Error().Err(err).Msg(reason) + controller.log.App.Warn().Err(err).Str("reason", reason).Msg("Authorization error") if callback != "" { errorQueries := CallbackError{ diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index 7cd01969..b4bdc534 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -11,7 +11,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/google/go-querystring/query" @@ -50,23 +50,27 @@ type ProxyContext struct { ProxyType ProxyType } -type ProxyControllerConfig struct { - AppURL string -} - type ProxyController struct { - config ProxyControllerConfig - router *gin.RouterGroup - acls *service.AccessControlsService - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + router *gin.RouterGroup + acls *service.AccessControlsService + auth *service.AuthService } -func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, acls *service.AccessControlsService, auth *service.AuthService) *ProxyController { +func NewProxyController( + log *logger.Logger, + runtime model.RuntimeConfig, + router *gin.RouterGroup, + acls *service.AccessControlsService, + auth *service.AuthService, +) *ProxyController { return &ProxyController{ - config: config, - router: router, - acls: acls, - auth: auth, + log: log, + runtime: runtime, + router: router, + acls: acls, + auth: auth, } } @@ -80,7 +84,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { proxyCtx, err := controller.getProxyContext(c) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to get proxy context") + controller.log.App.Error().Err(err).Msg("Failed to get proxy context from request") c.JSON(400, gin.H{ "status": 400, "message": "Bad request", @@ -88,19 +92,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { return } - tlog.App.Trace().Interface("ctx", proxyCtx).Msg("Got proxy context") - // Get acls acls, err := controller.acls.GetAccessControls(proxyCtx.Host) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get access controls for resource") + controller.log.App.Error().Err(err).Msg("Failed to get ACLs for resource") controller.handleError(c, proxyCtx) return } - tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource") - clientIP := c.ClientIP() if controller.auth.IsBypassedIP(clientIP, acls) { @@ -115,13 +115,13 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") + controller.log.App.Error().Err(err).Msg("Failed to determine if authentication is enabled for resource") controller.handleError(c, proxyCtx) return } if !authEnabled { - tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") + controller.log.App.Debug().Msg("Authentication is disabled for this resource, allowing access without authentication") controller.setHeaders(c, acls) c.JSON(200, gin.H{ "status": 200, @@ -137,12 +137,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -160,26 +160,24 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") userContext = &model.UserContext{ Authenticated: false, } } - tlog.App.Trace().Interface("context", userContext).Msg("User context from request") - if userContext.Authenticated { userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) if !userAllowed { - tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not allowed to access resource") queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } @@ -190,7 +188,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -215,7 +213,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { } if !groupOK { - tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") + controller.log.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User is not in the required group to access resource") queries, err := query.Values(UnauthorizedQuery{ Resource: strings.Split(proxyCtx.Host, ".")[0], @@ -223,7 +221,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode unauthorized query") + controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.handleError(c, proxyCtx) return } @@ -234,7 +232,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { queries.Set("username", userContext.GetUsername()) } - redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -277,12 +275,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to encode redirect URI query") + controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") controller.handleError(c, proxyCtx) return } - redirectURL := fmt.Sprintf("%s/login?%s", controller.config.AppURL, queries.Encode()) + redirectURL := fmt.Sprintf("%s/login?%s", controller.runtime.AppURL, queries.Encode()) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -306,20 +304,19 @@ func (controller *ProxyController) setHeaders(c *gin.Context, acls *model.App) { headers := utils.ParseHeaders(acls.Response.Headers) for key, value := range headers { - tlog.App.Debug().Str("header", key).Msg("Setting header") c.Header(key, value) } basicPassword := utils.GetSecret(acls.Response.BasicAuth.Password, acls.Response.BasicAuth.PasswordFile) if acls.Response.BasicAuth.Username != "" && basicPassword != "" { - tlog.App.Debug().Str("username", acls.Response.BasicAuth.Username).Msg("Setting basic auth header") + controller.log.App.Debug().Msg("Setting basic auth header for response") c.Header("Authorization", fmt.Sprintf("Basic %s", utils.EncodeBasicAuth(acls.Response.BasicAuth.Username, basicPassword))) } } func (controller *ProxyController) handleError(c *gin.Context, proxyCtx ProxyContext) { - redirectURL := fmt.Sprintf("%s/error", controller.config.AppURL) + redirectURL := fmt.Sprintf("%s/error", controller.runtime.AppURL) if !controller.useBrowserResponse(proxyCtx) { c.Header("x-tinyauth-location", redirectURL) @@ -520,7 +517,7 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext return ProxyContext{}, err } - tlog.App.Debug().Msgf("Proxy: %v", req.Proxy) + controller.log.App.Debug().Msgf("Determined proxy type: %v", proxy) authModules := controller.determineAuthModules(proxy) @@ -531,13 +528,13 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext var ctx ProxyContext for _, module := range authModules { - tlog.App.Debug().Msgf("Trying auth module: %v", module) + controller.log.App.Debug().Msgf("Trying to get context from auth module %v", module) ctx, err = controller.getContextFromAuthModule(c, module) if err == nil { - tlog.App.Debug().Msgf("Auth module %v succeeded", module) + controller.log.App.Debug().Msgf("Successfully got context from auth module %v", module) break } - tlog.App.Debug().Err(err).Msgf("Auth module %v failed", module) + controller.log.App.Debug().Msgf("Failed to get context from auth module %v: %v", module, err) } if err != nil { @@ -549,9 +546,9 @@ func (controller *ProxyController) getProxyContext(c *gin.Context) (ProxyContext isBrowser := BrowserUserAgentRegex.MatchString(userAgent) if isBrowser { - tlog.App.Debug().Msg("Request identified as coming from a browser") + controller.log.App.Debug().Msg("Request identified as coming from a browser client") } else { - tlog.App.Debug().Msg("Request identified as coming from a non-browser client") + controller.log.App.Debug().Msg("Request identified as coming from a non-browser client") } ctx.IsBrowser = isBrowser diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index 98d3b23c..b0fa3d70 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -4,21 +4,20 @@ import ( "net/http" "github.com/gin-gonic/gin" + "github.com/tinyauthapp/tinyauth/internal/model" ) -type ResourcesControllerConfig struct { - Path string - Enabled bool -} - type ResourcesController struct { - config ResourcesControllerConfig + config model.Config router *gin.RouterGroup fileServer http.Handler } -func NewResourcesController(config ResourcesControllerConfig, router *gin.RouterGroup) *ResourcesController { - fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Path))) +func NewResourcesController( + config model.Config, + router *gin.RouterGroup, +) *ResourcesController { + fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) return &ResourcesController{ config: config, @@ -32,14 +31,14 @@ func (controller *ResourcesController) SetupRoutes() { } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { - if controller.config.Path == "" { + if controller.config.Resources.Path == "" { c.JSON(404, gin.H{ "status": 404, "message": "Resources not found", }) return } - if !controller.config.Enabled { + if !controller.config.Resources.Enabled { c.JSON(403, gin.H{ "status": 403, "message": "Resources are disabled", diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index cb6d5e6f..a7a1f948 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -10,7 +10,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" "github.com/pquerna/otp/totp" @@ -25,22 +25,24 @@ type TotpRequest struct { Code string `json:"code"` } -type UserControllerConfig struct { - CookieDomain string - SessionCookieName string -} - type UserController struct { - config UserControllerConfig - router *gin.RouterGroup - auth *service.AuthService + log *logger.Logger + runtime model.RuntimeConfig + router *gin.RouterGroup + auth *service.AuthService } -func NewUserController(config UserControllerConfig, router *gin.RouterGroup, auth *service.AuthService) *UserController { +func NewUserController( + log *logger.Logger, + runtimeConfig model.RuntimeConfig, + router *gin.RouterGroup, + auth *service.AuthService, +) *UserController { return &UserController{ - config: config, - router: router, - auth: auth, + log: log, + runtime: runtimeConfig, + router: router, + auth: auth, } } @@ -56,7 +58,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -64,13 +66,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", req.Username).Msg("Login attempt") + controller.log.App.Debug().Str("username", req.Username).Msg("Login attempt") isLocked, remaining := controller.auth.IsAccountLocked(req.Username) if isLocked { - tlog.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") - tlog.AuditLoginFailure(c, req.Username, "username", "account locked") + controller.log.App.Warn().Str("username", req.Username).Msg("Account is locked due to too many failed login attempts") + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -84,16 +86,16 @@ func (controller *UserController) loginHandler(c *gin.Context) { if err != nil { if errors.Is(err, service.ErrUserNotFound) { - tlog.App.Warn().Str("username", req.Username).Msg("User not found") + controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "user not found") + controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", }) return } - tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user during login attempt") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -102,9 +104,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { } if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { - tlog.App.Warn().Err(err).Str("username", req.Username).Msg("Failed to verify password") + controller.log.App.Warn().Str("username", req.Username).Msg("Invalid password during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") + if search.Type == model.UserLocal { + controller.log.AuditLoginFailure(req.Username, "local", c.ClientIP(), "invalid password") + } else { + controller.log.AuditLoginFailure(req.Username, "ldap", c.ClientIP(), "invalid password") + } c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -118,7 +124,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { localUser = controller.auth.GetLocalUser(req.Username) if localUser == nil { - tlog.App.Warn().Str("username", req.Username).Msg("User disappeared during login") + controller.log.App.Error().Str("username", req.Username).Msg("Local user not found after successful password verification") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -127,7 +133,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { } if localUser.TOTPSecret != "" { - tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") + controller.log.App.Debug().Str("username", req.Username).Msg("TOTP required for user, creating pending TOTP session") name := localUser.Attributes.Name if name == "" { @@ -136,7 +142,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { email := localUser.Attributes.Email if email == "" { - email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) + email = utils.CompileUserEmail(localUser.Username, controller.runtime.CookieDomain) } cookie, err := controller.auth.CreateSession(c, repository.Session{ @@ -148,7 +154,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { }) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create pending TOTP session") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -170,7 +176,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { sessionCookie := repository.Session{ Username: req.Username, Name: utils.Capitalize(req.Username), - Email: utils.CompileUserEmail(req.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(req.Username, controller.runtime.CookieDomain), Provider: "local", } @@ -187,12 +193,10 @@ func (controller *UserController) loginHandler(c *gin.Context) { sessionCookie.Provider = "ldap" } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", req.Username).Msg("Failed to create session cookie after successful login") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -202,8 +206,13 @@ func (controller *UserController) loginHandler(c *gin.Context) { http.SetCookie(c.Writer, cookie) - tlog.App.Info().Str("username", req.Username).Msg("Login successful") - tlog.AuditLoginSuccess(c, req.Username, "username") + controller.log.App.Info().Str("username", req.Username).Msg("Login successful") + + if search.Type == model.UserLocal { + controller.log.AuditLoginSuccess(req.Username, "local", c.ClientIP()) + } else { + controller.log.AuditLoginSuccess(req.Username, "ldap", c.ClientIP()) + } controller.auth.RecordLoginAttempt(req.Username, true) @@ -214,20 +223,20 @@ func (controller *UserController) loginHandler(c *gin.Context) { } func (controller *UserController) logoutHandler(c *gin.Context) { - tlog.App.Debug().Msg("Logout request received") + controller.log.App.Debug().Msg("Logout attempt") - uuid, err := c.Cookie(controller.config.SessionCookieName) + uuid, err := c.Cookie(controller.runtime.SessionCookieName) if err != nil { if errors.Is(err, http.ErrNoCookie) { - tlog.App.Warn().Msg("No session cookie found on logout request") + controller.log.App.Warn().Msg("Logout attempt without session cookie, treating as successful logout") c.JSON(200, gin.H{ "status": 200, "message": "Logout successful", }) return } - tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout") + controller.log.App.Error().Err(err).Msg("Error retrieving session cookie on logout") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -238,7 +247,7 @@ func (controller *UserController) logoutHandler(c *gin.Context) { cookie, err := controller.auth.DeleteSession(c, uuid) if err != nil { - tlog.App.Error().Err(err).Msg("Error deleting session on logout") + controller.log.App.Error().Err(err).Msg("Error deleting session on logout") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -249,10 +258,10 @@ func (controller *UserController) logoutHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err == nil { - tlog.AuditLogout(c, context.GetUsername(), context.GetProviderID()) + controller.log.AuditLogout(context.GetUsername(), context.GetProviderID(), c.ClientIP()) } else { - tlog.App.Warn().Err(err).Msg("Failed to get user context for logout audit, proceeding without username") - tlog.AuditLogout(c, "unknown", "unknown") + controller.log.App.Warn().Err(err).Msg("Failed to get user context during logout, logging audit with unknown user") + controller.log.AuditLogout("unknown", "unknown", c.ClientIP()) } http.SetCookie(c.Writer, cookie) @@ -268,7 +277,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { err := c.ShouldBindJSON(&req) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to bind JSON") + controller.log.App.Error().Err(err).Msg("Failed to bind JSON for TOTP verification") c.JSON(400, gin.H{ "status": 400, "message": "Bad Request", @@ -279,7 +288,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to get user context") + controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", @@ -288,7 +297,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { } if !context.TOTPPending() { - tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("TOTP verification attempt without pending TOTP session") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -296,12 +305,13 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") + controller.log.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) if isLocked { - tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "account locked") c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.JSON(429, gin.H{ @@ -314,7 +324,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { user := controller.auth.GetLocalUser(context.GetUsername()) if user == nil { - tlog.App.Error().Str("username", context.GetUsername()).Msg("User not found in TOTP handler") + controller.log.App.Error().Str("username", context.GetUsername()).Msg("Local user not found during TOTP verification") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -325,9 +335,9 @@ func (controller *UserController) totpHandler(c *gin.Context) { ok := totp.Validate(req.Code, user.TOTPSecret) if !ok { - tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") + controller.log.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code during verification attempt") controller.auth.RecordLoginAttempt(context.GetUsername(), false) - tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") + controller.log.AuditLoginFailure(context.GetUsername(), "local", c.ClientIP(), "invalid TOTP code") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", @@ -335,15 +345,15 @@ func (controller *UserController) totpHandler(c *gin.Context) { return } - uuid, err := c.Cookie(controller.config.SessionCookieName) + uuid, err := c.Cookie(controller.runtime.SessionCookieName) if err == nil { _, err = controller.auth.DeleteSession(c, uuid) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete pending TOTP session") + controller.log.App.Error().Err(err).Msg("Failed to delete pending TOTP session after successful verification") } } else { - tlog.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, proceeding without deleting it") + controller.log.App.Warn().Err(err).Msg("Failed to retrieve session cookie for pending TOTP session, cannot delete it") } controller.auth.RecordLoginAttempt(context.GetUsername(), true) @@ -351,7 +361,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie := repository.Session{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: utils.CompileUserEmail(user.Username, controller.config.CookieDomain), + Email: utils.CompileUserEmail(user.Username, controller.runtime.CookieDomain), Provider: "local", } @@ -362,8 +372,6 @@ func (controller *UserController) totpHandler(c *gin.Context) { sessionCookie.Email = user.Attributes.Email } - tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") - cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { @@ -377,8 +385,8 @@ func (controller *UserController) totpHandler(c *gin.Context) { http.SetCookie(c.Writer, cookie) - tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") - tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") + controller.log.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful, login complete") + controller.log.AuditLoginSuccess(context.GetUsername(), "local", c.ClientIP()) c.JSON(200, gin.H{ "status": 200, diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index f31a9ed7..951fdac2 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -26,25 +26,21 @@ type OpenIDConnectConfiguration struct { RequestObjectSigningAlgValuesSupported []string `json:"request_object_signing_alg_values_supported"` } -type WellKnownControllerConfig struct{} - type WellKnownController struct { - config WellKnownControllerConfig - engine *gin.Engine + router *gin.RouterGroup oidc *service.OIDCService } -func NewWellKnownController(config WellKnownControllerConfig, oidc *service.OIDCService, engine *gin.Engine) *WellKnownController { +func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { return &WellKnownController{ - config: config, oidc: oidc, - engine: engine, + router: router, } } func (controller *WellKnownController) SetupRoutes() { - controller.engine.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) - controller.engine.GET("/.well-known/jwks.json", controller.JWKS) + controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + controller.router.GET("/.well-known/jwks.json", controller.JWKS) } func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { diff --git a/internal/model/runtime.go b/internal/model/runtime.go new file mode 100644 index 00000000..72eab370 --- /dev/null +++ b/internal/model/runtime.go @@ -0,0 +1,30 @@ +package model + +type RuntimeConfig struct { + AppURL string + UUID string + CookieDomain string + SessionCookieName string + CSRFCookieName string + RedirectCookieName string + OAuthSessionCookieName string + LocalUsers []LocalUser + OAuthProviders map[string]OAuthServiceConfig + OAuthWhitelist []string + ConfiguredProviders []Provider + OIDCClients []OIDCClientConfig + LabelProvider LabelProvider +} + +type Provider struct { + Name string `json:"name"` + ID string `json:"id"` + OAuth bool `json:"oauth"` +} + +type LabelProvider int + +const ( + LabelProviderDocker LabelProvider = iota + LabelProviderKubernetes +) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index c16c5a25..d31ae6b7 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -7,13 +7,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) -type LabelProvider int - -const ( - LabelProviderDocker LabelProvider = iota - LabelProviderKubernetes -) - type LabelProviderImpl interface { GetLabels(appDomain string) (*model.App, error) } diff --git a/internal/utils/tlog/log_audit.go b/internal/utils/tlog/log_audit.go deleted file mode 100644 index 115d41fe..00000000 --- a/internal/utils/tlog/log_audit.go +++ /dev/null @@ -1,39 +0,0 @@ -package tlog - -import "github.com/gin-gonic/gin" - -// functions here use CallerSkipFrame to ensure correct caller info is logged - -func AuditLoginSuccess(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} - -func AuditLoginFailure(c *gin.Context, username, provider string, reason string) { - Audit.Warn(). - CallerSkipFrame(1). - Str("event", "login"). - Str("result", "failure"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Str("reason", reason). - Send() -} - -func AuditLogout(c *gin.Context, username, provider string) { - Audit.Info(). - CallerSkipFrame(1). - Str("event", "logout"). - Str("result", "success"). - Str("username", username). - Str("provider", provider). - Str("ip", c.ClientIP()). - Send() -} diff --git a/internal/utils/tlog/log_wrapper.go b/internal/utils/tlog/log_wrapper.go deleted file mode 100644 index ffdfcf91..00000000 --- a/internal/utils/tlog/log_wrapper.go +++ /dev/null @@ -1,97 +0,0 @@ -package tlog - -import ( - "os" - "strings" - "time" - - "github.com/rs/zerolog" - "github.com/rs/zerolog/log" - "github.com/tinyauthapp/tinyauth/internal/model" -) - -type Logger struct { - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -} - -var ( - Audit zerolog.Logger - HTTP zerolog.Logger - App zerolog.Logger -) - -func NewLogger(cfg model.LogConfig) *Logger { - baseLogger := log.With(). - Timestamp(). - Caller(). - Logger(). - Level(parseLogLevel(cfg.Level)) - - if !cfg.Json { - baseLogger = baseLogger.Output(zerolog.ConsoleWriter{ - Out: os.Stderr, - TimeFormat: time.RFC3339, - }) - } - - return &Logger{ - Audit: createLogger("audit", cfg.Streams.Audit, baseLogger), - HTTP: createLogger("http", cfg.Streams.HTTP, baseLogger), - App: createLogger("app", cfg.Streams.App, baseLogger), - } -} - -func NewSimpleLogger() *Logger { - return NewLogger(model.LogConfig{ - Level: "info", - Json: false, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: false}, - }, - }) -} - -func NewTestLogger() *Logger { - return NewLogger(model.LogConfig{ - Level: "trace", - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: true}, - }, - }) -} - -func (l *Logger) Init() { - Audit = l.Audit - HTTP = l.HTTP - App = l.App -} - -func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { - if !streamCfg.Enabled { - return zerolog.Nop() - } - subLogger := baseLogger.With().Str("log_stream", component).Logger() - // override level if specified, otherwise use base level - if streamCfg.Level != "" { - subLogger = subLogger.Level(parseLogLevel(streamCfg.Level)) - } - return subLogger -} - -func parseLogLevel(level string) zerolog.Level { - if level == "" { - return zerolog.InfoLevel - } - parsedLevel, err := zerolog.ParseLevel(strings.ToLower(level)) - if err != nil { - log.Warn().Err(err).Str("level", level).Msg("Invalid log level, defaulting to info") - parsedLevel = zerolog.InfoLevel - } - return parsedLevel -} diff --git a/internal/utils/tlog/log_wrapper_test.go b/internal/utils/tlog/log_wrapper_test.go deleted file mode 100644 index 41609f53..00000000 --- a/internal/utils/tlog/log_wrapper_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package tlog_test - -import ( - "bytes" - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - - "github.com/rs/zerolog" -) - -func TestNewLogger(t *testing.T) { - cfg := model.LogConfig{ - Level: "debug", - Json: true, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true, Level: "info"}, - App: model.LogStreamConfig{Enabled: true, Level: ""}, - Audit: model.LogStreamConfig{Enabled: false, Level: ""}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.NotNil(t, logger) - assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.DebugLevel, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestNewSimpleLogger(t *testing.T) { - logger := tlog.NewSimpleLogger() - assert.NotNil(t, logger) - assert.Equal(t, zerolog.InfoLevel, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.InfoLevel, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestLoggerInit(t *testing.T) { - logger := tlog.NewSimpleLogger() - logger.Init() - - assert.NotEqual(t, zerolog.Disabled, tlog.App.GetLevel()) -} - -func TestLoggerWithDisabledStreams(t *testing.T) { - cfg := model.LogConfig{ - Level: "info", - Json: false, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: false}, - App: model.LogStreamConfig{Enabled: false}, - Audit: model.LogStreamConfig{Enabled: false}, - }, - } - - logger := tlog.NewLogger(cfg) - - assert.Equal(t, zerolog.Disabled, logger.HTTP.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.App.GetLevel()) - assert.Equal(t, zerolog.Disabled, logger.Audit.GetLevel()) -} - -func TestLogStreamField(t *testing.T) { - var buf bytes.Buffer - - cfg := model.LogConfig{ - Level: "info", - Json: true, - Streams: model.LogStreams{ - HTTP: model.LogStreamConfig{Enabled: true}, - App: model.LogStreamConfig{Enabled: true}, - Audit: model.LogStreamConfig{Enabled: true}, - }, - } - - logger := tlog.NewLogger(cfg) - - // Override output for HTTP logger to capture output - logger.HTTP = logger.HTTP.Output(&buf) - - logger.HTTP.Info().Msg("test message") - - var logEntry map[string]interface{} - err := json.Unmarshal(buf.Bytes(), &logEntry) - assert.NoError(t, err) - - assert.Equal(t, "http", logEntry["log_stream"]) - assert.Equal(t, "test message", logEntry["message"]) -} From 55b53c77bf05b5e32a8749b3b5d85348ac15d487 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 16:42:49 +0300 Subject: [PATCH 04/22] refactor: rework logging and config in middlewares --- internal/middleware/context_middleware.go | 46 ++++++++++++----------- internal/middleware/ui_middleware.go | 3 -- internal/middleware/zerolog_middleware.go | 14 ++++--- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 88e96462..211f931c 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -10,7 +10,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/gin-gonic/gin" ) @@ -35,22 +35,24 @@ var ( } ) -type ContextMiddlewareConfig struct { - CookieDomain string - SessionCookieName string -} - type ContextMiddleware struct { - config ContextMiddlewareConfig - auth *service.AuthService - broker *service.OAuthBrokerService + log *logger.Logger + runtime model.RuntimeConfig + auth *service.AuthService + broker *service.OAuthBrokerService } -func NewContextMiddleware(config ContextMiddlewareConfig, auth *service.AuthService, broker *service.OAuthBrokerService) *ContextMiddleware { +func NewContextMiddleware( + log *logger.Logger, + runtime model.RuntimeConfig, + auth *service.AuthService, + broker *service.OAuthBrokerService, +) *ContextMiddleware { return &ContextMiddleware{ - config: config, - auth: auth, - broker: broker, + log: log, + runtime: runtime, + auth: auth, + broker: broker, } } @@ -65,7 +67,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return } - uuid, err := c.Cookie(m.config.SessionCookieName) + uuid, err := c.Cookie(m.runtime.SessionCookieName) if err == nil { userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid) @@ -75,12 +77,12 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { http.SetCookie(c.Writer, cookie) } - tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername()) + m.log.App.Debug().Msgf("Authenticated user %s via session cookie", userContext.GetUsername()) c.Set("context", userContext) c.Next() return } else { - tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) + m.log.App.Error().Msgf("Error authenticating session cookie: %v", err) } } @@ -90,7 +92,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { userContext, headers, err := m.basicAuth(username, password) if err != nil { - tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) + m.log.App.Error().Msgf("Error authenticating basic auth: %v", err) c.Next() return } @@ -141,7 +143,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model } if userContext.Local.Attributes.Email == "" { - userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain) + userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.runtime.CookieDomain) } case model.ProviderLDAP: search, err := m.auth.SearchUser(userContext.LDAP.Username) @@ -162,7 +164,7 @@ func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model userContext.LDAP.Groups = user.Groups userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username) - userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain) + userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.runtime.CookieDomain) case model.ProviderOAuth: _, exists := m.broker.GetService(userContext.OAuth.ID) @@ -191,7 +193,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. locked, remaining := m.auth.IsAccountLocked(username) if locked { - tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) + m.log.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", username, remaining) headers["x-tinyauth-lock-locked"] = "true" headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) return nil, headers, nil @@ -224,7 +226,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. BaseContext: model.BaseContext{ Username: user.Username, Name: utils.Capitalize(user.Username), - Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain), + Email: utils.CompileUserEmail(user.Username, m.runtime.CookieDomain), }, Attributes: user.Attributes, } @@ -240,7 +242,7 @@ func (m *ContextMiddleware) basicAuth(username string, password string) (*model. BaseContext: model.BaseContext{ Username: username, Name: utils.Capitalize(username), - Email: utils.CompileUserEmail(username, m.config.CookieDomain), + Email: utils.CompileUserEmail(username, m.runtime.CookieDomain), }, Groups: user.Groups, } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 96553b07..67b05b86 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -9,7 +9,6 @@ import ( "time" "github.com/tinyauthapp/tinyauth/internal/assets" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/gin-gonic/gin" ) @@ -40,8 +39,6 @@ func (m *UIMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { path := strings.TrimPrefix(c.Request.URL.Path, "/") - tlog.App.Debug().Str("path", path).Msg("path") - switch strings.SplitN(path, "/", 2)[0] { case "api", "resources", ".well-known": c.Next() diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index d75e3a72..070da695 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -5,7 +5,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) // See context middleware for explanation of why we have to do this @@ -17,10 +17,14 @@ var ( } ) -type ZerologMiddleware struct{} +type ZerologMiddleware struct { + log *logger.Logger +} -func NewZerologMiddleware() *ZerologMiddleware { - return &ZerologMiddleware{} +func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { + return &ZerologMiddleware{ + log: log, + } } func (m *ZerologMiddleware) Init() error { @@ -50,7 +54,7 @@ func (m *ZerologMiddleware) Middleware() gin.HandlerFunc { latency := time.Since(tStart).String() - subLogger := tlog.HTTP.With().Str("method", method). + subLogger := m.log.HTTP.With().Str("method", method). Str("path", path). Str("address", address). Str("client_ip", clientIP). From e214d6d8d45b5c45d811f900998e4b4021c27cb2 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 17:18:39 +0300 Subject: [PATCH 05/22] refactor: rework logging and cancellation in services --- internal/bootstrap/router_bootstrap.go | 7 +- internal/bootstrap/service_bootstrap.go | 49 ++---- internal/controller/user_controller.go | 2 +- internal/model/runtime.go | 8 - internal/service/access_controls_service.go | 19 ++- internal/service/auth_service.go | 172 ++++++++++---------- internal/service/docker_service.go | 47 ++++-- internal/service/kubernetes_service.go | 51 +++--- internal/service/ldap_service.go | 80 +++++---- internal/service/oauth_broker_service.go | 13 +- internal/service/oidc_service.go | 149 +++++++++-------- internal/utils/app_utils.go | 3 - 12 files changed, 310 insertions(+), 290 deletions(-) diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 2250fb19..47d3461e 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -29,10 +29,7 @@ func (app *BootstrapApp) setupRouter() error { } } - contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ - CookieDomain: app.runtime.CookieDomain, - SessionCookieName: app.runtime.SessionCookieName, - }, app.services.authService, app.services.oauthBrokerService) + contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) err := contextMiddleware.Init() @@ -52,7 +49,7 @@ func (app *BootstrapApp) setupRouter() error { engine.Use(uiMiddleware.Middleware()) - zerologMiddleware := middleware.NewZerologMiddleware() + zerologMiddleware := middleware.NewZerologMiddleware(app.log) err = zerologMiddleware.Init() diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 9f44540d..6d79d801 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -4,21 +4,11 @@ import ( "fmt" "os" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/service" ) func (app *BootstrapApp) setupServices() error { - ldapService := service.NewLdapService(service.LdapServiceConfig{ - Address: app.config.LDAP.Address, - BindDN: app.config.LDAP.BindDN, - BindPassword: app.config.LDAP.BindPassword, - BaseDN: app.config.LDAP.BaseDN, - Insecure: app.config.LDAP.Insecure, - SearchFilter: app.config.LDAP.SearchFilter, - AuthCert: app.config.LDAP.AuthCert, - AuthKey: app.config.LDAP.AuthKey, - }) + ldapService := service.NewLdapService(app.log, app.config, app.ctx) err := ldapService.Init() @@ -32,10 +22,12 @@ func (app *BootstrapApp) setupServices() error { useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") + var labelProvider service.LabelProviderImpl + if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService := service.NewKubernetesService() + kubernetesService := service.NewKubernetesService(app.log, app.ctx) err = kubernetesService.Init() @@ -44,11 +36,11 @@ func (app *BootstrapApp) setupServices() error { } app.services.kubernetesService = kubernetesService - app.runtime.LabelProvider = model.LabelProviderKubernetes + labelProvider = kubernetesService } else { app.log.App.Debug().Msg("Using Docker label provider") - dockerService := service.NewDockerService() + dockerService := service.NewDockerService(app.log, app.ctx) err = dockerService.Init() @@ -57,10 +49,10 @@ func (app *BootstrapApp) setupServices() error { } app.services.dockerService = dockerService - app.runtime.LabelProvider = model.LabelProviderDocker + labelProvider = dockerService } - accessControlsService := service.NewAccessControlsService(app.runtime.LabelProvider, app.config.Apps) + accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) err = accessControlsService.Init() @@ -70,7 +62,7 @@ func (app *BootstrapApp) setupServices() error { app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.runtime.OAuthProviders) + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) err = oauthBrokerService.Init() @@ -80,20 +72,7 @@ func (app *BootstrapApp) setupServices() error { app.services.oauthBrokerService = oauthBrokerService - authService := service.NewAuthService(service.AuthServiceConfig{ - LocalUsers: &app.runtime.LocalUsers, - OauthWhitelist: app.runtime.OAuthWhitelist, - SessionExpiry: app.config.Auth.SessionExpiry, - SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, - SecureCookie: app.config.Auth.SecureCookie, - CookieDomain: app.runtime.CookieDomain, - LoginTimeout: app.config.Auth.LoginTimeout, - LoginMaxRetries: app.config.Auth.LoginMaxRetries, - SessionCookieName: app.runtime.SessionCookieName, - IP: app.config.Auth.IP, - LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, - SubdomainsEnabled: app.config.Auth.SubdomainsEnabled, - }, app.services.ldapService, app.queries, app.services.oauthBrokerService) + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService) err = authService.Init() @@ -103,13 +82,7 @@ func (app *BootstrapApp) setupServices() error { app.services.authService = authService - oidcService := service.NewOIDCService(service.OIDCServiceConfig{ - Clients: app.config.OIDC.Clients, - PrivateKeyPath: app.config.OIDC.PrivateKeyPath, - PublicKeyPath: app.config.OIDC.PublicKeyPath, - Issuer: app.config.AppURL, - SessionExpiry: app.config.Auth.SessionExpiry, - }, app.queries) + oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx) err = oidcService.Init() diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index a7a1f948..b405bb03 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -375,7 +375,7 @@ func (controller *UserController) totpHandler(c *gin.Context) { cookie, err := controller.auth.CreateSession(c, sessionCookie) if err != nil { - tlog.App.Error().Err(err).Msg("Failed to create session cookie") + controller.log.App.Error().Err(err).Str("username", context.GetUsername()).Msg("Failed to create session cookie after successful TOTP verification") c.JSON(500, gin.H{ "status": 500, "message": "Internal Server Error", diff --git a/internal/model/runtime.go b/internal/model/runtime.go index 72eab370..9bd81770 100644 --- a/internal/model/runtime.go +++ b/internal/model/runtime.go @@ -13,7 +13,6 @@ type RuntimeConfig struct { OAuthWhitelist []string ConfiguredProviders []Provider OIDCClients []OIDCClientConfig - LabelProvider LabelProvider } type Provider struct { @@ -21,10 +20,3 @@ type Provider struct { ID string `json:"id"` OAuth bool `json:"oauth"` } - -type LabelProvider int - -const ( - LabelProviderDocker LabelProvider = iota - LabelProviderKubernetes -) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index d31ae6b7..9bfe834d 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -4,7 +4,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type LabelProviderImpl interface { @@ -12,12 +12,17 @@ type LabelProviderImpl interface { } type AccessControlsService struct { - labelProvider LabelProvider + log *logger.Logger + labelProvider LabelProviderImpl static map[string]model.App } -func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { +func NewAccessControlsService( + log *logger.Logger, + labelProvider LabelProviderImpl, + static map[string]model.App) *AccessControlsService { return &AccessControlsService{ + log: log, labelProvider: labelProvider, static: static, } @@ -31,13 +36,13 @@ func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { var appAcls *model.App for app, config := range acls.static { if config.Config.Domain == domain { - tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by domain") appAcls = &config break // If we find a match by domain, we can stop searching } if strings.SplitN(domain, ".", 2)[0] == app { - tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") + acls.log.App.Debug().Str("name", app).Msg("Found matching container by app name") appAcls = &config break // If we find a match by app name, we can stop searching } @@ -50,11 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, app := acls.lookupStaticACLs(domain) if app != nil { - tlog.App.Debug().Msg("Using ACls from static configuration") + acls.log.App.Debug().Msg("Using static ACLs for app") return app, nil } // Fallback to label provider - tlog.App.Debug().Msg("Falling back to label provider for ACLs") + acls.log.App.Debug().Msg("Using label provider for app") return acls.labelProvider.GetLabels(domain) } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 16c53fe0..8b891c34 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -14,7 +14,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -72,39 +72,40 @@ type Lockdown struct { ActiveUntil time.Time } -type AuthServiceConfig struct { - LocalUsers *[]model.LocalUser - OauthWhitelist []string - SessionExpiry int - SessionMaxLifetime int - SecureCookie bool - CookieDomain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - IP model.IPConfig - LDAPGroupsCacheTTL int - SubdomainsEnabled bool -} - type AuthService struct { - config AuthServiceConfig + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + context context.Context + + ldap *LdapService + queries *repository.Queries + oauthBroker *OAuthBrokerService + loginAttempts map[string]*LoginAttempt ldapGroupsCache map[string]*LdapGroupsCache oauthPendingSessions map[string]*OAuthPendingSession oauthMutex sync.RWMutex loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex - ldap *LdapService - queries *repository.Queries - oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { +func NewAuthService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + context context.Context, + ldap *LdapService, + queries *repository.Queries, + oauthBroker *OAuthBrokerService, +) *AuthService { return &AuthService{ + log: log, + runtime: runtime, + context: context, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -173,10 +174,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { - if auth.config.LocalUsers == nil { + if auth.runtime.LocalUsers == nil { return nil } - for _, user := range *auth.config.LocalUsers { + for _, user := range auth.runtime.LocalUsers { if user.Username == username { return &user } @@ -209,7 +210,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { auth.ldapGroupsMutex.Lock() auth.ldapGroupsCache[userDN] = &LdapGroupsCache{ Groups: groups, - Expires: time.Now().Add(time.Duration(auth.config.LDAPGroupsCacheTTL) * time.Second), + Expires: time.Now().Add(time.Duration(auth.config.LDAP.GroupCacheTTL) * time.Second), } auth.ldapGroupsMutex.Unlock() @@ -228,7 +229,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { return true, remaining } - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return false, 0 } @@ -246,7 +247,7 @@ func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { } func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { - if auth.config.LoginMaxRetries <= 0 || auth.config.LoginTimeout <= 0 { + if auth.config.Auth.LoginMaxRetries <= 0 || auth.config.Auth.LoginTimeout <= 0 { return } @@ -277,14 +278,14 @@ func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { attempt.FailedAttempts++ - if attempt.FailedAttempts >= auth.config.LoginMaxRetries { - attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second) - tlog.App.Warn().Str("identifier", identifier).Int("timeout", auth.config.LoginTimeout).Msg("Account locked due to too many failed login attempts") + if attempt.FailedAttempts >= auth.config.Auth.LoginMaxRetries { + attempt.LockedUntil = time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second) + auth.log.App.Warn().Str("identifier", identifier).Int("failedAttempts", attempt.FailedAttempts).Msg("Account locked due to too many failed login attempts") } } func (auth *AuthService) IsEmailWhitelisted(email string) bool { - return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) + return utils.CheckFilter(strings.Join(auth.runtime.OAuthWhitelist, ","), email) } func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { @@ -299,7 +300,7 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess if data.TotpPending { expiry = 3600 } else { - expiry = auth.config.SessionExpiry + expiry = auth.config.Auth.SessionExpiry } expiresAt := time.Now().Add(time.Duration(expiry) * time.Second) @@ -325,13 +326,13 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: expiresAt, MaxAge: int(time.Until(expiresAt).Seconds()), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -348,8 +349,8 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http var refreshThreshold int64 - if auth.config.SessionExpiry <= int(time.Hour.Seconds()) { - refreshThreshold = int64(auth.config.SessionExpiry / 2) + if auth.config.Auth.SessionExpiry <= int(time.Hour.Seconds()) { + refreshThreshold = int64(auth.config.Auth.SessionExpiry / 2) } else { refreshThreshold = int64(time.Hour.Seconds()) } @@ -378,13 +379,13 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: session.UUID, Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second), MaxAge: int(newExpiry - currentTime), - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -395,7 +396,7 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. err := auth.queries.DeleteSession(ctx, uuid) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") + auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } err = auth.queries.DeleteSession(ctx, uuid) @@ -405,13 +406,13 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. } return &http.Cookie{ - Name: auth.config.SessionCookieName, + Name: auth.runtime.SessionCookieName, Value: "", Path: "/", - Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), + Domain: fmt.Sprintf(".%s", auth.runtime.CookieDomain), Expires: time.Now(), MaxAge: -1, - Secure: auth.config.SecureCookie, + Secure: auth.config.Auth.SecureCookie, HttpOnly: true, SameSite: http.SameSiteLaxMode, }, nil @@ -429,8 +430,8 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito currentTime := time.Now().Unix() - if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { - if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { + if auth.config.Auth.SessionMaxLifetime != 0 && session.CreatedAt != 0 { + if currentTime-session.CreatedAt > int64(auth.config.Auth.SessionMaxLifetime) { err = auth.queries.DeleteSession(ctx, uuid) if err != nil { return nil, fmt.Errorf("failed to delete expired session: %w", err) @@ -451,7 +452,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito } func (auth *AuthService) LocalAuthConfigured() bool { - return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0 + return len(auth.runtime.LocalUsers) > 0 } func (auth *AuthService) LDAPAuthConfigured() bool { @@ -464,18 +465,18 @@ func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext } if context.Provider == model.ProviderOAuth { - tlog.App.Debug().Msg("Checking OAuth whitelist") + auth.log.App.Debug().Msg("User is an OAuth user, checking OAuth whitelist") return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) } if acls.Users.Block != "" { - tlog.App.Debug().Msg("Checking blocked users") + auth.log.App.Debug().Msg("Checking users block list") if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { return false } } - tlog.App.Debug().Msg("Checking users") + auth.log.App.Debug().Msg("Checking users allow list") return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) } @@ -485,23 +486,23 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex } if !context.IsOAuth() { - tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") + auth.log.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") return false } if _, ok := model.OverrideProviders[context.OAuth.ID]; ok { - tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check") + auth.log.App.Debug().Str("provider", context.OAuth.ID).Msg("Provider override detected, skipping group check") return true } for _, userGroup := range context.OAuth.Groups { if utils.CheckFilter(acls.OAuth.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.OAuth.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -511,18 +512,18 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext } if !context.IsLDAP() { - tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") + auth.log.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check") return false } for _, userGroup := range context.LDAP.Groups { if utils.CheckFilter(acls.LDAP.Groups, strings.TrimSpace(userGroup)) { - tlog.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") + auth.log.App.Trace().Str("group", userGroup).Str("required", acls.LDAP.Groups).Msg("User group matched") return true } } - tlog.App.Debug().Msg("No groups matched") + auth.log.App.Debug().Msg("No groups matched") return false } @@ -566,17 +567,17 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { } // Merge the global and app IP filter - blockedIps := append(auth.config.IP.Block, acls.IP.Block...) - allowedIPs := append(auth.config.IP.Allow, acls.IP.Allow...) + blockedIps := append(auth.config.Auth.IP.Block, acls.IP.Block...) + allowedIPs := append(auth.config.Auth.IP.Allow, acls.IP.Allow...) for _, blocked := range blockedIps { res, err := utils.FilterIP(blocked, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") + auth.log.App.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in blocked list, denying access") + auth.log.App.Debug().Str("ip", ip).Str("item", blocked).Msg("IP is in block list, denying access") return false } } @@ -584,21 +585,21 @@ func (auth *AuthService) CheckIP(ip string, acls *model.App) bool { for _, allowed := range allowedIPs { res, err := utils.FilterIP(allowed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") + auth.log.App.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allowed list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", allowed).Msg("IP is in allow list, allowing access") return true } } if len(allowedIPs) > 0 { - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in allow list, denying access") return false } - tlog.App.Debug().Str("ip", ip).Msg("IP not in allow or block list, allowing by default") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in any block or allow list, allowing access by default") return true } @@ -610,16 +611,16 @@ func (auth *AuthService) IsBypassedIP(ip string, acls *model.App) bool { for _, bypassed := range acls.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) if err != nil { - tlog.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") + auth.log.App.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list") continue } if res { - tlog.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, allowing access") + auth.log.App.Debug().Str("ip", ip).Str("item", bypassed).Msg("IP is in bypass list, skipping authentication") return true } } - tlog.App.Debug().Str("ip", ip).Msg("IP not in bypass list, continuing with authentication") + auth.log.App.Debug().Str("ip", ip).Msg("IP not in bypass list, proceeding with authentication") return false } @@ -726,18 +727,23 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { ticker := time.NewTicker(30 * time.Minute) defer ticker.Stop() - for range ticker.C { - auth.oauthMutex.Lock() + for { + select { + case <-ticker.C: + auth.oauthMutex.Lock() - now := time.Now() + now := time.Now() - for sessionId, session := range auth.oauthPendingSessions { - if now.After(session.ExpiresAt) { - delete(auth.oauthPendingSessions, sessionId) + for sessionId, session := range auth.oauthPendingSessions { + if now.After(session.ExpiresAt) { + delete(auth.oauthPendingSessions, sessionId) + } } - } - auth.oauthMutex.Unlock() + auth.oauthMutex.Unlock() + case <-auth.context.Done(): + return + } } } @@ -806,11 +812,11 @@ func (auth *AuthService) lockdownMode() { auth.loginMutex.Lock() - tlog.App.Warn().Msg("Multiple login attempts detected, possibly DDOS attack. Activating temporary lockdown.") + auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode") auth.lockdown = &Lockdown{ Active: true, - ActiveUntil: time.Now().Add(time.Duration(auth.config.LoginTimeout) * time.Second), + ActiveUntil: time.Now().Add(time.Duration(auth.config.Auth.LoginTimeout) * time.Second), } // At this point all login attemps will also expire so, @@ -827,11 +833,14 @@ func (auth *AuthService) lockdownMode() { // Timer expired, end lockdown case <-ctx.Done(): // Context cancelled, end lockdown + case <-auth.context.Done(): + // Service is shutting down, end lockdown } auth.loginMutex.Lock() - tlog.App.Info().Msg("Lockdown period ended, resuming normal operation") + auth.log.App.Info().Msg("Exiting lockdown mode") + auth.lockdown = nil auth.loginMutex.Unlock() } @@ -845,10 +854,3 @@ func (auth *AuthService) ClearRateLimitsTestingOnly() { } auth.loginMutex.Unlock() } - -func (auth *AuthService) getCookieDomain() string { - if auth.config.SubdomainsEnabled { - return "." + auth.config.CookieDomain - } - return auth.config.CookieDomain -} diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index c5f95dd4..763e26fb 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -6,20 +6,28 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" container "github.com/docker/docker/api/types/container" "github.com/docker/docker/client" ) type DockerService struct { - client *client.Client - context context.Context + log *logger.Logger + client *client.Client + context context.Context + isConnected bool } -func NewDockerService() *DockerService { - return &DockerService{} +func NewDockerService( + log *logger.Logger, + context context.Context, +) *DockerService { + return &DockerService{ + log: log, + context: context, + } } func (docker *DockerService) Init() error { @@ -28,16 +36,14 @@ func (docker *DockerService) Init() error { return err } - ctx := context.Background() - client.NegotiateAPIVersion(ctx) + client.NegotiateAPIVersion(docker.context) docker.client = client - docker.context = ctx _, err = docker.client.Ping(docker.context) if err != nil { - tlog.App.Debug().Err(err).Msg("Docker not connected") + docker.log.App.Debug().Err(err).Msg("Docker not connected") docker.isConnected = false docker.client = nil docker.context = nil @@ -45,7 +51,9 @@ func (docker *DockerService) Init() error { } docker.isConnected = true - tlog.App.Debug().Msg("Docker connected") + docker.log.App.Debug().Msg("Docker connected successfully") + + go docker.watchAndClose() return nil } @@ -60,7 +68,7 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { if !docker.isConnected { - tlog.App.Debug().Msg("Docker not connected, returning empty labels") + docker.log.App.Debug().Msg("Docker service not connected, returning empty labels") return nil, nil } @@ -82,17 +90,28 @@ func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { for appName, appLabels := range labels.Apps { if appLabels.Config.Domain == appDomain { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") return &appLabels, nil } if strings.SplitN(appDomain, ".", 2)[0] == appName { - tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") + docker.log.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") return &appLabels, nil } } } - tlog.App.Debug().Msg("No matching container found, returning empty labels") + docker.log.App.Debug().Str("domain", appDomain).Msg("No matching container found for domain") return nil, nil } + +func (docker *DockerService) watchAndClose() { + <-docker.context.Done() + docker.log.App.Debug().Msg("Closing Docker client") + if docker.client != nil { + err := docker.client.Close() + if err != nil { + docker.log.App.Error().Err(err).Msg("Error closing Docker client") + } + } +} diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 9c5ad427..acba24e4 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -9,7 +9,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -36,8 +36,10 @@ type ingressApp struct { } type KubernetesService struct { + log *logger.Logger + ctx context.Context + client dynamic.Interface - ctx context.Context cancel context.CancelFunc started bool mu sync.RWMutex @@ -46,8 +48,13 @@ type KubernetesService struct { appNameIndex map[string]ingressAppKey } -func NewKubernetesService() *KubernetesService { +func NewKubernetesService( + log *logger.Logger, + context context.Context, +) *KubernetesService { return &KubernetesService{ + log: log, + ctx: context, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), @@ -133,7 +140,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { } labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") if err != nil { - tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") + k.log.App.Warn().Err(err).Str("namespace", namespace).Str("name", name).Msg("Failed to decode ingress labels, skipping") k.removeIngress(namespace, name) return } @@ -161,13 +168,13 @@ func (k *KubernetesService) resyncGVR(gvr schema.GroupVersionResource) error { list, err := k.client.Resource(gvr).List(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Debug().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list ingresses during resync") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to list resources for resync") return err } for i := range list.Items { k.updateFromItem(&list.Items[i]) } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resynced ingress cache") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Int("count", len(list.Items)).Msg("Resync complete") return nil } @@ -181,14 +188,14 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. return false case event, ok := <-w.ResultChan(): if !ok { - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting in 5 seconds") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Watcher channel closed, restarting watcher") w.Stop() time.Sleep(5 * time.Second) return true } item, ok := event.Object.(*unstructured.Unstructured) if !ok { - tlog.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Failed to cast watched object") + k.log.App.Warn().Str("api", gvr.GroupVersion().String()).Msg("Received unexpected event object, skipping") continue } switch event.Type { @@ -199,7 +206,7 @@ func (k *KubernetesService) runWatcher(gvr schema.GroupVersionResource, w watch. } case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed during watcher run") } } } @@ -210,29 +217,29 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { defer resyncTicker.Stop() if err := k.resyncGVR(gvr); err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, retrying in 30 seconds") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Initial resync failed, will retry") time.Sleep(30 * time.Second) } for { select { case <-k.ctx.Done(): - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Stopping watcher") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") return case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { - tlog.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Periodic resync failed, will retry") } default: ctx, cancel := context.WithCancel(k.ctx) watcher, err := k.client.Resource(gvr).Watch(ctx, metav1.ListOptions{}) if err != nil { - tlog.App.Error().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to start watcher, will retry") cancel() time.Sleep(10 * time.Second) continue } - tlog.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Watcher started successfully") if !k.runWatcher(gvr, watcher, resyncTicker) { cancel() return @@ -257,7 +264,7 @@ func (k *KubernetesService) Init() error { } k.client = client - k.ctx, k.cancel = context.WithCancel(context.Background()) + k.ctx, k.cancel = context.WithCancel(k.ctx) gvr := schema.GroupVersionResource{ Group: "networking.k8s.io", @@ -269,38 +276,38 @@ func (k *KubernetesService) Init() error { defer accessCancel() _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) if err != nil { - tlog.App.Warn().Err(err).Msg("Insufficient permissions for networking.k8s.io/v1 Ingress, Kubernetes label provider will not work") + k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") k.started = false return nil } - tlog.App.Debug().Msg("networking.k8s.io/v1 Ingress API accessible") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") go k.watchGVR(gvr) k.started = true - tlog.App.Info().Msg("Kubernetes label provider initialized") + k.log.App.Debug().Msg("Kubernetes label provider started successfully") return nil } func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { - tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") + k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") return nil, nil } // First check cache app := k.getByDomain(appDomain) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") + k.log.App.Debug().Str("domain", appDomain).Msg("Found labels in cache by domain") return app, nil } appName := strings.SplitN(appDomain, ".", 2)[0] app = k.getByAppName(appName) if app != nil { - tlog.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") + k.log.App.Debug().Str("domain", appDomain).Str("appName", appName).Msg("Found labels in cache by app name") return app, nil } - tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") + k.log.App.Debug().Str("domain", appDomain).Msg("No labels found for domain") return nil, nil } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 0963ebf5..d356cc75 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -9,31 +9,30 @@ import ( "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type LdapServiceConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string - AuthCert string - AuthKey string -} - type LdapService struct { - config LdapServiceConfig + log *logger.Logger + config model.Config + context context.Context + conn *ldapgo.Conn mutex sync.RWMutex cert *tls.Certificate isConfigured bool } -func NewLdapService(config LdapServiceConfig) *LdapService { +func NewLdapService( + log *logger.Logger, + config model.Config, + context context.Context, +) *LdapService { return &LdapService{ - config: config, + log: log, + config: config, + context: context, } } @@ -57,7 +56,7 @@ func (ldap *LdapService) Unconfigure() error { } func (ldap *LdapService) Init() error { - if ldap.config.Address == "" { + if ldap.config.LDAP.Address == "" { ldap.isConfigured = false return nil } @@ -65,13 +64,13 @@ func (ldap *LdapService) Init() error { ldap.isConfigured = true // Check whether authentication with client certificate is possible - if ldap.config.AuthCert != "" && ldap.config.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.AuthCert, ldap.config.AuthKey) + if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) if err != nil { return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } ldap.cert = &cert - tlog.App.Info().Msg("Using LDAP with mTLS authentication") + ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -90,15 +89,24 @@ func (ldap *LdapService) Init() error { } go func() { - for range time.Tick(time.Duration(5) * time.Minute) { - err := ldap.heartbeat() - if err != nil { - tlog.App.Error().Err(err).Msg("LDAP connection heartbeat failed") - if reconnectErr := ldap.reconnect(); reconnectErr != nil { - tlog.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") - continue + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := ldap.heartbeat() + if err != nil { + ldap.log.App.Warn().Err(err).Msg("LDAP connection heartbeat failed, attempting to reconnect") + if reconnectErr := ldap.reconnect(); reconnectErr != nil { + ldap.log.App.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server") + continue + } + ldap.log.App.Info().Msg("Successfully reconnected to LDAP server") } - tlog.App.Info().Msg("Successfully reconnected to LDAP server") + case <-ldap.context.Done(): + ldap.log.App.Debug().Msg("LDAP service context cancelled, stopping heartbeat") + return } } }() @@ -120,13 +128,13 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { // 2. conn.StartTLS(tlsConfig) // 3. conn.externalBind() if ldap.cert != nil { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{*ldap.cert}, })) } else { - conn, err = ldapgo.DialURL(ldap.config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: ldap.config.Insecure, + conn, err = ldapgo.DialURL(ldap.config.LDAP.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.config.LDAP.Insecure, MinVersion: tls.VersionTLS12, })) } @@ -146,10 +154,10 @@ func (ldap *LdapService) connect() (*ldapgo.Conn, error) { func (ldap *LdapService) GetUserDN(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(ldap.config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.config.LDAP.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, @@ -176,7 +184,7 @@ func (ldap *LdapService) GetUserGroups(userDN string) ([]string, error) { escapedUserDN := ldapgo.EscapeFilter(userDN) searchRequest := ldapgo.NewSearchRequest( - ldap.config.BaseDN, + ldap.config.LDAP.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, fmt.Sprintf("(&(objectclass=groupOfUniqueNames)(uniquemember=%s))", escapedUserDN), []string{"dn"}, @@ -224,7 +232,7 @@ func (ldap *LdapService) BindService(rebind bool) error { if ldap.cert != nil { return ldap.conn.ExternalBind() } - return ldap.conn.Bind(ldap.config.BindDN, ldap.config.BindPassword) + return ldap.conn.Bind(ldap.config.LDAP.BindDN, ldap.config.LDAP.BindPassword) } func (ldap *LdapService) Bind(userDN string, password string) error { @@ -238,7 +246,7 @@ func (ldap *LdapService) Bind(userDN string, password string) error { } func (ldap *LdapService) heartbeat() error { - tlog.App.Debug().Msg("Performing LDAP connection heartbeat") + ldap.log.App.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( "", @@ -260,7 +268,7 @@ func (ldap *LdapService) heartbeat() error { } func (ldap *LdapService) reconnect() error { - tlog.App.Info().Msg("Reconnecting to LDAP server") + ldap.log.App.Info().Msg("Attempting to reconnect to LDAP server") exp := backoff.NewExponentialBackOff() exp.InitialInterval = 500 * time.Millisecond diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 15823c47..c3bfec9c 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -2,7 +2,7 @@ package service import ( "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "slices" @@ -19,6 +19,8 @@ type OAuthServiceImpl interface { } type OAuthBrokerService struct { + log *logger.Logger + services map[string]OAuthServiceImpl configs map[string]model.OAuthServiceConfig } @@ -28,7 +30,10 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ "google": newGoogleOAuthService, } -func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { +func NewOAuthBrokerService( + log *logger.Logger, + configs map[string]model.OAuthServiceConfig, +) *OAuthBrokerService { return &OAuthBrokerService{ services: make(map[string]OAuthServiceImpl), configs: configs, @@ -39,10 +44,10 @@ func (broker *OAuthBrokerService) Init() error { for name, cfg := range broker.configs { if presetFunc, exists := presets[name]; exists { broker.services[name] = presetFunc(cfg) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { broker.services[name] = NewOAuthService(cfg, name) - tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") + broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } return nil diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1e1c1986..da69eb96 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -25,7 +25,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) var ( @@ -111,17 +111,13 @@ type AuthorizeRequest struct { CodeChallengeMethod string `json:"code_challenge_method"` } -type OIDCServiceConfig struct { - Clients map[string]model.OIDCClientConfig - PrivateKeyPath string - PublicKeyPath string - Issuer string - SessionExpiry int -} - type OIDCService struct { - config OIDCServiceConfig - queries *repository.Queries + log *logger.Logger + config model.Config + runtime model.RuntimeConfig + queries *repository.Queries + context context.Context + clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey @@ -129,10 +125,18 @@ type OIDCService struct { isConfigured bool } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { +func NewOIDCService( + log *logger.Logger, + config model.Config, + runtime model.RuntimeConfig, + queries *repository.Queries, + context context.Context) *OIDCService { return &OIDCService{ + log: log, config: config, + runtime: runtime, queries: queries, + context: context, } } @@ -142,7 +146,7 @@ func (service *OIDCService) IsConfigured() bool { func (service *OIDCService) Init() error { // If not configured, skip init - if len(service.config.Clients) == 0 { + if len(service.runtime.OIDCClients) == 0 { service.isConfigured = false return nil } @@ -150,7 +154,7 @@ func (service *OIDCService) Init() error { service.isConfigured = true // Ensure issuer is https - uissuer, err := url.Parse(service.config.Issuer) + uissuer, err := url.Parse(service.runtime.AppURL) if err != nil { return err @@ -163,14 +167,14 @@ func (service *OIDCService) Init() error { service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.PublicKeyPath) == "" { + if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { return errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath) + fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err @@ -189,8 +193,8 @@ func (service *OIDCService) Init() error { Type: "RSA PRIVATE KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600) + service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { return err } @@ -200,7 +204,7 @@ func (service *OIDCService) Init() error { if block == nil { return errors.New("failed to decode private key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded private key") + service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return err @@ -208,7 +212,7 @@ func (service *OIDCService) Init() error { service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.PublicKeyPath) + fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { return err @@ -224,8 +228,8 @@ func (service *OIDCService) Init() error { Type: "RSA PUBLIC KEY", Bytes: der, }) - tlog.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644) + service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { return err } @@ -235,7 +239,7 @@ func (service *OIDCService) Init() error { if block == nil { return errors.New("failed to decode public key") } - tlog.App.Trace().Str("type", block.Type).Msg("Loaded public key") + service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) @@ -257,7 +261,7 @@ func (service *OIDCService) Init() error { // We will reorganize the client into a map with the client ID as the key service.clients = make(map[string]model.OIDCClientConfig) - for id, client := range service.config.Clients { + for id, client := range service.config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) @@ -273,9 +277,12 @@ func (service *OIDCService) Init() error { } client.ClientSecretFile = "" service.clients[id] = client - tlog.App.Info().Str("id", client.ID).Msg("Registered OIDC client") + service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") } + // Start cleanup routine + go service.cleanupRoutine() + return nil } @@ -307,7 +314,7 @@ func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error return errors.New("invalid_scope") } if !slices.Contains(SupportedScopes, scope) { - tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored") + service.log.App.Warn().Str("scope", scope).Msg("Requested unsupported scope") } } @@ -357,7 +364,7 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r entry.CodeChallenge = req.CodeChallenge } else { entry.CodeChallenge = service.hashAndEncodePKCE(req.CodeChallenge) - tlog.App.Warn().Msg("Received plain PKCE code challenge, it's recommended to use S256 for better security") + service.log.App.Warn().Msg("Using plain PKCE code challenge method is not recommended, consider switching to S256 for better security") } } @@ -449,7 +456,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { createdAt := time.Now().Unix() - expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + expiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() hasher := sha256.New() @@ -529,16 +536,16 @@ func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OID accessToken := utils.GenerateString(32) refreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() // Refresh token lives double the time of an access token but can't be used to access userinfo - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: refreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(codeEntry.Scope, ",", " "), } @@ -598,14 +605,14 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri accessToken := utils.GenerateString(32) newRefreshToken := utils.GenerateString(32) - tokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() - refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry*2) * time.Second).Unix() + tokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry) * time.Second).Unix() + refreshTokenExpiresAt := time.Now().Add(time.Duration(service.config.Auth.SessionExpiry*2) * time.Second).Unix() tokenResponse := TokenResponse{ AccessToken: accessToken, RefreshToken: newRefreshToken, TokenType: "Bearer", - ExpiresIn: int64(service.config.SessionExpiry), + ExpiresIn: int64(service.config.Auth.SessionExpiry), IDToken: idToken, Scope: strings.ReplaceAll(entry.Scope, ",", " "), } @@ -748,56 +755,64 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er } // Cleanup routine - Resource heavy due to the linked tables -func (service *OIDCService) Cleanup() { - // We need a context for the routine - ctx := context.Background() +func (service *OIDCService) cleanupRoutine() { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() - for range ticker.C { - currentTime := time.Now().Unix() + for { + select { + case <-ticker.C: + service.log.App.Debug().Msg("Starting OIDC cleanup routine") - // For the OIDC tokens, if they are expired we delete the userinfo and codes - expiredTokens, err := service.queries.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ - TokenExpiresAt: currentTime, - RefreshTokenExpiresAt: currentTime, - }) + currentTime := time.Now().Unix() - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired tokens") - } + // For the OIDC tokens, if they are expired we delete the userinfo and codes + expiredTokens, err := service.queries.DeleteExpiredOidcTokens(service.context, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: currentTime, + RefreshTokenExpiresAt: currentTime, + }) - for _, expiredToken := range expiredTokens { - err := service.DeleteOldSession(ctx, expiredToken.Sub) if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete old session") + service.log.App.Warn().Err(err).Msg("Failed to delete expired tokens") } - } - // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything - expiredCodes, err := service.queries.DeleteExpiredOidcCodes(ctx, currentTime) - - if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete expired codes") - } + for _, expiredToken := range expiredTokens { + err := service.DeleteOldSession(service.context, expiredToken.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired token") + } + } - for _, expiredCode := range expiredCodes { - token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) + // For expired codes, we need to get the sub, check if tokens are expired and if they are remove everything + expiredCodes, err := service.queries.DeleteExpiredOidcCodes(service.context, currentTime) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } - tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") + service.log.App.Warn().Err(err).Msg("Failed to delete expired codes") } - if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { - err := service.DeleteOldSession(ctx, expiredCode.Sub) + for _, expiredCode := range expiredCodes { + token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) + if err != nil { - tlog.App.Warn().Err(err).Msg("Failed to delete session") + if errors.Is(err, sql.ErrNoRows) { + continue + } + service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") + } + + if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { + err := service.DeleteOldSession(service.context, expiredCode.Sub) + if err != nil { + service.log.App.Warn().Err(err).Msg("Failed to delete session for expired code") + } } } + + service.log.App.Debug().Msg("Finished OIDC cleanup routine") + case <-service.context.Done(): + service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping") + return } } } diff --git a/internal/utils/app_utils.go b/internal/utils/app_utils.go index d021c083..6413755b 100644 --- a/internal/utils/app_utils.go +++ b/internal/utils/app_utils.go @@ -7,8 +7,6 @@ import ( "net/url" "strings" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" - "github.com/weppos/publicsuffix-go/publicsuffix" ) @@ -28,7 +26,6 @@ func GetCookieDomain(u string) (string, error) { parts := strings.Split(host, ".") if len(parts) == 2 { - tlog.App.Warn().Msgf("Running on the root domain, cookies will be set for .%v", host) return host, nil } From 0958c3b864b337c30563a9a30cca923c119e7748 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 17:22:21 +0300 Subject: [PATCH 06/22] refactor: rework cli logging --- cmd/tinyauth/create_user.go | 9 +++++---- cmd/tinyauth/generate_totp.go | 11 ++++++----- cmd/tinyauth/healthcheck.go | 9 +++++---- cmd/tinyauth/verify_user.go | 11 ++++++----- internal/bootstrap/app_bootstrap.go | 19 ++++++++++--------- 5 files changed, 32 insertions(+), 27 deletions(-) diff --git a/cmd/tinyauth/create_user.go b/cmd/tinyauth/create_user.go index ef5fe266..d7e9f97e 100644 --- a/cmd/tinyauth/create_user.go +++ b/cmd/tinyauth/create_user.go @@ -6,8 +6,8 @@ import ( "strings" "charm.land/huh/v2" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "golang.org/x/crypto/bcrypt" ) @@ -40,7 +40,8 @@ func createUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -73,7 +74,7 @@ func createUserCmd() *cli.Command { return errors.New("username and password cannot be empty") } - tlog.App.Info().Str("username", tCfg.Username).Msg("Creating user") + log.App.Info().Str("username", tCfg.Username).Msg("Creating user") passwd, err := bcrypt.GenerateFromPassword([]byte(tCfg.Password), bcrypt.DefaultCost) if err != nil { @@ -86,7 +87,7 @@ func createUserCmd() *cli.Command { passwdStr = strings.ReplaceAll(passwdStr, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") + log.App.Info().Str("user", fmt.Sprintf("%s:%s", tCfg.Username, passwdStr)).Msg("User created") return nil }, diff --git a/cmd/tinyauth/generate_totp.go b/cmd/tinyauth/generate_totp.go index 8819922e..8492f87b 100644 --- a/cmd/tinyauth/generate_totp.go +++ b/cmd/tinyauth/generate_totp.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/mdp/qrterminal/v3" @@ -40,7 +40,8 @@ func generateTotpCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -88,9 +89,9 @@ func generateTotpCmd() *cli.Command { secret := key.Secret() - tlog.App.Info().Str("secret", secret).Msg("Generated TOTP secret") + log.App.Info().Str("secret", secret).Msg("Generated TOTP secret") - tlog.App.Info().Msg("Generated QR code") + log.App.Info().Msg("Generated QR code") config := qrterminal.Config{ Level: qrterminal.L, @@ -109,7 +110,7 @@ func generateTotpCmd() *cli.Command { user.Password = strings.ReplaceAll(user.Password, "$", "$$") } - tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") + log.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") return nil }, diff --git a/cmd/tinyauth/healthcheck.go b/cmd/tinyauth/healthcheck.go index 649a68c7..921479a5 100644 --- a/cmd/tinyauth/healthcheck.go +++ b/cmd/tinyauth/healthcheck.go @@ -9,8 +9,8 @@ import ( "os" "time" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/paerser/cli" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) type healthzResponse struct { @@ -26,7 +26,8 @@ func healthcheckCmd() *cli.Command { Resources: nil, AllowArg: true, Run: func(args []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() srvAddr := os.Getenv("TINYAUTH_SERVER_ADDRESS") if srvAddr == "" { @@ -48,7 +49,7 @@ func healthcheckCmd() *cli.Command { return errors.New("Could not determine app URL") } - tlog.App.Info().Str("app_url", appUrl).Msg("Performing health check") + log.App.Info().Str("app_url", appUrl).Msg("Performing health check") client := http.Client{ Timeout: 30 * time.Second, @@ -86,7 +87,7 @@ func healthcheckCmd() *cli.Command { return fmt.Errorf("failed to decode response: %w", err) } - tlog.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") + log.App.Info().Interface("response", healthResp).Msg("Tinyauth is healthy") return nil }, diff --git a/cmd/tinyauth/verify_user.go b/cmd/tinyauth/verify_user.go index 5058b606..b0347f6f 100644 --- a/cmd/tinyauth/verify_user.go +++ b/cmd/tinyauth/verify_user.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" "charm.land/huh/v2" "github.com/pquerna/otp/totp" @@ -44,7 +44,8 @@ func verifyUserCmd() *cli.Command { Configuration: tCfg, Resources: loaders, Run: func(_ []string) error { - tlog.NewSimpleLogger().Init() + log := logger.NewLogger().WithSimpleConfig() + log.Init() if tCfg.Interactive { form := huh.NewForm( @@ -97,9 +98,9 @@ func verifyUserCmd() *cli.Command { if user.TOTPSecret == "" { if tCfg.Totp != "" { - tlog.App.Warn().Msg("User does not have TOTP secret") + log.App.Warn().Msg("User does not have TOTP secret") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil } @@ -109,7 +110,7 @@ func verifyUserCmd() *cli.Command { return fmt.Errorf("TOTP code incorrect") } - tlog.App.Info().Msg("User verified") + log.App.Info().Msg("User verified") return nil }, diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 268d0d30..92a89b39 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -249,17 +249,18 @@ func (app *BootstrapApp) Setup() error { }() // monitor cancellation and server errors - select { - case <-app.ctx.Done(): - app.log.App.Debug().Msg("Shutting down application") - return nil - case err := <-errChan: - if err != nil { - return fmt.Errorf("server error: %w", err) + for { + select { + case <-app.ctx.Done(): + app.log.App.Debug().Msg("Oh, seems like I got to shutdown, bye!") + app.db.Close() + return nil + case err := <-errChan: + if err != nil { + return fmt.Errorf("server error: %w", err) + } } } - - return nil } func (app *BootstrapApp) serveHTTP() error { From b73a9db061e2825ec8776c114107451f97d084cc Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 17:43:20 +0300 Subject: [PATCH 07/22] fix: improve logging in routines --- internal/bootstrap/app_bootstrap.go | 4 +++- internal/bootstrap/router_bootstrap.go | 9 ++------- internal/service/auth_service.go | 6 ++++++ internal/service/ldap_service.go | 2 ++ internal/service/oauth_broker_service.go | 1 + internal/service/oidc_service.go | 4 ++-- internal/utils/logger/logger.go | 7 ++++--- internal/utils/logger/logger_test.go | 2 +- 8 files changed, 21 insertions(+), 14 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 92a89b39..bd052c6b 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -252,7 +252,7 @@ func (app *BootstrapApp) Setup() error { for { select { case <-app.ctx.Done(): - app.log.App.Debug().Msg("Oh, seems like I got to shutdown, bye!") + app.log.App.Info().Msg("Oh, seems like I got to shutdown, bye!") app.db.Close() return nil case err := <-errChan: @@ -410,6 +410,8 @@ func (app *BootstrapApp) dbCleanupRoutine() { if err != nil { app.log.App.Error().Err(err).Msg("Failed to delete expired sessions") } + + app.log.App.Debug().Msg("Database cleanup completed") case <-app.ctx.Done(): app.log.App.Debug().Msg("Stopping database cleanup routine") ticker.Stop() diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 47d3461e..ce739fc9 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -2,21 +2,16 @@ package bootstrap import ( "fmt" - "slices" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/middleware" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/gin-gonic/gin" ) -var DEV_MODES = []string{"main", "test", "development"} - func (app *BootstrapApp) setupRouter() error { - if !slices.Contains(DEV_MODES, model.Version) { - gin.SetMode(gin.ReleaseMode) - } + // we don't want gin debug mode + gin.SetMode(gin.ReleaseMode) engine := gin.New() engine.Use(gin.Recovery()) diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 8b891c34..60a205db 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -724,12 +724,16 @@ func (auth *AuthService) EndOAuthSession(sessionId string) { } func (auth *AuthService) CleanupOAuthSessionsRoutine() { + auth.log.App.Debug().Msg("Starting OAuth session cleanup routine") + ticker := time.NewTicker(30 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: + auth.log.App.Debug().Msg("Running OAuth session cleanup") + auth.oauthMutex.Lock() now := time.Now() @@ -741,7 +745,9 @@ func (auth *AuthService) CleanupOAuthSessionsRoutine() { } auth.oauthMutex.Unlock() + auth.log.App.Debug().Msg("OAuth session cleanup completed") case <-auth.context.Done(): + auth.log.App.Debug().Msg("Stopping OAuth session cleanup routine") return } } diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index d356cc75..c1a5d187 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -89,6 +89,8 @@ func (ldap *LdapService) Init() error { } go func() { + ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") + ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index c3bfec9c..8d693ad9 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -35,6 +35,7 @@ func NewOAuthBrokerService( configs map[string]model.OAuthServiceConfig, ) *OAuthBrokerService { return &OAuthBrokerService{ + log: log, services: make(map[string]OAuthServiceImpl), configs: configs, } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index da69eb96..38101fa7 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -756,14 +756,14 @@ func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) er // Cleanup routine - Resource heavy due to the linked tables func (service *OIDCService) cleanupRoutine() { - + service.log.App.Debug().Msg("Starting OIDC cleanup routine") ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - service.log.App.Debug().Msg("Starting OIDC cleanup routine") + service.log.App.Debug().Msg("Performing OIDC cleanup routine") currentTime := time.Now().Unix() diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go index d85af79e..24d93145 100644 --- a/internal/utils/logger/logger.go +++ b/internal/utils/logger/logger.go @@ -87,6 +87,10 @@ func (l *Logger) Init() { }) } + if base.GetLevel() == zerolog.TraceLevel || base.GetLevel() == zerolog.DebugLevel { + base = base.With().Caller().Logger() + } + l.base = base l.audit = l.createLogger("audit", l.config.Streams.Audit) l.HTTP = l.createLogger("http", l.config.Streams.HTTP) @@ -113,9 +117,6 @@ func (l *Logger) createLogger(component string, cfg model.LogStreamConfig) zerol if cfg.Level != "" { sub = sub.Level(l.parseLogLevel(cfg.Level)) } - if sub.GetLevel() == zerolog.DebugLevel { - sub = sub.With().Caller().Logger() - } return sub } diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go index 66387a5f..395d348f 100644 --- a/internal/utils/logger/logger_test.go +++ b/internal/utils/logger/logger_test.go @@ -162,7 +162,7 @@ func TestLogger(t *testing.T) { l.AuditLoginFailure("test", "test", "test", "test") assert.NotEmpty(t, buf.String()) - assert.Equal(t, 119, buf.Len()) // it's the length of the test log entry + assert.Equal(t, 81, buf.Len()) // it's the length of the test log entry }, }, } From 71ddfbbdba20125df46b599da5b2a7447b0191b4 Mon Sep 17 00:00:00 2001 From: Stavros Date: Fri, 8 May 2026 18:08:27 +0300 Subject: [PATCH 08/22] feat: use sync groups for better cancellation --- internal/bootstrap/app_bootstrap.go | 46 +++++++++++++++---------- internal/bootstrap/service_bootstrap.go | 10 +++--- internal/service/auth_service.go | 5 ++- internal/service/docker_service.go | 6 +++- internal/service/kubernetes_service.go | 11 +++--- internal/service/ldap_service.go | 7 ++-- internal/service/oidc_service.go | 10 ++++-- 7 files changed, 61 insertions(+), 34 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index bd052c6b..48d57a9d 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -14,6 +14,7 @@ import ( "os/signal" "sort" "strings" + "sync" "syscall" "time" @@ -45,6 +46,7 @@ type BootstrapApp struct { queries *repository.Queries router *gin.Engine db *sql.DB + wg sync.WaitGroup } func NewBootstrapApp(config model.Config) *BootstrapApp { @@ -227,33 +229,39 @@ func (app *BootstrapApp) Setup() error { // start db cleanup routine app.log.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine() + app.wg.Go(app.dbCleanupRoutine) // if analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { app.log.App.Debug().Msg("Starting heartbeat routine") - go app.heartbeatRoutine() + app.wg.Go(app.heartbeatRoutine) } // create err channel to listen for server errors errChan := make(chan error, 1) // serve unix - go func() { - errChan <- app.serveUnix() - }() + app.wg.Go(func() { + if err := app.serveUnix(); err != nil { + errChan <- err + } + }) // serve to http - go func() { - errChan <- app.serveHTTP() - }() + app.wg.Go(func() { + if err := app.serveHTTP(); err != nil { + errChan <- err + } + }) // monitor cancellation and server errors for { select { case <-app.ctx.Done(): - app.log.App.Info().Msg("Oh, seems like I got to shutdown, bye!") + app.wg.Wait() + app.log.App.Debug().Msg("Closing database") app.db.Close() + app.log.App.Info().Msg("Oh, it's time for me to go, bye!") return nil case err := <-errChan: if err != nil { @@ -275,14 +283,14 @@ func (app *BootstrapApp) serveHTTP() error { go func() { <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down server") + app.log.App.Debug().Msg("Shutting down http listener") server.Close() }() err := server.ListenAndServe() if err != nil && !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("failed to start server: %w", err) + return fmt.Errorf("failed to start http listener: %w", err) } return nil @@ -312,24 +320,26 @@ func (app *BootstrapApp) serveUnix() error { return fmt.Errorf("failed to create unix socket listner: %w", err) } + server := &http.Server{ + Handler: app.router.Handler(), + } + + defer server.Close() defer listener.Close() defer os.Remove(app.config.Server.SocketPath) go func() { <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down server") + app.log.App.Debug().Msg("Shutting down unix sokcet listener") + server.Close() listener.Close() os.Remove(app.config.Server.SocketPath) }() - server := &http.Server{ - Handler: app.router.Handler(), - } - err = server.Serve(listener) - if err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("failed to start server: %w", err) + if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) { + return fmt.Errorf("failed to start unix socket listener: %w", err) } return nil diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 6d79d801..6692b038 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -8,7 +8,7 @@ import ( ) func (app *BootstrapApp) setupServices() error { - ldapService := service.NewLdapService(app.log, app.config, app.ctx) + ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) err := ldapService.Init() @@ -27,7 +27,7 @@ func (app *BootstrapApp) setupServices() error { if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService := service.NewKubernetesService(app.log, app.ctx) + kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg) err = kubernetesService.Init() @@ -40,7 +40,7 @@ func (app *BootstrapApp) setupServices() error { } else { app.log.App.Debug().Msg("Using Docker label provider") - dockerService := service.NewDockerService(app.log, app.ctx) + dockerService := service.NewDockerService(app.log, app.ctx, &app.wg) err = dockerService.Init() @@ -72,7 +72,7 @@ func (app *BootstrapApp) setupServices() error { app.services.oauthBrokerService = oauthBrokerService - authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, app.services.ldapService, app.queries, app.services.oauthBrokerService) + authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) err = authService.Init() @@ -82,7 +82,7 @@ func (app *BootstrapApp) setupServices() error { app.services.authService = authService - oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx) + oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) err = oidcService.Init() diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 60a205db..e47d31cb 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -77,6 +77,7 @@ type AuthService struct { config model.Config runtime model.RuntimeConfig context context.Context + wg *sync.WaitGroup ldap *LdapService queries *repository.Queries @@ -98,6 +99,7 @@ func NewAuthService( config model.Config, runtime model.RuntimeConfig, context context.Context, + wg *sync.WaitGroup, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService, @@ -106,6 +108,7 @@ func NewAuthService( log: log, runtime: runtime, context: context, + wg: wg, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -117,7 +120,7 @@ func NewAuthService( } func (auth *AuthService) Init() error { - go auth.CleanupOAuthSessionsRoutine() + auth.wg.Go(auth.CleanupOAuthSessionsRoutine) return nil } diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 763e26fb..55579607 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -3,6 +3,7 @@ package service import ( "context" "strings" + "sync" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/decoders" @@ -16,6 +17,7 @@ type DockerService struct { log *logger.Logger client *client.Client context context.Context + wg *sync.WaitGroup isConnected bool } @@ -23,10 +25,12 @@ type DockerService struct { func NewDockerService( log *logger.Logger, context context.Context, + wg *sync.WaitGroup, ) *DockerService { return &DockerService{ log: log, context: context, + wg: wg, } } @@ -53,7 +57,7 @@ func (docker *DockerService) Init() error { docker.isConnected = true docker.log.App.Debug().Msg("Docker connected successfully") - go docker.watchAndClose() + docker.wg.Go(docker.watchAndClose) return nil } diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index acba24e4..1af6b4da 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -38,9 +38,9 @@ type ingressApp struct { type KubernetesService struct { log *logger.Logger ctx context.Context + wg *sync.WaitGroup client dynamic.Interface - cancel context.CancelFunc started bool mu sync.RWMutex ingressApps map[ingressKey][]ingressApp @@ -51,10 +51,12 @@ type KubernetesService struct { func NewKubernetesService( log *logger.Logger, context context.Context, + wg *sync.WaitGroup, ) *KubernetesService { return &KubernetesService{ log: log, ctx: context, + wg: wg, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), @@ -264,8 +266,6 @@ func (k *KubernetesService) Init() error { } k.client = client - k.ctx, k.cancel = context.WithCancel(k.ctx) - gvr := schema.GroupVersionResource{ Group: "networking.k8s.io", Version: "v1", @@ -274,6 +274,7 @@ func (k *KubernetesService) Init() error { accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) defer accessCancel() + _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) if err != nil { k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") @@ -282,7 +283,9 @@ func (k *KubernetesService) Init() error { } k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") - go k.watchGVR(gvr) + k.wg.Go(func() { + k.watchGVR(gvr) + }) k.started = true k.log.App.Debug().Msg("Kubernetes label provider started successfully") diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index c1a5d187..35d3d887 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -17,6 +17,7 @@ type LdapService struct { log *logger.Logger config model.Config context context.Context + wg *sync.WaitGroup conn *ldapgo.Conn mutex sync.RWMutex @@ -28,11 +29,13 @@ func NewLdapService( log *logger.Logger, config model.Config, context context.Context, + wg *sync.WaitGroup, ) *LdapService { return &LdapService{ log: log, config: config, context: context, + wg: wg, } } @@ -88,7 +91,7 @@ func (ldap *LdapService) Init() error { return fmt.Errorf("failed to connect to LDAP server: %w", err) } - go func() { + ldap.wg.Go(func() { ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ticker := time.NewTicker(5 * time.Minute) @@ -111,7 +114,7 @@ func (ldap *LdapService) Init() error { return } } - }() + }) return nil } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 38101fa7..7d4d8d71 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -16,6 +16,7 @@ import ( "net/url" "os" "strings" + "sync" "time" "slices" @@ -117,6 +118,7 @@ type OIDCService struct { runtime model.RuntimeConfig queries *repository.Queries context context.Context + wg *sync.WaitGroup clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey @@ -130,13 +132,15 @@ func NewOIDCService( config model.Config, runtime model.RuntimeConfig, queries *repository.Queries, - context context.Context) *OIDCService { + context context.Context, + wg *sync.WaitGroup) *OIDCService { return &OIDCService{ log: log, config: config, runtime: runtime, queries: queries, context: context, + wg: wg, } } @@ -281,7 +285,7 @@ func (service *OIDCService) Init() error { } // Start cleanup routine - go service.cleanupRoutine() + service.wg.Go(service.cleanupRoutine) return nil } @@ -811,7 +815,7 @@ func (service *OIDCService) cleanupRoutine() { service.log.App.Debug().Msg("Finished OIDC cleanup routine") case <-service.context.Done(): - service.log.App.Debug().Msg("OIDC cleanup routine context cancelled, stopping") + service.log.App.Debug().Msg("Stopping OIDC cleanup routine") return } } From 8c8d56f87c86428e77417901cad05576e4185d2d Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 12:24:10 +0300 Subject: [PATCH 09/22] refactor: simplify middleware, controller and service init --- internal/bootstrap/router_bootstrap.go | 56 ++------ internal/bootstrap/service_bootstrap.go | 42 +----- internal/controller/context_controller.go | 18 ++- internal/controller/health_controller.go | 13 +- internal/controller/oauth_controller.go | 10 +- internal/controller/oidc_controller.go | 24 ++-- internal/controller/proxy_controller.go | 12 +- internal/controller/resources_controller.go | 10 +- internal/controller/user_controller.go | 10 +- internal/controller/well_known_controller.go | 32 +++-- internal/middleware/context_middleware.go | 6 +- internal/middleware/ui_middleware.go | 10 +- internal/middleware/zerolog_middleware.go | 4 - internal/model/context.go | 6 +- internal/service/access_controls_service.go | 18 +-- internal/service/auth_service.go | 23 ++-- internal/service/docker_service.go | 41 +++--- internal/service/kubernetes_service.go | 90 ++++++------- internal/service/ldap_service.go | 66 +++------ internal/service/oauth_broker_service.go | 22 +-- internal/service/oauth_presets.go | 10 +- internal/service/oauth_service.go | 7 +- internal/service/oidc_service.go | 134 +++++++++---------- 23 files changed, 273 insertions(+), 391 deletions(-) diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index ce739fc9..90a20c3b 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -25,18 +25,9 @@ func (app *BootstrapApp) setupRouter() error { } contextMiddleware := middleware.NewContextMiddleware(app.log, app.runtime, app.services.authService, app.services.oauthBrokerService) - - err := contextMiddleware.Init() - - if err != nil { - return fmt.Errorf("failed to initialize context middleware: %w", err) - } - engine.Use(contextMiddleware.Middleware()) - uiMiddleware := middleware.NewUIMiddleware() - - err = uiMiddleware.Init() + uiMiddleware, err := middleware.NewUIMiddleware() if err != nil { return fmt.Errorf("failed to initialize UI middleware: %w", err) @@ -46,47 +37,18 @@ func (app *BootstrapApp) setupRouter() error { zerologMiddleware := middleware.NewZerologMiddleware(app.log) - err = zerologMiddleware.Init() - - if err != nil { - return fmt.Errorf("failed to initialize zerolog middleware: %w", err) - } - engine.Use(zerologMiddleware.Middleware()) apiRouter := engine.Group("/api") - contextController := controller.NewContextController(app.log, app.config, app.runtime, apiRouter) - - contextController.SetupRoutes() - - oauthController := controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) - - oauthController.SetupRoutes() - - oidcController := controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) - - oidcController.SetupRoutes() - - proxyController := controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) - - proxyController.SetupRoutes() - - userController := controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) - - userController.SetupRoutes() - - resourcesController := controller.NewResourcesController(app.config, &engine.RouterGroup) - - resourcesController.SetupRoutes() - - healthController := controller.NewHealthController(apiRouter) - - healthController.SetupRoutes() - - wellknownController := controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) - - wellknownController.SetupRoutes() + controller.NewContextController(app.log, app.config, app.runtime, apiRouter) + controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) + controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) + controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) + controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) + controller.NewResourcesController(app.config, &engine.RouterGroup) + controller.NewHealthController(apiRouter) + controller.NewWellKnownController(app.services.oidcService, &engine.RouterGroup) app.router = engine return nil diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 6692b038..1e850437 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -8,13 +8,10 @@ import ( ) func (app *BootstrapApp) setupServices() error { - ldapService := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) - - err := ldapService.Init() + ldapService, err := service.NewLdapService(app.log, app.config, app.ctx, &app.wg) if err != nil { app.log.App.Warn().Err(err).Msg("Failed to initialize LDAP connection, will continue without it") - ldapService.Unconfigure() } app.services.ldapService = ldapService @@ -27,9 +24,7 @@ func (app *BootstrapApp) setupServices() error { if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") - kubernetesService := service.NewKubernetesService(app.log, app.ctx, &app.wg) - - err = kubernetesService.Init() + kubernetesService, err := service.NewKubernetesService(app.log, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize kubernetes service: %w", err) @@ -40,9 +35,7 @@ func (app *BootstrapApp) setupServices() error { } else { app.log.App.Debug().Msg("Using Docker label provider") - dockerService := service.NewDockerService(app.log, app.ctx, &app.wg) - - err = dockerService.Init() + dockerService, err := service.NewDockerService(app.log, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize docker service: %w", err) @@ -52,39 +45,16 @@ func (app *BootstrapApp) setupServices() error { labelProvider = dockerService } - accessControlsService := service.NewAccessControlsService(app.log, labelProvider, app.config.Apps) - - err = accessControlsService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize access controls service: %w", err) - } - + accessControlsService := service.NewAccessControlsService(app.log, &labelProvider, app.config.Apps) app.services.accessControlService = accessControlsService - oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders) - - err = oauthBrokerService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize oauth broker service: %w", err) - } - + oauthBrokerService := service.NewOAuthBrokerService(app.log, app.runtime.OAuthProviders, app.ctx) app.services.oauthBrokerService = oauthBrokerService authService := service.NewAuthService(app.log, app.config, app.runtime, app.ctx, &app.wg, app.services.ldapService, app.queries, app.services.oauthBrokerService) - - err = authService.Init() - - if err != nil { - return fmt.Errorf("failed to initialize auth service: %w", err) - } - app.services.authService = authService - oidcService := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) - - err = oidcService.Init() + oidcService, err := service.NewOIDCService(app.log, app.config, app.runtime, app.queries, app.ctx, &app.wg) if err != nil { return fmt.Errorf("failed to initialize oidc service: %w", err) diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 491cb0b8..22ba0ffd 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -40,7 +40,6 @@ type ContextController struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - router *gin.RouterGroup } func NewContextController( @@ -49,22 +48,21 @@ func NewContextController( runtimeConfig model.RuntimeConfig, router *gin.RouterGroup, ) *ContextController { - if !config.UI.WarningsEnabled { - log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") - } - - return &ContextController{ + controller := &ContextController{ log: log, config: config, runtime: runtimeConfig, - router: router, } -} -func (controller *ContextController) SetupRoutes() { - contextGroup := controller.router.Group("/context") + if !config.UI.WarningsEnabled { + log.App.Warn().Msg("UI warnings are disabled. This may lead to security issues if you are not careful. Make sure to enable warnings in production environments.") + } + + contextGroup := router.Group("/context") contextGroup.GET("/user", controller.userContextHandler) contextGroup.GET("/app", controller.appContextHandler) + + return controller } func (controller *ContextController) userContextHandler(c *gin.Context) { diff --git a/internal/controller/health_controller.go b/internal/controller/health_controller.go index 1b9adbf9..8e84e62b 100644 --- a/internal/controller/health_controller.go +++ b/internal/controller/health_controller.go @@ -3,18 +3,15 @@ package controller import "github.com/gin-gonic/gin" type HealthController struct { - router *gin.RouterGroup } func NewHealthController(router *gin.RouterGroup) *HealthController { - return &HealthController{ - router: router, - } -} + controller := &HealthController{} + + router.GET("/healthz", controller.healthHandler) + router.HEAD("/healthz", controller.healthHandler) -func (controller *HealthController) SetupRoutes() { - controller.router.GET("/healthz", controller.healthHandler) - controller.router.HEAD("/healthz", controller.healthHandler) + return controller } func (controller *HealthController) healthHandler(c *gin.Context) { diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 902ee3de..803a4c04 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -24,7 +24,6 @@ type OAuthController struct { log *logger.Logger config model.Config runtime model.RuntimeConfig - router *gin.RouterGroup auth *service.AuthService } @@ -35,19 +34,18 @@ func NewOAuthController( router *gin.RouterGroup, auth *service.AuthService, ) *OAuthController { - return &OAuthController{ + controller := &OAuthController{ log: log, config: config, runtime: runtimeConfig, - router: router, auth: auth, } -} -func (controller *OAuthController) SetupRoutes() { - oauthGroup := controller.router.Group("/oauth") + oauthGroup := router.Group("/oauth") oauthGroup.GET("/url/:provider", controller.oauthURLHandler) oauthGroup.GET("/callback/:provider", controller.oauthCallbackHandler) + + return controller } func (controller *OAuthController) oauthURLHandler(c *gin.Context) { diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index e5a139c9..7e735159 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -17,9 +17,8 @@ import ( ) type OIDCController struct { - log *logger.Logger - router *gin.RouterGroup - oidc *service.OIDCService + log *logger.Logger + oidc *service.OIDCService } type AuthorizeCallback struct { @@ -60,20 +59,19 @@ func NewOIDCController( log *logger.Logger, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController { - return &OIDCController{ - log: log, - oidc: oidcService, - router: router, + controller := &OIDCController{ + log: log, + oidc: oidcService, } -} -func (controller *OIDCController) SetupRoutes() { - oidcGroup := controller.router.Group("/oidc") + oidcGroup := router.Group("/oidc") oidcGroup.GET("/clients/:id", controller.GetClientInfo) oidcGroup.POST("/authorize", controller.Authorize) oidcGroup.POST("/token", controller.Token) oidcGroup.GET("/userinfo", controller.Userinfo) oidcGroup.POST("/userinfo", controller.Userinfo) + + return controller } func (controller *OIDCController) GetClientInfo(c *gin.Context) { @@ -108,7 +106,7 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) { } func (controller *OIDCController) Authorize(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.authorizeError(c, errors.New("err_oidc_not_configured"), "OIDC not configured", "This instance is not configured for OIDC", "", "", "") return } @@ -198,7 +196,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) { } func (controller *OIDCController) Token(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", @@ -374,7 +372,7 @@ func (controller *OIDCController) Token(c *gin.Context) { } func (controller *OIDCController) Userinfo(c *gin.Context) { - if !controller.oidc.IsConfigured() { + if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") c.JSON(404, gin.H{ "error": "not_found", diff --git a/internal/controller/proxy_controller.go b/internal/controller/proxy_controller.go index b4bdc534..40969b83 100644 --- a/internal/controller/proxy_controller.go +++ b/internal/controller/proxy_controller.go @@ -53,7 +53,6 @@ type ProxyContext struct { type ProxyController struct { log *logger.Logger runtime model.RuntimeConfig - router *gin.RouterGroup acls *service.AccessControlsService auth *service.AuthService } @@ -65,18 +64,17 @@ func NewProxyController( acls *service.AccessControlsService, auth *service.AuthService, ) *ProxyController { - return &ProxyController{ + controller := &ProxyController{ log: log, runtime: runtime, - router: router, acls: acls, auth: auth, } -} -func (controller *ProxyController) SetupRoutes() { - proxyGroup := controller.router.Group("/auth") + proxyGroup := router.Group("/auth") proxyGroup.Any("/:proxy", controller.proxyHandler) + + return controller } func (controller *ProxyController) proxyHandler(c *gin.Context) { @@ -160,7 +158,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) { userContext, err := new(model.UserContext).NewFromGin(c) if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") + controller.log.App.Debug().Err(err).Msg("Failed to create user context from request, treating as unauthenticated") userContext = &model.UserContext{ Authenticated: false, } diff --git a/internal/controller/resources_controller.go b/internal/controller/resources_controller.go index b0fa3d70..54af733d 100644 --- a/internal/controller/resources_controller.go +++ b/internal/controller/resources_controller.go @@ -9,7 +9,6 @@ import ( type ResourcesController struct { config model.Config - router *gin.RouterGroup fileServer http.Handler } @@ -19,15 +18,14 @@ func NewResourcesController( ) *ResourcesController { fileServer := http.StripPrefix("/resources", http.FileServer(http.Dir(config.Resources.Path))) - return &ResourcesController{ + controller := &ResourcesController{ config: config, - router: router, fileServer: fileServer, } -} -func (controller *ResourcesController) SetupRoutes() { - controller.router.GET("/resources/*resource", controller.resourcesHandler) + router.GET("/resources/*resource", controller.resourcesHandler) + + return controller } func (controller *ResourcesController) resourcesHandler(c *gin.Context) { diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index b405bb03..f186ec0d 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -28,7 +28,6 @@ type TotpRequest struct { type UserController struct { log *logger.Logger runtime model.RuntimeConfig - router *gin.RouterGroup auth *service.AuthService } @@ -38,19 +37,18 @@ func NewUserController( router *gin.RouterGroup, auth *service.AuthService, ) *UserController { - return &UserController{ + controller := &UserController{ log: log, runtime: runtimeConfig, - router: router, auth: auth, } -} -func (controller *UserController) SetupRoutes() { - userGroup := controller.router.Group("/user") + userGroup := router.Group("/user") userGroup.POST("/login", controller.loginHandler) userGroup.POST("/logout", controller.logoutHandler) userGroup.POST("/totp", controller.totpHandler) + + return controller } func (controller *UserController) loginHandler(c *gin.Context) { diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index 951fdac2..a00876be 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -27,23 +27,29 @@ type OpenIDConnectConfiguration struct { } type WellKnownController struct { - router *gin.RouterGroup - oidc *service.OIDCService + oidc *service.OIDCService } func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) *WellKnownController { - return &WellKnownController{ - oidc: oidc, - router: router, + controller := &WellKnownController{ + oidc: oidc, } -} -func (controller *WellKnownController) SetupRoutes() { - controller.router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) - controller.router.GET("/.well-known/jwks.json", controller.JWKS) + router.GET("/.well-known/openid-configuration", controller.OpenIDConnectConfiguration) + router.GET("/.well-known/jwks.json", controller.JWKS) + + return controller } func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": "500", + "message": "OIDC service not configured", + }) + return + } + issuer := controller.oidc.GetIssuer() c.JSON(200, OpenIDConnectConfiguration{ Issuer: issuer, @@ -65,6 +71,14 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context } func (controller *WellKnownController) JWKS(c *gin.Context) { + if controller.oidc == nil { + c.JSON(500, gin.H{ + "status": "500", + "message": "OIDC service not configured", + }) + return + } + jwks, err := controller.oidc.GetJWK() if err != nil { diff --git a/internal/middleware/context_middleware.go b/internal/middleware/context_middleware.go index 211f931c..6e6bbe56 100644 --- a/internal/middleware/context_middleware.go +++ b/internal/middleware/context_middleware.go @@ -56,10 +56,6 @@ func NewContextMiddleware( } } -func (m *ContextMiddleware) Init() error { - return nil -} - func (m *ContextMiddleware) Middleware() gin.HandlerFunc { return func(c *gin.Context) { if m.isIgnorePath(c.Request.Method + " " + c.Request.URL.Path) { @@ -82,7 +78,7 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc { c.Next() return } else { - m.log.App.Error().Msgf("Error authenticating session cookie: %v", err) + m.log.App.Debug().Msgf("Error authenticating session cookie: %v", err) } } diff --git a/internal/middleware/ui_middleware.go b/internal/middleware/ui_middleware.go index 67b05b86..2b8d6b8a 100644 --- a/internal/middleware/ui_middleware.go +++ b/internal/middleware/ui_middleware.go @@ -18,21 +18,19 @@ type UIMiddleware struct { uiFileServer http.Handler } -func NewUIMiddleware() *UIMiddleware { - return &UIMiddleware{} -} +func NewUIMiddleware() (*UIMiddleware, error) { + m := &UIMiddleware{} -func (m *UIMiddleware) Init() error { ui, err := fs.Sub(assets.FrontendAssets, "dist") if err != nil { - return err + return nil, fmt.Errorf("failed to load ui assets: %w", err) } m.uiFs = ui m.uiFileServer = http.FileServerFS(ui) - return nil + return m, nil } func (m *UIMiddleware) Middleware() gin.HandlerFunc { diff --git a/internal/middleware/zerolog_middleware.go b/internal/middleware/zerolog_middleware.go index 070da695..9870a70a 100644 --- a/internal/middleware/zerolog_middleware.go +++ b/internal/middleware/zerolog_middleware.go @@ -27,10 +27,6 @@ func NewZerologMiddleware(log *logger.Logger) *ZerologMiddleware { } } -func (m *ZerologMiddleware) Init() error { - return nil -} - func (m *ZerologMiddleware) logPath(path string) bool { for _, prefix := range loggerSkipPathsPrefix { if strings.HasPrefix(path, prefix) { diff --git a/internal/model/context.go b/internal/model/context.go index 7384ebe8..c459a620 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -8,6 +8,10 @@ import ( "github.com/tinyauthapp/tinyauth/internal/repository" ) +var ( + ErrUserContextNotFound = errors.New("user context not found") +) + type ProviderType int const ( @@ -74,7 +78,7 @@ func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) { userContextValue, exists := ginctx.Get("context") if !exists { - return nil, errors.New("failed to get user context") + return nil, ErrUserContextNotFound } userContext, ok := userContextValue.(*UserContext) diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index 9bfe834d..f6e3cbd2 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -13,13 +13,13 @@ type LabelProviderImpl interface { type AccessControlsService struct { log *logger.Logger - labelProvider LabelProviderImpl + labelProvider *LabelProviderImpl static map[string]model.App } func NewAccessControlsService( log *logger.Logger, - labelProvider LabelProviderImpl, + labelProvider *LabelProviderImpl, static map[string]model.App) *AccessControlsService { return &AccessControlsService{ log: log, @@ -28,10 +28,6 @@ func NewAccessControlsService( } } -func (acls *AccessControlsService) Init() error { - return nil // No initialization needed -} - func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App { var appAcls *model.App for app, config := range acls.static { @@ -59,7 +55,11 @@ func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, return app, nil } - // Fallback to label provider - acls.log.App.Debug().Msg("Using label provider for app") - return acls.labelProvider.GetLabels(domain) + // If we have a label provider configured, try to get ACLs from it + if acls.labelProvider != nil { + return (*acls.labelProvider).GetLabels(domain) + } + + // no labels + return nil, nil } diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index e47d31cb..ed882438 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -77,7 +77,6 @@ type AuthService struct { config model.Config runtime model.RuntimeConfig context context.Context - wg *sync.WaitGroup ldap *LdapService queries *repository.Queries @@ -98,17 +97,16 @@ func NewAuthService( log *logger.Logger, config model.Config, runtime model.RuntimeConfig, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService, ) *AuthService { - return &AuthService{ + service := &AuthService{ log: log, runtime: runtime, - context: context, - wg: wg, + context: ctx, config: config, loginAttempts: make(map[string]*LoginAttempt), ldapGroupsCache: make(map[string]*LdapGroupsCache), @@ -117,11 +115,10 @@ func NewAuthService( queries: queries, oauthBroker: oauthBroker, } -} -func (auth *AuthService) Init() error { - auth.wg.Go(auth.CleanupOAuthSessionsRoutine) - return nil + wg.Go(service.CleanupOAuthSessionsRoutine) + + return service } func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { @@ -132,7 +129,7 @@ func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) }, nil } - if auth.ldap.IsConfigured() { + if auth.ldap != nil { userDN, err := auth.ldap.GetUserDN(username) if err != nil { @@ -157,7 +154,7 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str } return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) case model.UserLDAP: - if auth.ldap.IsConfigured() { + if auth.ldap != nil { err := auth.ldap.Bind(search.Username, password) if err != nil { return fmt.Errorf("failed to bind to ldap user: %w", err) @@ -189,7 +186,7 @@ func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { } func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { - if !auth.ldap.IsConfigured() { + if auth.ldap == nil { return nil, errors.New("ldap service not configured") } @@ -459,7 +456,7 @@ func (auth *AuthService) LocalAuthConfigured() bool { } func (auth *AuthService) LDAPAuthConfigured() bool { - return auth.ldap.IsConfigured() + return auth.ldap != nil } func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { diff --git a/internal/service/docker_service.go b/internal/service/docker_service.go index 55579607..9d077c53 100644 --- a/internal/service/docker_service.go +++ b/internal/service/docker_service.go @@ -17,49 +17,42 @@ type DockerService struct { log *logger.Logger client *client.Client context context.Context - wg *sync.WaitGroup isConnected bool } func NewDockerService( log *logger.Logger, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *DockerService { - return &DockerService{ - log: log, - context: context, - wg: wg, - } -} +) (*DockerService, error) { -func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return err + return nil, err } - client.NegotiateAPIVersion(docker.context) - - docker.client = client + client.NegotiateAPIVersion(ctx) - _, err = docker.client.Ping(docker.context) + _, err = client.Ping(ctx) if err != nil { - docker.log.App.Debug().Err(err).Msg("Docker not connected") - docker.isConnected = false - docker.client = nil - docker.context = nil - return nil + log.App.Debug().Err(err).Msg("Docker not connected") + return nil, nil + } + + service := &DockerService{ + log: log, + client: client, + context: ctx, } - docker.isConnected = true - docker.log.App.Debug().Msg("Docker connected successfully") + service.isConnected = true + service.log.App.Debug().Msg("Docker connected successfully") - docker.wg.Go(docker.watchAndClose) + wg.Go(service.watchAndClose) - return nil + return service, nil } func (docker *DockerService) getContainers() ([]container.Summary, error) { diff --git a/internal/service/kubernetes_service.go b/internal/service/kubernetes_service.go index 1af6b4da..8976cb54 100644 --- a/internal/service/kubernetes_service.go +++ b/internal/service/kubernetes_service.go @@ -38,7 +38,6 @@ type ingressApp struct { type KubernetesService struct { log *logger.Logger ctx context.Context - wg *sync.WaitGroup client dynamic.Interface started bool @@ -50,17 +49,53 @@ type KubernetesService struct { func NewKubernetesService( log *logger.Logger, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *KubernetesService { - return &KubernetesService{ +) (*KubernetesService, error) { + cfg, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("failed to get in-cluster kubernetes config: %w", err) + } + + client, err := dynamic.NewForConfig(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create kubernetes client: %w", err) + } + + gvr := schema.GroupVersionResource{ + Group: "networking.k8s.io", + Version: "v1", + Resource: "ingresses", + } + + accessCtx, accessCancel := context.WithTimeout(ctx, 5*time.Second) + defer accessCancel() + + _, err = client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) + if err != nil { + log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") + return nil, fmt.Errorf("failed to access ingress api: %w", err) + } + + log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") + + service := &KubernetesService{ log: log, - ctx: context, - wg: wg, + ctx: ctx, + client: client, ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), } + + wg.Go(func() { + service.watchGVR(gvr) + }) + + service.started = true + log.App.Debug().Msg("Kubernetes label provider started successfully") + + return service, nil } func (k *KubernetesService) addIngressApps(namespace, name string, apps []ingressApp) { @@ -226,7 +261,7 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { for { select { case <-k.ctx.Done(): - k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Context cancelled, stopping watcher") + k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Shutting down kubernetes watcher") return case <-resyncTicker.C: if err := k.resyncGVR(gvr); err != nil { @@ -251,47 +286,6 @@ func (k *KubernetesService) watchGVR(gvr schema.GroupVersionResource) { } } -func (k *KubernetesService) Init() error { - var cfg *rest.Config - var err error - - cfg, err = rest.InClusterConfig() - if err != nil { - return fmt.Errorf("failed to get in-cluster Kubernetes config: %w", err) - } - - client, err := dynamic.NewForConfig(cfg) - if err != nil { - return fmt.Errorf("failed to create Kubernetes client: %w", err) - } - - k.client = client - gvr := schema.GroupVersionResource{ - Group: "networking.k8s.io", - Version: "v1", - Resource: "ingresses", - } - - accessCtx, accessCancel := context.WithTimeout(k.ctx, 5*time.Second) - defer accessCancel() - - _, err = k.client.Resource(gvr).List(accessCtx, metav1.ListOptions{Limit: 1}) - if err != nil { - k.log.App.Warn().Err(err).Str("api", gvr.GroupVersion().String()).Msg("Failed to access Ingress API, Kubernetes label provider will be disabled") - k.started = false - return nil - } - - k.log.App.Debug().Str("api", gvr.GroupVersion().String()).Msg("Successfully accessed Ingress API, starting watcher") - k.wg.Go(func() { - k.watchGVR(gvr) - }) - - k.started = true - k.log.App.Debug().Msg("Kubernetes label provider started successfully") - return nil -} - func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { if !k.started { k.log.App.Debug().Str("domain", appDomain).Msg("Kubernetes label provider not started, skipping") diff --git a/internal/service/ldap_service.go b/internal/service/ldap_service.go index 35d3d887..9c031206 100644 --- a/internal/service/ldap_service.go +++ b/internal/service/ldap_service.go @@ -17,63 +17,39 @@ type LdapService struct { log *logger.Logger config model.Config context context.Context - wg *sync.WaitGroup - conn *ldapgo.Conn - mutex sync.RWMutex - cert *tls.Certificate - isConfigured bool + conn *ldapgo.Conn + mutex sync.RWMutex + cert *tls.Certificate } func NewLdapService( log *logger.Logger, config model.Config, - context context.Context, + ctx context.Context, wg *sync.WaitGroup, -) *LdapService { - return &LdapService{ +) (*LdapService, error) { + if config.LDAP.Address == "" { + return nil, nil + } + + ldap := &LdapService{ log: log, config: config, - context: context, - wg: wg, + context: ctx, } -} - -func (ldap *LdapService) IsConfigured() bool { - return ldap.isConfigured -} -func (ldap *LdapService) Unconfigure() error { - if !ldap.isConfigured { - return nil - } + // Check whether authentication with client certificate is possible + if config.LDAP.AuthCert != "" && config.LDAP.AuthKey != "" { + cert, err := tls.LoadX509KeyPair(config.LDAP.AuthCert, config.LDAP.AuthKey) - if ldap.conn != nil { - if err := ldap.conn.Close(); err != nil { - return fmt.Errorf("failed to close LDAP connection: %w", err) + if err != nil { + return nil, fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) } - } - - ldap.isConfigured = false - return nil -} -func (ldap *LdapService) Init() error { - if ldap.config.LDAP.Address == "" { - ldap.isConfigured = false - return nil - } + log.App.Info().Msg("LDAP mTLS authentication configured successfully") - ldap.isConfigured = true - - // Check whether authentication with client certificate is possible - if ldap.config.LDAP.AuthCert != "" && ldap.config.LDAP.AuthKey != "" { - cert, err := tls.LoadX509KeyPair(ldap.config.LDAP.AuthCert, ldap.config.LDAP.AuthKey) - if err != nil { - return fmt.Errorf("failed to initialize LDAP with mTLS authentication: %w", err) - } ldap.cert = &cert - ldap.log.App.Info().Msg("LDAP mTLS authentication configured successfully") // TODO: Add optional extra CA certificates, instead of `InsecureSkipVerify` /* @@ -86,12 +62,14 @@ func (ldap *LdapService) Init() error { } */ } + _, err := ldap.connect() + if err != nil { - return fmt.Errorf("failed to connect to LDAP server: %w", err) + return nil, fmt.Errorf("failed to connect to ldap server: %w", err) } - ldap.wg.Go(func() { + wg.Go(func() { ldap.log.App.Debug().Msg("Starting LDAP connection heartbeat routine") ticker := time.NewTicker(5 * time.Minute) @@ -116,7 +94,7 @@ func (ldap *LdapService) Init() error { } }) - return nil + return ldap, nil } func (ldap *LdapService) connect() (*ldapgo.Conn, error) { diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 8d693ad9..fdb5e1e0 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,6 +1,8 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" @@ -25,7 +27,7 @@ type OAuthBrokerService struct { configs map[string]model.OAuthServiceConfig } -var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ +var presets = map[string]func(config model.OAuthServiceConfig, ctx context.Context) *OAuthService{ "github": newGitHubOAuthService, "google": newGoogleOAuthService, } @@ -33,25 +35,25 @@ var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ func NewOAuthBrokerService( log *logger.Logger, configs map[string]model.OAuthServiceConfig, + ctx context.Context, ) *OAuthBrokerService { - return &OAuthBrokerService{ + service := &OAuthBrokerService{ log: log, services: make(map[string]OAuthServiceImpl), configs: configs, } -} -func (broker *OAuthBrokerService) Init() error { - for name, cfg := range broker.configs { + for name, cfg := range configs { if presetFunc, exists := presets[name]; exists { - broker.services[name] = presetFunc(cfg) - broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + service.services[name] = presetFunc(cfg, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - broker.services[name] = NewOAuthService(cfg, name) - broker.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") + service.services[name] = NewOAuthService(cfg, name, ctx) + service.log.App.Debug().Str("service", name).Msg("Loaded OAuth service from custom config") } } - return nil + + return service } func (broker *OAuthBrokerService) GetConfiguredServices() []string { diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index ef21fa60..d620d54d 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -1,23 +1,25 @@ package service import ( + "context" + "github.com/tinyauthapp/tinyauth/internal/model" "golang.org/x/oauth2/endpoints" ) -func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGoogleOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"openid", "email", "profile"} config.Scopes = scopes config.AuthURL = endpoints.Google.AuthURL config.TokenURL = endpoints.Google.TokenURL config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - return NewOAuthService(config, "google") + return NewOAuthService(config, "google", ctx) } -func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { +func newGitHubOAuthService(config model.OAuthServiceConfig, ctx context.Context) *OAuthService { scopes := []string{"read:user", "user:email"} config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL config.TokenURL = endpoints.GitHub.TokenURL - return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) + return NewOAuthService(config, "github", ctx).WithUserinfoExtractor(githubExtractor) } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 11b0be9c..0def3143 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -20,7 +20,7 @@ type OAuthService struct { id string } -func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { +func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Context) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -29,8 +29,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { }, }, } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + vctx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) return &OAuthService{ serviceCfg: config, @@ -44,7 +43,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { TokenURL: config.TokenURL, }, }, - ctx: ctx, + ctx: vctx, userinfoExtractor: defaultExtractor, id: id, } diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 7d4d8d71..02c33199 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -118,13 +118,11 @@ type OIDCService struct { runtime model.RuntimeConfig queries *repository.Queries context context.Context - wg *sync.WaitGroup - clients map[string]model.OIDCClientConfig - privateKey *rsa.PrivateKey - publicKey crypto.PublicKey - issuer string - isConfigured bool + clients map[string]model.OIDCClientConfig + privateKey *rsa.PrivateKey + publicKey crypto.PublicKey + issuer string } func NewOIDCService( @@ -132,162 +130,156 @@ func NewOIDCService( config model.Config, runtime model.RuntimeConfig, queries *repository.Queries, - context context.Context, - wg *sync.WaitGroup) *OIDCService { - return &OIDCService{ - log: log, - config: config, - runtime: runtime, - queries: queries, - context: context, - wg: wg, - } -} - -func (service *OIDCService) IsConfigured() bool { - return service.isConfigured -} - -func (service *OIDCService) Init() error { + ctx context.Context, + wg *sync.WaitGroup) (*OIDCService, error) { // If not configured, skip init - if len(service.runtime.OIDCClients) == 0 { - service.isConfigured = false - return nil + if len(runtime.OIDCClients) == 0 { + return nil, nil } - service.isConfigured = true - // Ensure issuer is https - uissuer, err := url.Parse(service.runtime.AppURL) + uissuer, err := url.Parse(runtime.AppURL) if err != nil { - return err + return nil, fmt.Errorf("failed to parse app url: %w", err) } if uissuer.Scheme != "https" { - return errors.New("issuer must be https") + return nil, errors.New("issuer must be https") } - service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) + issuer := fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host) // Create/load private and public keys - if strings.TrimSpace(service.config.OIDC.PrivateKeyPath) == "" || - strings.TrimSpace(service.config.OIDC.PublicKeyPath) == "" { - return errors.New("private key path and public key path are required") + if strings.TrimSpace(config.OIDC.PrivateKeyPath) == "" || + strings.TrimSpace(config.OIDC.PublicKeyPath) == "" { + return nil, errors.New("private key path and public key path are required") } var privateKey *rsa.PrivateKey - fprivateKey, err := os.ReadFile(service.config.OIDC.PrivateKeyPath) + fprivateKey, err := os.ReadFile(config.OIDC.PrivateKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, err } if errors.Is(err, os.ErrNotExist) { privateKey, err = rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return err + return nil, fmt.Errorf("failed to generate private key: %w", err) } der := x509.MarshalPKCS1PrivateKey(privateKey) if der == nil { - return errors.New("failed to marshal private key") + return nil, errors.New("failed to marshal private key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }) - service.log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") - err = os.WriteFile(service.config.OIDC.PrivateKeyPath, encoded, 0600) + log.App.Trace().Str("type", "RSA PRIVATE KEY").Msg("Generated private RSA key") + err = os.WriteFile(config.OIDC.PrivateKeyPath, encoded, 0600) if err != nil { - return err + return nil, fmt.Errorf("failed to write private key to file: %w", err) } - service.privateKey = privateKey } else { block, _ := pem.Decode(fprivateKey) if block == nil { - return errors.New("failed to decode private key") + return nil, errors.New("failed to decode private key") } - service.log.App.Trace().Str("type", block.Type).Msg("Loaded private key") + log.App.Trace().Str("type", block.Type).Msg("Loaded private key") privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse private key: %w", err) } - service.privateKey = privateKey } - fpublicKey, err := os.ReadFile(service.config.OIDC.PublicKeyPath) + var publicKey crypto.PublicKey + + fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath) if err != nil && !errors.Is(err, os.ErrNotExist) { - return err + return nil, fmt.Errorf("failed to read public key: %w", err) } if errors.Is(err, os.ErrNotExist) { - publicKey := service.privateKey.Public() + publicKey = privateKey.Public() der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey)) if der == nil { - return errors.New("failed to marshal public key") + return nil, errors.New("failed to marshal public key") } encoded := pem.EncodeToMemory(&pem.Block{ Type: "RSA PUBLIC KEY", Bytes: der, }) - service.log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") - err = os.WriteFile(service.config.OIDC.PublicKeyPath, encoded, 0644) + log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key") + err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644) if err != nil { - return err + return nil, err } - service.publicKey = publicKey } else { block, _ := pem.Decode(fpublicKey) if block == nil { - return errors.New("failed to decode public key") + return nil, errors.New("failed to decode public key") } - service.log.App.Trace().Str("type", block.Type).Msg("Loaded public key") + log.App.Trace().Str("type", block.Type).Msg("Loaded public key") switch block.Type { case "RSA PUBLIC KEY": - publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey case "PUBLIC KEY": publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) if err != nil { - return err + return nil, fmt.Errorf("failed to parse public key: %w", err) } - service.publicKey = publicKey.(crypto.PublicKey) + publicKey = publicKey.(crypto.PublicKey) default: - return fmt.Errorf("unsupported public key type: %s", block.Type) + return nil, fmt.Errorf("unsupported public key type: %s", block.Type) } } // We will reorganize the client into a map with the client ID as the key - service.clients = make(map[string]model.OIDCClientConfig) + clients := make(map[string]model.OIDCClientConfig) - for id, client := range service.config.OIDC.Clients { + for id, client := range config.OIDC.Clients { client.ID = id if client.Name == "" { client.Name = utils.Capitalize(client.ID) } - service.clients[client.ClientID] = client + clients[client.ClientID] = client } // Load the client secrets from files if they exist - for id, client := range service.clients { + for id, client := range clients { secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile) if secret != "" { client.ClientSecret = secret } client.ClientSecretFile = "" - service.clients[id] = client - service.log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") + clients[id] = client + log.App.Debug().Str("clientId", client.ClientID).Msg("Loaded OIDC client configuration") + } + + // Initialize the service + service := &OIDCService{ + log: log, + config: config, + runtime: runtime, + queries: queries, + context: ctx, + + clients: clients, + privateKey: privateKey, + publicKey: publicKey, + issuer: issuer, } // Start cleanup routine - service.wg.Go(service.cleanupRoutine) + wg.Go(service.cleanupRoutine) - return nil + return service, nil } func (service *OIDCService) GetIssuer() string { From 9fccb630977dc1d19818291ef260eb04d72e1da9 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:17:35 +0300 Subject: [PATCH 10/22] tests: fix controller tests --- internal/bootstrap/db_bootstrap.go | 4 + internal/controller/context_controller.go | 2 +- .../controller/context_controller_test.go | 46 +++----- internal/controller/controller_test.go | 106 ++++++++++++++++++ internal/controller/health_controller_test.go | 7 +- internal/controller/oidc_controller_test.go | 44 +++----- internal/controller/proxy_controller_test.go | 67 +++-------- .../controller/resources_controller_test.go | 22 ++-- internal/controller/user_controller_test.go | 81 +++---------- .../controller/well_known_controller_test.go | 56 ++++----- 10 files changed, 202 insertions(+), 233 deletions(-) create mode 100644 internal/controller/controller_test.go diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 4644036b..f9554ddc 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -56,3 +56,7 @@ func (app *BootstrapApp) SetupDatabase() error { app.db = db return nil } + +func (app *BootstrapApp) GetDB() *sql.DB { + return app.db +} diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 22ba0ffd..8d9f5fa2 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -95,7 +95,7 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { } func (controller *ContextController) appContextHandler(c *gin.Context) { - appUrl, err := url.Parse(controller.config.AppURL) + appUrl, err := url.Parse(controller.runtime.AppURL) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to parse app URL") diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 12a8e22b..162fd166 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -11,27 +11,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestContextController(t *testing.T) { - tlog.NewTestLogger().Init() - controllerConfig := controller.ContextControllerConfig{ - Providers: []controller.Provider{ - { - Name: "Local", - ID: "local", - OAuth: false, - }, - }, - Title: "Tinyauth", - AppURL: "https://tinyauth.example.com", - CookieDomain: "example.com", - ForgotPasswordMessage: "foo", - BackgroundImage: "/background.jpg", - OAuthAutoRedirect: "none", - WarningsEnabled: true, - } + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := createTestConfigs(t) tests := []struct { description string @@ -47,14 +34,14 @@ func TestContextController(t *testing.T) { expectedAppContextResponse := controller.AppContextResponse{ Status: 200, Message: "Success", - Providers: controllerConfig.Providers, - Title: controllerConfig.Title, - AppURL: controllerConfig.AppURL, - CookieDomain: controllerConfig.CookieDomain, - ForgotPasswordMessage: controllerConfig.ForgotPasswordMessage, - BackgroundImage: controllerConfig.BackgroundImage, - OAuthAutoRedirect: controllerConfig.OAuthAutoRedirect, - WarningsEnabled: controllerConfig.WarningsEnabled, + Providers: runtime.ConfiguredProviders, + Title: cfg.UI.Title, + AppURL: runtime.AppURL, + CookieDomain: runtime.CookieDomain, + ForgotPasswordMessage: cfg.UI.ForgotPasswordMessage, + BackgroundImage: cfg.UI.BackgroundImage, + OAuthAutoRedirect: cfg.OAuth.AutoRedirect, + WarningsEnabled: cfg.UI.WarningsEnabled, } bytes, err := json.Marshal(expectedAppContextResponse) assert.NoError(t, err) @@ -86,7 +73,7 @@ func TestContextController(t *testing.T) { BaseContext: model.BaseContext{ Username: "johndoe", Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), }, }, }) @@ -100,7 +87,7 @@ func TestContextController(t *testing.T) { IsLoggedIn: true, Username: "johndoe", Name: "John Doe", - Email: utils.CompileUserEmail("johndoe", controllerConfig.CookieDomain), + Email: utils.CompileUserEmail("johndoe", runtime.CookieDomain), Provider: "local", } bytes, err := json.Marshal(expectedUserContextResponse) @@ -121,8 +108,7 @@ func TestContextController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - contextController := controller.NewContextController(controllerConfig, group) - contextController.SetupRoutes() + controller.NewContextController(log, cfg, runtime, group) recorder := httptest.NewRecorder() diff --git a/internal/controller/controller_test.go b/internal/controller/controller_test.go new file mode 100644 index 00000000..675f345f --- /dev/null +++ b/internal/controller/controller_test.go @@ -0,0 +1,106 @@ +package controller_test + +import ( + "path" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "golang.org/x/crypto/bcrypt" +) + +var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" + +func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { + tempDir := t.TempDir() + + config := model.Config{ + UI: model.UIConfig{ + Title: "Tinyauth Test", + ForgotPasswordMessage: "foo", + BackgroundImage: "/background.jpg", + WarningsEnabled: true, + }, + OAuth: model.OAuthConfig{ + AutoRedirect: "none", + }, + OIDC: model.OIDCConfig{ + Clients: map[string]model.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: path.Join(tempDir, "key.pem"), + PublicKeyPath: path.Join(tempDir, "key.pub"), + }, + Auth: model.AuthConfig{ + SessionExpiry: 10, + LoginTimeout: 10, + LoginMaxRetries: 3, + }, + Database: model.DatabaseConfig{ + Path: path.Join(tempDir, "test.db"), + }, + Resources: model.ResourcesConfig{ + Enabled: true, + Path: path.Join(tempDir, "resources"), + }, + } + + passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + require.NoError(t, err) + + runtime := model.RuntimeConfig{ + ConfiguredProviders: []model.Provider{ + { + Name: "Local", + ID: "local", + OAuth: false, + }, + }, + LocalUsers: []model.LocalUser{ + { + Username: "testuser", + Password: string(passwd), + }, + { + Username: "totpuser", + Password: string(passwd), + TOTPSecret: testingTOTPSecret, + }, + { + Username: "attruser", + Password: string(passwd), + Attributes: model.UserAttributes{ + Name: "Alice Smith", + Email: "alice@example.com", + }, + }, + { + Username: "attrtotpuser", + Password: string(passwd), + TOTPSecret: testingTOTPSecret, + Attributes: model.UserAttributes{ + Name: "Bob Jones", + Email: "bob@example.com", + }, + }, + }, + CookieDomain: "example.com", + AppURL: "https://tinyauth.example.com", + SessionCookieName: "tinyauth-session", + OIDCClients: func() []model.OIDCClientConfig { + var clients []model.OIDCClientConfig + for id, client := range config.OIDC.Clients { + client.ID = id + clients = append(clients, client) + } + return clients + }(), + } + + return config, runtime +} diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go index d1bed3b6..2d8ef349 100644 --- a/internal/controller/health_controller_test.go +++ b/internal/controller/health_controller_test.go @@ -7,13 +7,11 @@ import ( "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" + "github.com/tinyauthapp/tinyauth/internal/controller" ) func TestHealthController(t *testing.T) { - tlog.NewTestLogger().Init() tests := []struct { description string path string @@ -56,8 +54,7 @@ func TestHealthController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - healthController := controller.NewHealthController(group) - healthController.SetupRoutes() + controller.NewHealthController(group) recorder := httptest.NewRecorder() diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 150540fc..9d552bbe 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -1,13 +1,14 @@ package controller_test import ( + "context" "crypto/sha256" "encoding/base64" "encoding/json" "net/http/httptest" "net/url" - "path" "strings" + "sync" "testing" "github.com/gin-gonic/gin" @@ -19,29 +20,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestOIDCController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]model.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } - - controllerCfg := controller.OIDCControllerConfig{} + cfg, runtime := createTestConfigs(t) simpleCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ @@ -852,14 +838,16 @@ func TestOIDCController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + queries := repository.New(app.GetDB()) + + wg := &sync.WaitGroup{} + + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, context.TODO(), wg) require.NoError(t, err) for _, test := range tests { @@ -873,8 +861,7 @@ func TestOIDCController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - oidcController := controller.NewOIDCController(controllerCfg, oidcService, group) - oidcController.SetupRoutes() + controller.NewOIDCController(log, oidcService, group) recorder := httptest.NewRecorder() @@ -883,7 +870,6 @@ func TestOIDCController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 7b2e3202..c1603d14 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -1,8 +1,9 @@ package controller_test import ( + "context" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" @@ -13,35 +14,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestProxyController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() + log := logger.NewLogger().WithTestConfig() + log.Init() - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } - - controllerCfg := controller.ProxyControllerConfig{ - AppURL: "https://tinyauth.example.com", - } + cfg, runtime := createTestConfigs(t) acls := map[string]model.App{ "app_path_allow": { @@ -398,32 +378,19 @@ func TestProxyController(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - - app := bootstrap.NewBootstrapApp(model.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) - - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) - - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) + wg := &sync.WaitGroup{} + ctx := context.TODO() - aclsService := service.NewAccessControlsService(docker, acls) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) + aclsService := service.NewAccessControlsService(log, nil, acls) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -438,15 +405,13 @@ func TestProxyController(t *testing.T) { recorder := httptest.NewRecorder() - proxyController := controller.NewProxyController(controllerCfg, group, aclsService, authService) - proxyController.SetupRoutes() + controller.NewProxyController(log, runtime, group, aclsService, authService) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index a1996be3..8c8554d3 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -3,26 +3,19 @@ package controller_test import ( "net/http/httptest" "os" - "path" + "path/filepath" "testing" "github.com/gin-gonic/gin" - "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/controller" ) func TestResourcesController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() - - resourcesControllerCfg := controller.ResourcesControllerConfig{ - Path: path.Join(tempDir, "resources"), - Enabled: true, - } + cfg, _ := createTestConfigs(t) - err := os.Mkdir(resourcesControllerCfg.Path, 0777) + err := os.MkdirAll(cfg.Resources.Path, 0777) require.NoError(t, err) type testCase struct { @@ -61,11 +54,11 @@ func TestResourcesController(t *testing.T) { }, } - testFilePath := resourcesControllerCfg.Path + "/testfile.txt" + testFilePath := cfg.Resources.Path + "/testfile.txt" err = os.WriteFile(testFilePath, []byte("This is a test file."), 0777) require.NoError(t, err) - testFilePathParent := tempDir + "/somefile.txt" + testFilePathParent := filepath.Dir(cfg.Resources.Path) + "/somefile.txt" err = os.WriteFile(testFilePathParent, []byte("This file should not be accessible."), 0777) require.NoError(t, err) @@ -75,8 +68,7 @@ func TestResourcesController(t *testing.T) { group := router.Group("/") gin.SetMode(gin.TestMode) - resourcesController := controller.NewResourcesController(resourcesControllerCfg, group) - resourcesController.SetupRoutes() + controller.NewResourcesController(cfg, group) recorder := httptest.NewRecorder() test.run(t, router, recorder) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4863c16e..bfe232fc 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -5,8 +5,8 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "path" "strings" + "sync" "testing" "time" @@ -19,53 +19,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestUserController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() - - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - { - Username: "attruser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - Attributes: model.UserAttributes{ - Name: "Alice Smith", - Email: "alice@example.com", - }, - }, - { - Username: "attrtotpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - Attributes: model.UserAttributes{ - Name: "Bob Jones", - Email: "bob@example.com", - }, - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } + log := logger.NewLogger().WithTestConfig() + log.Init() - userControllerCfg := controller.UserControllerConfig{ - CookieDomain: "example.com", - SessionCookieName: "tinyauth-session", - } + cfg, runtime := createTestConfigs(t) totpCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ @@ -111,14 +72,12 @@ func TestUserController(t *testing.T) { }) } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - - app := bootstrap.NewBootstrapApp(model.Config{}) + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) type testCase struct { description string @@ -456,21 +415,11 @@ func TestUserController(t *testing.T) { }, } - docker := service.NewDockerService() - err = docker.Init() - require.NoError(t, err) - - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() - require.NoError(t, err) - - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) + ctx := context.TODO() + wg := &sync.WaitGroup{} - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) beforeEach := func() { // Clear failed login attempts before each test @@ -489,8 +438,7 @@ func TestUserController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - userController := controller.NewUserController(userControllerCfg, group, authService) - userController.SetupRoutes() + controller.NewUserController(log, runtime, group, authService) recorder := httptest.NewRecorder() @@ -499,7 +447,6 @@ func TestUserController(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7dcf2bdc..9d6c7483 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -1,10 +1,11 @@ package controller_test import ( + "context" "encoding/json" "fmt" "net/http/httptest" - "path" + "sync" "testing" "github.com/gin-gonic/gin" @@ -12,30 +13,16 @@ import ( "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" - "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestWellKnownController(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() - - oidcServiceCfg := service.OIDCServiceConfig{ - Clients: map[string]model.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - Issuer: "https://tinyauth.example.com", - SessionExpiry: 500, - } + log := logger.NewLogger().WithTestConfig() + log.Init() + + cfg, runtime := createTestConfigs(t) type testCase struct { description string @@ -56,11 +43,11 @@ func TestWellKnownController(t *testing.T) { assert.NoError(t, err) expected := controller.OpenIDConnectConfiguration{ - Issuer: oidcServiceCfg.Issuer, - AuthorizationEndpoint: fmt.Sprintf("%s/authorize", oidcServiceCfg.Issuer), - TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", oidcServiceCfg.Issuer), - UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", oidcServiceCfg.Issuer), - JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", oidcServiceCfg.Issuer), + Issuer: runtime.AppURL, + AuthorizationEndpoint: fmt.Sprintf("%s/authorize", runtime.AppURL), + TokenEndpoint: fmt.Sprintf("%s/api/oidc/token", runtime.AppURL), + UserinfoEndpoint: fmt.Sprintf("%s/api/oidc/userinfo", runtime.AppURL), + JwksUri: fmt.Sprintf("%s/.well-known/jwks.json", runtime.AppURL), ScopesSupported: service.SupportedScopes, ResponseTypesSupported: service.SupportedResponseTypes, GrantTypesSupported: service.SupportedGrantTypes, @@ -101,16 +88,17 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + ctx := context.TODO() + wg := &sync.WaitGroup{} + + app := bootstrap.NewBootstrapApp(cfg) - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + err := app.SetupDatabase() require.NoError(t, err) - queries := repository.New(db) + queries := repository.New(app.GetDB()) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() - require.NoError(t, err) + oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) for _, test := range tests { t.Run(test.description, func(t *testing.T) { @@ -119,15 +107,13 @@ func TestWellKnownController(t *testing.T) { recorder := httptest.NewRecorder() - wellKnownController := controller.NewWellKnownController(controller.WellKnownControllerConfig{}, oidcService, router) - wellKnownController.SetupRoutes() + controller.NewWellKnownController(oidcService, &router.RouterGroup) test.run(t, router, recorder) }) } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } From c7e9fade039b07e1e83daa688b937d5b73bf8162 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:28:22 +0300 Subject: [PATCH 11/22] tests: use require instead of assert where previous step is required --- .../controller/context_controller_test.go | 9 +- internal/controller/health_controller_test.go | 7 +- internal/controller/oidc_controller_test.go | 100 +++++++++--------- internal/controller/user_controller_test.go | 28 ++--- 4 files changed, 73 insertions(+), 71 deletions(-) diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 162fd166..4d65e8a5 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils" @@ -44,7 +45,7 @@ func TestContextController(t *testing.T) { WarningsEnabled: cfg.UI.WarningsEnabled, } bytes, err := json.Marshal(expectedAppContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -58,7 +59,7 @@ func TestContextController(t *testing.T) { Message: "Unauthorized", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -91,7 +92,7 @@ func TestContextController(t *testing.T) { Provider: "local", } bytes, err := json.Marshal(expectedUserContextResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -113,7 +114,7 @@ func TestContextController(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest("GET", test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/health_controller_test.go b/internal/controller/health_controller_test.go index 2d8ef349..7576d518 100644 --- a/internal/controller/health_controller_test.go +++ b/internal/controller/health_controller_test.go @@ -8,6 +8,7 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" ) @@ -28,7 +29,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -42,7 +43,7 @@ func TestHealthController(t *testing.T) { "message": "Healthy", } bytes, err := json.Marshal(expectedHealthResponse) - assert.NoError(t, err) + require.NoError(t, err) return string(bytes) }(), }, @@ -59,7 +60,7 @@ func TestHealthController(t *testing.T) { recorder := httptest.NewRecorder() request, err := http.NewRequest(test.method, test.path, nil) - assert.NoError(t, err) + require.NoError(t, err) router.ServeHTTP(recorder, request) diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 9d552bbe..59cabf8b 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -89,7 +89,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") }, @@ -109,7 +109,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -117,7 +117,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") }, @@ -137,7 +137,7 @@ func TestOIDCController(t *testing.T) { Nonce: "some-nonce", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -146,11 +146,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -169,7 +169,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -177,7 +177,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, res["error"], "unsupported_grant_type") }, @@ -192,7 +192,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -230,7 +230,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -253,11 +253,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -269,7 +269,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -292,7 +292,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := tokenRes["refresh_token"] assert.True(t, ok, "Expected refresh token in response") @@ -306,7 +306,7 @@ func TestOIDCController(t *testing.T) { ClientSecret: "some-client-secret", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -318,7 +318,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) var refreshRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok = refreshRes["access_token"] assert.True(t, ok, "Expected access token in refresh response") @@ -339,11 +339,11 @@ func TestOIDCController(t *testing.T) { var authorizeRes map[string]any err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := authorizeRes["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -355,7 +355,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -375,7 +375,7 @@ func TestOIDCController(t *testing.T) { var secondRes map[string]any err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", secondRes["error"]) }, @@ -403,7 +403,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -415,7 +415,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -435,7 +435,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -450,7 +450,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -465,7 +465,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -480,7 +480,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, @@ -495,7 +495,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -510,7 +510,7 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_request", res["error"]) }, }, @@ -527,7 +527,7 @@ func TestOIDCController(t *testing.T) { var tokenRes map[string]any err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) - assert.NoError(t, err) + require.NoError(t, err) accessToken := tokenRes["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -541,7 +541,7 @@ func TestOIDCController(t *testing.T) { var userInfoRes map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) - assert.NoError(t, err) + require.NoError(t, err) _, ok := userInfoRes["sub"] assert.True(t, ok, "Expected sub claim in userinfo response") @@ -565,7 +565,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -574,11 +574,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -595,7 +595,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -626,7 +626,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -635,11 +635,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -656,7 +656,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -687,7 +687,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "S256", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -696,11 +696,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() assert.Equal(t, queryParams.Get("state"), "some-state") @@ -717,7 +717,7 @@ func TestOIDCController(t *testing.T) { CodeVerifier: "some-challenge-1", } reqBodyEncoded, err := query.Values(tokenReqBody) - assert.NoError(t, err) + require.NoError(t, err) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -748,7 +748,7 @@ func TestOIDCController(t *testing.T) { CodeChallengeMethod: "foo", } reqBodyBytes, err := json.Marshal(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req.Header.Set("Content-Type", "application/json") @@ -757,11 +757,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() error := queryParams.Get("error") @@ -780,11 +780,11 @@ func TestOIDCController(t *testing.T) { var res map[string]any err := json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) redirectURI := res["redirect_uri"].(string) url, err := url.Parse(redirectURI) - assert.NoError(t, err) + require.NoError(t, err) queryParams := url.Query() code := queryParams.Get("code") @@ -796,7 +796,7 @@ func TestOIDCController(t *testing.T) { RedirectURI: "https://test.example.com/callback", } reqBodyEncoded, err := query.Values(reqBody) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -807,7 +807,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 200, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) accessToken := res["access_token"].(string) assert.NotEmpty(t, accessToken) @@ -832,7 +832,7 @@ func TestOIDCController(t *testing.T) { assert.Equal(t, 401, recorder.Code) err = json.Unmarshal(recorder.Body.Bytes(), &res) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "invalid_grant", res["error"]) }, }, diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index bfe232fc..e834a8b5 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -95,7 +95,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -103,7 +103,7 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -123,7 +123,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -144,7 +144,7 @@ func TestUserController(t *testing.T) { Password: "wrongpassword", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) for range 3 { recorder := httptest.NewRecorder() @@ -179,7 +179,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -190,12 +190,12 @@ func TestUserController(t *testing.T) { decodedBody := make(map[string]any) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, decodedBody["totpPending"], true) // should set the session cookie - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) cookie := recorder.Result().Cookies()[0] assert.Equal(t, "tinyauth-session", cookie.Name) assert.True(t, cookie.HttpOnly) @@ -216,7 +216,7 @@ func TestUserController(t *testing.T) { Password: "password", } loginReqBody, err := json.Marshal(loginReq) - assert.NoError(t, err) + require.NoError(t, err) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req.Header.Set("Content-Type", "application/json") @@ -225,7 +225,7 @@ func TestUserController(t *testing.T) { assert.Equal(t, 200, recorder.Code) cookies := recorder.Result().Cookies() - assert.Len(t, cookies, 1) + require.Len(t, cookies, 1) cookie := cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -239,7 +239,7 @@ func TestUserController(t *testing.T) { assert.Equal(t, 200, recorder.Code) cookies = recorder.Result().Cookies() - assert.Len(t, cookies, 1) + require.Len(t, cookies, 1) cookie = cookies[0] assert.Equal(t, "tinyauth-session", cookie.Name) @@ -266,14 +266,14 @@ func TestUserController(t *testing.T) { require.NoError(t, err) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) - assert.NoError(t, err) + require.NoError(t, err) totpReq := controller.TotpRequest{ Code: code, } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) @@ -288,7 +288,7 @@ func TestUserController(t *testing.T) { router.ServeHTTP(recorder, req) assert.Equal(t, 200, recorder.Code) - assert.Len(t, recorder.Result().Cookies(), 1) + require.Len(t, recorder.Result().Cookies(), 1) // should set a new session cookie with totp pending removed totpCookie := recorder.Result().Cookies()[0] @@ -311,7 +311,7 @@ func TestUserController(t *testing.T) { } totpReqBody, err := json.Marshal(totpReq) - assert.NoError(t, err) + require.NoError(t, err) recorder = httptest.NewRecorder() req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) From a76141a99dd5d78babbc68b2b7d43c90862396d5 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:32:08 +0300 Subject: [PATCH 12/22] tests: fix middleware tests --- .../middleware/context_middleware_test.go | 61 +++------- internal/middleware/middleware_test.go | 108 ++++++++++++++++++ 2 files changed, 122 insertions(+), 47 deletions(-) create mode 100644 internal/middleware/middleware_test.go diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 5dfde3b4..167c200e 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -5,7 +5,7 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "path" + "sync" "testing" "time" @@ -17,36 +17,14 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" - "github.com/tinyauthapp/tinyauth/internal/utils/tlog" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestContextMiddleware(t *testing.T) { - tlog.NewTestLogger().Init() - tempDir := t.TempDir() - - authServiceCfg := service.AuthServiceConfig{ - LocalUsers: &[]model.LocalUser{ - { - Username: "testuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - }, - { - Username: "totpuser", - Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password - TOTPSecret: "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", - }, - }, - SessionExpiry: 10, // 10 seconds, useful for testing - CookieDomain: "example.com", - LoginTimeout: 10, // 10 seconds, useful for testing - LoginMaxRetries: 3, - SessionCookieName: "tinyauth-session", - } + log := logger.NewLogger().WithTestConfig() + log.Init() - middlewareCfg := middleware.ContextMiddlewareConfig{ - CookieDomain: "example.com", - SessionCookieName: "tinyauth-session", - } + cfg, runtime := createTestConfigs(t) basicAuthHeader := func(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) @@ -270,30 +248,20 @@ func TestContextMiddleware(t *testing.T) { }, } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - - app := bootstrap.NewBootstrapApp(model.Config{}) + ctx := context.TODO() + wg := &sync.WaitGroup{} - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) + app := bootstrap.NewBootstrapApp(cfg) - ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() + err := app.SetupDatabase() require.NoError(t, err) - broker := service.NewOAuthBrokerService(oauthBrokerCfgs) - err = broker.Init() - require.NoError(t, err) + queries := repository.New(app.GetDB()) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) - err = authService.Init() - require.NoError(t, err) + broker := service.NewOAuthBrokerService(log, map[string]model.OAuthServiceConfig{}, ctx) + authService := service.NewAuthService(log, cfg, runtime, ctx, wg, nil, queries, broker) - contextMiddleware := middleware.NewContextMiddleware(middlewareCfg, authService, broker) - err = contextMiddleware.Init() - require.NoError(t, err) + contextMiddleware := middleware.NewContextMiddleware(log, runtime, authService, broker) for _, test := range tests { authService.ClearRateLimitsTestingOnly() @@ -322,7 +290,6 @@ func TestContextMiddleware(t *testing.T) { } t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) + app.GetDB().Close() }) } diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go new file mode 100644 index 00000000..c00bf4e0 --- /dev/null +++ b/internal/middleware/middleware_test.go @@ -0,0 +1,108 @@ +package middleware_test + +import ( + "path" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tinyauthapp/tinyauth/internal/model" + "golang.org/x/crypto/bcrypt" +) + +// Note: This code is duplicated from controller_test.go + +var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" + +func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { + tempDir := t.TempDir() + + config := model.Config{ + UI: model.UIConfig{ + Title: "Tinyauth Test", + ForgotPasswordMessage: "foo", + BackgroundImage: "/background.jpg", + WarningsEnabled: true, + }, + OAuth: model.OAuthConfig{ + AutoRedirect: "none", + }, + OIDC: model.OIDCConfig{ + Clients: map[string]model.OIDCClientConfig{ + "test": { + ClientID: "some-client-id", + ClientSecret: "some-client-secret", + TrustedRedirectURIs: []string{"https://test.example.com/callback"}, + Name: "Test Client", + }, + }, + PrivateKeyPath: path.Join(tempDir, "key.pem"), + PublicKeyPath: path.Join(tempDir, "key.pub"), + }, + Auth: model.AuthConfig{ + SessionExpiry: 10, + LoginTimeout: 10, + LoginMaxRetries: 3, + }, + Database: model.DatabaseConfig{ + Path: path.Join(tempDir, "test.db"), + }, + Resources: model.ResourcesConfig{ + Enabled: true, + Path: path.Join(tempDir, "resources"), + }, + } + + passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) + require.NoError(t, err) + + runtime := model.RuntimeConfig{ + ConfiguredProviders: []model.Provider{ + { + Name: "Local", + ID: "local", + OAuth: false, + }, + }, + LocalUsers: []model.LocalUser{ + { + Username: "testuser", + Password: string(passwd), + }, + { + Username: "totpuser", + Password: string(passwd), + TOTPSecret: testingTOTPSecret, + }, + { + Username: "attruser", + Password: string(passwd), + Attributes: model.UserAttributes{ + Name: "Alice Smith", + Email: "alice@example.com", + }, + }, + { + Username: "attrtotpuser", + Password: string(passwd), + TOTPSecret: testingTOTPSecret, + Attributes: model.UserAttributes{ + Name: "Bob Jones", + Email: "bob@example.com", + }, + }, + }, + CookieDomain: "example.com", + AppURL: "https://tinyauth.example.com", + SessionCookieName: "tinyauth-session", + OIDCClients: func() []model.OIDCClientConfig { + var clients []model.OIDCClientConfig + for id, client := range config.OIDC.Clients { + client.ID = id + clients = append(clients, client) + } + return clients + }(), + } + + return config, runtime +} From 74aca0f52101f641f8b14e81ca2b0453ced01e32 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:34:34 +0300 Subject: [PATCH 13/22] tests: fix service tests --- internal/service/oidc_service_test.go | 33 +++++++++++++++++++++------ 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/internal/service/oidc_service_test.go b/internal/service/oidc_service_test.go index 394df4be..bc24c9be 100644 --- a/internal/service/oidc_service_test.go +++ b/internal/service/oidc_service_test.go @@ -1,7 +1,9 @@ package service_test import ( + "context" "encoding/json" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -10,6 +12,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func newTestUser() repository.OidcUserinfo { @@ -48,13 +51,29 @@ func newTestUser() repository.OidcUserinfo { func TestCompileUserinfo(t *testing.T) { dir := t.TempDir() - svc := service.NewOIDCService(service.OIDCServiceConfig{ - PrivateKeyPath: dir + "/key.pem", - PublicKeyPath: dir + "/key.pub", - Issuer: "https://tinyauth.example.com", - SessionExpiry: 3600, - }, nil) - require.NoError(t, svc.Init()) + + cfg := model.Config{ + OIDC: model.OIDCConfig{ + PrivateKeyPath: dir + "/key.pem", + PublicKeyPath: dir + "/key.pub", + }, + Auth: model.AuthConfig{ + SessionExpiry: 3600, + }, + } + + runtime := model.RuntimeConfig{ + AppURL: "https://tinyauth.example.com", + } + + log := logger.NewLogger().WithTestConfig() + log.Init() + + ctx := context.TODO() + wg := &sync.WaitGroup{} + + svc, err := service.NewOIDCService(log, cfg, runtime, nil, ctx, wg) + require.NoError(t, err) type testCase struct { description string From 886f9a84d68229385775cd4e6679597580266211 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:38:04 +0300 Subject: [PATCH 14/22] tests: fix context tests --- internal/model/context_test.go | 2 +- internal/service/kubernetes_service_test.go | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/internal/model/context_test.go b/internal/model/context_test.go index 733805a7..79bc97b0 100644 --- a/internal/model/context_test.go +++ b/internal/model/context_test.go @@ -238,7 +238,7 @@ func TestContext(t *testing.T) { _, err := c.NewFromGin(newGinCtx(nil, false)) return err.Error() }, - expected: "failed to get user context", + expected: model.ErrUserContextNotFound.Error(), }, { description: "NewFromGin returns error when context value has wrong type", diff --git a/internal/service/kubernetes_service_test.go b/internal/service/kubernetes_service_test.go index c7b39ead..702fe0f8 100644 --- a/internal/service/kubernetes_service_test.go +++ b/internal/service/kubernetes_service_test.go @@ -8,9 +8,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) func TestKubernetesService(t *testing.T) { + log := logger.NewLogger().WithTestConfig() + log.Init() + type testCase struct { description string run func(t *testing.T, svc *KubernetesService) @@ -179,6 +183,7 @@ func TestKubernetesService(t *testing.T) { ingressApps: make(map[ingressKey][]ingressApp), domainIndex: make(map[string]ingressAppKey), appNameIndex: make(map[string]ingressAppKey), + log: log, } test.run(t, svc) }) From 02b48aa1657eca99f22a641f3e3989c51bab9e80 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:42:44 +0300 Subject: [PATCH 15/22] fix: fix typos --- internal/bootstrap/app_bootstrap.go | 10 +++++----- internal/bootstrap/service_bootstrap.go | 2 +- internal/controller/user_controller.go | 2 +- internal/model/context.go | 2 +- internal/service/access_controls_service.go | 6 +++--- internal/service/auth_service.go | 6 ------ internal/utils/logger/logger.go | 2 +- 7 files changed, 12 insertions(+), 18 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 48d57a9d..36afb79a 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -214,7 +214,7 @@ func (app *BootstrapApp) Setup() error { return errors.New("no authentication providers configured") } - for _, provider := range app.runtime.ConfiguredProviders { + for _, provider := range configuredProviders { app.log.App.Debug().Str("provider", provider.Name).Msg("Configured authentication provider") } @@ -238,7 +238,7 @@ func (app *BootstrapApp) Setup() error { } // create err channel to listen for server errors - errChan := make(chan error, 1) + errChan := make(chan error, 2) // serve unix app.wg.Go(func() { @@ -317,7 +317,7 @@ func (app *BootstrapApp) serveUnix() error { listener, err := net.Listen("unix", app.config.Server.SocketPath) if err != nil { - return fmt.Errorf("failed to create unix socket listner: %w", err) + return fmt.Errorf("failed to create unix socket listener: %w", err) } server := &http.Server{ @@ -330,7 +330,7 @@ func (app *BootstrapApp) serveUnix() error { go func() { <-app.ctx.Done() - app.log.App.Debug().Msg("Shutting down unix sokcet listener") + app.log.App.Debug().Msg("Shutting down unix socket listener") server.Close() listener.Close() os.Remove(app.config.Server.SocketPath) @@ -338,7 +338,7 @@ func (app *BootstrapApp) serveUnix() error { err = server.Serve(listener) - if err != nil && (!errors.Is(err, net.ErrClosed) || !errors.Is(err, http.ErrServerClosed)) { + if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("failed to start unix socket listener: %w", err) } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 1e850437..ef3ee591 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -19,7 +19,7 @@ func (app *BootstrapApp) setupServices() error { useKubernetes := app.config.LabelProvider == "kubernetes" || (app.config.LabelProvider == "auto" && os.Getenv("KUBERNETES_SERVICE_HOST") != "") - var labelProvider service.LabelProviderImpl + var labelProvider service.LabelProvider if useKubernetes { app.log.App.Debug().Msg("Using Kubernetes label provider") diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index f186ec0d..45a876bf 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -86,7 +86,7 @@ func (controller *UserController) loginHandler(c *gin.Context) { if errors.Is(err, service.ErrUserNotFound) { controller.log.App.Warn().Str("username", req.Username).Msg("User not found during login attempt") controller.auth.RecordLoginAttempt(req.Username, false) - controller.log.AuditLoginFailure(req.Username, "unkown", c.ClientIP(), "user not found") + controller.log.AuditLoginFailure(req.Username, "unknown", c.ClientIP(), "user not found") c.JSON(401, gin.H{ "status": 401, "message": "Unauthorized", diff --git a/internal/model/context.go b/internal/model/context.go index c459a620..b9e31bef 100644 --- a/internal/model/context.go +++ b/internal/model/context.go @@ -121,7 +121,7 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, Email: session.Email, }, } - // By default we assume an unkown name which is oauth + // By default we assume an unknown name which is oauth default: c.Provider = ProviderOAuth c.OAuth = &OAuthContext{ diff --git a/internal/service/access_controls_service.go b/internal/service/access_controls_service.go index f6e3cbd2..34700ea7 100644 --- a/internal/service/access_controls_service.go +++ b/internal/service/access_controls_service.go @@ -7,19 +7,19 @@ import ( "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) -type LabelProviderImpl interface { +type LabelProvider interface { GetLabels(appDomain string) (*model.App, error) } type AccessControlsService struct { log *logger.Logger - labelProvider *LabelProviderImpl + labelProvider *LabelProvider static map[string]model.App } func NewAccessControlsService( log *logger.Logger, - labelProvider *LabelProviderImpl, + labelProvider *LabelProvider, static map[string]model.App) *AccessControlsService { return &AccessControlsService{ log: log, diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index ed882438..a9139bb3 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -399,12 +399,6 @@ func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http. auth.log.App.Error().Err(err).Str("uuid", uuid).Msg("Failed to delete session from database") } - err = auth.queries.DeleteSession(ctx, uuid) - - if err != nil { - return nil, err - } - return &http.Cookie{ Name: auth.runtime.SessionCookieName, Value: "", diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go index 24d93145..b3ceda5e 100644 --- a/internal/utils/logger/logger.go +++ b/internal/utils/logger/logger.go @@ -33,7 +33,7 @@ func NewLogger() *Logger { App: model.LogStreamConfig{ Enabled: true, }, - // No reason to enabled audit by default since it will be surpressed by the log level + // No reason to enabled audit by default since it will be suppressed by the log level }, }, } From 4e760e83978a2ff96cd35aedd78dcb3e9365c2d8 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:52:49 +0300 Subject: [PATCH 16/22] feat: add option to enable or disable concurrent listeners --- internal/bootstrap/app_bootstrap.go | 43 +++++++++++++++++++++-------- internal/model/config.go | 20 ++++++++------ 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 36afb79a..7d983ff7 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -238,21 +238,42 @@ func (app *BootstrapApp) Setup() error { } // create err channel to listen for server errors - errChan := make(chan error, 2) + errChanLen := 0 + + runUnix := app.config.Server.SocketPath != "" + runHTTP := app.config.Server.SocketPath == "" || app.config.Server.ConcurrentListenersEnabled + + if runUnix { + errChanLen++ + } + + if runHTTP { + errChanLen++ + } + + errChan := make(chan error, errChanLen) + + if app.config.Server.ConcurrentListenersEnabled { + app.log.App.Info().Msg("Concurrent listeners enabled, will run on all available listeners") + } // serve unix - app.wg.Go(func() { - if err := app.serveUnix(); err != nil { - errChan <- err - } - }) + if runUnix { + app.wg.Go(func() { + if err := app.serveUnix(); err != nil { + errChan <- err + } + }) + } // serve to http - app.wg.Go(func() { - if err := app.serveHTTP(); err != nil { - errChan <- err - } - }) + if runHTTP { + app.wg.Go(func() { + if err := app.serveHTTP(); err != nil { + errChan <- err + } + }) + } // monitor cancellation and server errors for { diff --git a/internal/model/config.go b/internal/model/config.go index 95870e3d..f5376af2 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -14,8 +14,9 @@ func NewDefaultConfiguration() *Config { Path: "./resources", }, Server: ServerConfig{ - Port: 3000, - Address: "0.0.0.0", + Port: 3000, + Address: "0.0.0.0", + ConcurrentListenersEnabled: false, }, Auth: AuthConfig{ SubdomainsEnabled: true, @@ -95,9 +96,10 @@ type ResourcesConfig struct { } type ServerConfig struct { - Port int `description:"The port on which the server listens." yaml:"port"` - Address string `description:"The address on which the server listens." yaml:"address"` - SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + Port int `description:"The port on which the server listens." yaml:"port"` + Address string `description:"The address on which the server listens." yaml:"address"` + SocketPath string `description:"The path to the Unix socket." yaml:"socketPath"` + ConcurrentListenersEnabled bool `description:"Enable listening on both TCP and Unix socket at the same time." yaml:"concurrentListenersEnabled"` } type AuthConfig struct { @@ -147,10 +149,10 @@ type IPConfig struct { } type OAuthConfig struct { - Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` - WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` - AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` - Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` + Whitelist []string `description:"Comma-separated list of allowed OAuth domains." yaml:"whitelist"` + WhitelistFile string `description:"Path to the OAuth whitelist file." yaml:"whitelistFile"` + AutoRedirect string `description:"The OAuth provider to use for automatic redirection." yaml:"autoRedirect"` + Providers map[string]OAuthServiceConfig `description:"OAuth providers configuration." yaml:"providers"` } type OIDCConfig struct { From 3d9c81d7a01ebd3e9cad04c4a74467ce682266b6 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 13:56:28 +0300 Subject: [PATCH 17/22] fix: assign public key correctly in oidc server --- internal/controller/well_known_controller.go | 6 +++--- internal/service/oidc_service.go | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/internal/controller/well_known_controller.go b/internal/controller/well_known_controller.go index a00876be..8c71d890 100644 --- a/internal/controller/well_known_controller.go +++ b/internal/controller/well_known_controller.go @@ -44,7 +44,7 @@ func NewWellKnownController(oidc *service.OIDCService, router *gin.RouterGroup) func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context) { if controller.oidc == nil { c.JSON(500, gin.H{ - "status": "500", + "status": 500, "message": "OIDC service not configured", }) return @@ -73,7 +73,7 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context func (controller *WellKnownController) JWKS(c *gin.Context) { if controller.oidc == nil { c.JSON(500, gin.H{ - "status": "500", + "status": 500, "message": "OIDC service not configured", }) return @@ -83,7 +83,7 @@ func (controller *WellKnownController) JWKS(c *gin.Context) { if err != nil { c.JSON(500, gin.H{ - "status": "500", + "status": 500, "message": "failed to get JWK", }) return diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 02c33199..6b46ed9b 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -230,11 +230,10 @@ func NewOIDCService( return nil, fmt.Errorf("failed to parse public key: %w", err) } case "PUBLIC KEY": - publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + publicKey, err = x509.ParsePKIXPublicKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse public key: %w", err) } - publicKey = publicKey.(crypto.PublicKey) default: return nil, fmt.Errorf("unsupported public key type: %s", block.Type) } From 548d97fa62134d30769d16937ee70b0c2db71466 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 14:00:58 +0300 Subject: [PATCH 18/22] tests: fix don't try to test logger with char size --- internal/utils/logger/logger_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go index 395d348f..8288f0d6 100644 --- a/internal/utils/logger/logger_test.go +++ b/internal/utils/logger/logger_test.go @@ -159,10 +159,10 @@ func TestLogger(t *testing.T) { l.App.Info().Msg("test") - l.AuditLoginFailure("test", "test", "test", "test") + l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop") assert.NotEmpty(t, buf.String()) - assert.Equal(t, 81, buf.Len()) // it's the length of the test log entry + assert.NotContains(t, "test_nop", buf.String()) }, }, } From d5009070e3fc814c22a8652c4e4f5e16d04eb5b2 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 17:00:02 +0300 Subject: [PATCH 19/22] fix: coderabbit comments --- internal/bootstrap/app_bootstrap.go | 31 +++-- internal/bootstrap/db_bootstrap.go | 7 ++ internal/bootstrap/router_bootstrap.go | 2 +- .../controller/context_controller_test.go | 3 +- internal/controller/oauth_controller.go | 42 ++++--- internal/controller/oidc_controller.go | 38 ++++-- internal/controller/oidc_controller_test.go | 5 +- internal/controller/proxy_controller_test.go | 3 +- .../controller/resources_controller_test.go | 3 +- internal/controller/user_controller_test.go | 3 +- .../controller/well_known_controller_test.go | 4 +- .../middleware/context_middleware_test.go | 3 +- internal/middleware/middleware_test.go | 108 ------------------ internal/service/oidc_service.go | 4 +- .../controller_test.go => test/test.go} | 10 +- internal/utils/logger/logger.go | 2 +- internal/utils/logger/logger_test.go | 2 +- 17 files changed, 107 insertions(+), 163 deletions(-) delete mode 100644 internal/middleware/middleware_test.go rename internal/{controller/controller_test.go => test/test.go} (91%) diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 7d983ff7..8f9fdec0 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -102,7 +102,7 @@ func (app *BootstrapApp) Setup() error { app.runtime.OAuthWhitelist = oauthWhitelist - // Setup oauth providers + // setup oauth providers app.runtime.OAuthProviders = app.config.OAuth.Providers for id, provider := range app.runtime.OAuthProviders { @@ -168,6 +168,14 @@ func (app *BootstrapApp) Setup() error { return fmt.Errorf("failed to setup database: %w", err) } + // after this point, we start initializing dependencies so it's a good time to setup a defer + // to ensure that resources are cleaned up properly in case of an error during initialization + defer func() { + app.cancel() + app.wg.Wait() + app.db.Close() + }() + // queries queries := repository.New(app.db) app.queries = queries @@ -279,9 +287,6 @@ func (app *BootstrapApp) Setup() error { for { select { case <-app.ctx.Done(): - app.wg.Wait() - app.log.App.Debug().Msg("Closing database") - app.db.Close() app.log.App.Info().Msg("Oh, it's time for me to go, bye!") return nil case err := <-errChan: @@ -305,7 +310,7 @@ func (app *BootstrapApp) serveHTTP() error { go func() { <-app.ctx.Done() app.log.App.Debug().Msg("Shutting down http listener") - server.Close() + server.Shutdown(app.ctx) }() err := server.ListenAndServe() @@ -345,21 +350,23 @@ func (app *BootstrapApp) serveUnix() error { Handler: app.router.Handler(), } - defer server.Close() - defer listener.Close() - defer os.Remove(app.config.Server.SocketPath) + shutdown := func() { + server.Shutdown(app.ctx) + listener.Close() + os.Remove(app.config.Server.SocketPath) + } + + defer shutdown() go func() { <-app.ctx.Done() app.log.App.Debug().Msg("Shutting down unix socket listener") - server.Close() - listener.Close() - os.Remove(app.config.Server.SocketPath) + shutdown() }() err = server.Serve(listener) - if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { + if err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("failed to start unix socket listener: %w", err) } diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index f9554ddc..d8572c4c 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -27,6 +27,13 @@ func (app *BootstrapApp) SetupDatabase() error { return fmt.Errorf("failed to open database: %w", err) } + // Close the database if there is an error during migration + defer func() { + if err != nil { + db.Close() + } + }() + // Limit to 1 connection to sequence writes, this may need to be revisited in the future // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index 90a20c3b..12a48bc0 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -43,7 +43,7 @@ func (app *BootstrapApp) setupRouter() error { controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) - controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) + controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewResourcesController(app.config, &engine.RouterGroup) diff --git a/internal/controller/context_controller_test.go b/internal/controller/context_controller_test.go index 4d65e8a5..177f4744 100644 --- a/internal/controller/context_controller_test.go +++ b/internal/controller/context_controller_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -19,7 +20,7 @@ func TestContextController(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) tests := []struct { description string diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 803a4c04..1aec73ae 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { state := c.Query("state") if state != oauthPendingSession.State { controller.log.App.Warn().Msg("OAuth state mismatch") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) + if err != nil { + controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + + if user == nil { + controller.log.App.Warn().Msg("OAuth provider did not return user info") + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) + return + } + if user.Email == "" { controller.log.App.Warn().Msg("OAuth provider did not return an email") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode())) return } @@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } if svc.ID() != req.Provider { controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to create session cookie") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } @@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { queries, err := query.Values(oauthPendingSession.CallbackParams) if err != nil { controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode())) return } @@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { if err != nil { controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL)) return } - c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode())) return } - c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) + c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL) } func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { diff --git a/internal/controller/oidc_controller.go b/internal/controller/oidc_controller.go index 7e735159..142f0b40 100644 --- a/internal/controller/oidc_controller.go +++ b/internal/controller/oidc_controller.go @@ -17,8 +17,9 @@ import ( ) type OIDCController struct { - log *logger.Logger - oidc *service.OIDCService + log *logger.Logger + oidc *service.OIDCService + runtime model.RuntimeConfig } type AuthorizeCallback struct { @@ -58,10 +59,12 @@ type ClientCredentials struct { func NewOIDCController( log *logger.Logger, oidcService *service.OIDCService, + runtimeConfig model.RuntimeConfig, router *gin.RouterGroup) *OIDCController { controller := &OIDCController{ - log: log, - oidc: oidcService, + log: log, + oidc: oidcService, + runtime: runtimeConfig, } oidcGroup := router.Group("/oidc") @@ -75,6 +78,15 @@ func NewOIDCController( } func (controller *OIDCController) GetClientInfo(c *gin.Context) { + if controller.oidc == nil { + controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured") + c.JSON(500, gin.H{ + "status": 500, + "message": "OIDC not configured", + }) + return + } + var req ClientRequest err := c.BindUri(&req) @@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) { if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") - c.JSON(404, gin.H{ - "error": "not_found", + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) { if controller.oidc == nil { controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") - c.JSON(404, gin.H{ - "error": "not_found", + c.JSON(500, gin.H{ + "error": "server_error", }) return } @@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas return } + redirectUrl := "" + + if controller.oidc != nil { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()) + } else { + redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode()) + } + c.JSON(200, gin.H{ "status": 200, - "redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), + "redirect_uri": redirectUrl, }) } diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 59cabf8b..9ece2073 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -20,6 +20,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -27,7 +28,7 @@ func TestOIDCController(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) simpleCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ @@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) { group := router.Group("/api") gin.SetMode(gin.TestMode) - controller.NewOIDCController(log, oidcService, group) + controller.NewOIDCController(log, oidcService, runtime, group) recorder := httptest.NewRecorder() diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index c1603d14..12c3c9f1 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -14,6 +14,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -21,7 +22,7 @@ func TestProxyController(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) acls := map[string]model.App{ "app_path_allow": { diff --git a/internal/controller/resources_controller_test.go b/internal/controller/resources_controller_test.go index 8c8554d3..68ce463d 100644 --- a/internal/controller/resources_controller_test.go +++ b/internal/controller/resources_controller_test.go @@ -10,10 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tinyauthapp/tinyauth/internal/controller" + "github.com/tinyauthapp/tinyauth/internal/test" ) func TestResourcesController(t *testing.T) { - cfg, _ := createTestConfigs(t) + cfg, _ := test.CreateTestConfigs(t) err := os.MkdirAll(cfg.Resources.Path, 0777) require.NoError(t, err) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index e834a8b5..10858175 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -19,6 +19,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -26,7 +27,7 @@ func TestUserController(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) totpCtx := func(c *gin.Context) { c.Set("context", &model.UserContext{ diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 9d6c7483..e2323da2 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -15,6 +15,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) type testCase struct { description string @@ -99,6 +100,7 @@ func TestWellKnownController(t *testing.T) { queries := repository.New(app.GetDB()) oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) + require.NoError(t, err) for _, test := range tests { t.Run(test.description, func(t *testing.T) { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 167c200e..03f9f553 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -17,6 +17,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" + "github.com/tinyauthapp/tinyauth/internal/test" "github.com/tinyauthapp/tinyauth/internal/utils/logger" ) @@ -24,7 +25,7 @@ func TestContextMiddleware(t *testing.T) { log := logger.NewLogger().WithTestConfig() log.Init() - cfg, runtime := createTestConfigs(t) + cfg, runtime := test.CreateTestConfigs(t) basicAuthHeader := func(username, password string) string { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go deleted file mode 100644 index c00bf4e0..00000000 --- a/internal/middleware/middleware_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package middleware_test - -import ( - "path" - "testing" - - "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/model" - "golang.org/x/crypto/bcrypt" -) - -// Note: This code is duplicated from controller_test.go - -var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" - -func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { - tempDir := t.TempDir() - - config := model.Config{ - UI: model.UIConfig{ - Title: "Tinyauth Test", - ForgotPasswordMessage: "foo", - BackgroundImage: "/background.jpg", - WarningsEnabled: true, - }, - OAuth: model.OAuthConfig{ - AutoRedirect: "none", - }, - OIDC: model.OIDCConfig{ - Clients: map[string]model.OIDCClientConfig{ - "test": { - ClientID: "some-client-id", - ClientSecret: "some-client-secret", - TrustedRedirectURIs: []string{"https://test.example.com/callback"}, - Name: "Test Client", - }, - }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), - }, - Auth: model.AuthConfig{ - SessionExpiry: 10, - LoginTimeout: 10, - LoginMaxRetries: 3, - }, - Database: model.DatabaseConfig{ - Path: path.Join(tempDir, "test.db"), - }, - Resources: model.ResourcesConfig{ - Enabled: true, - Path: path.Join(tempDir, "resources"), - }, - } - - passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost) - require.NoError(t, err) - - runtime := model.RuntimeConfig{ - ConfiguredProviders: []model.Provider{ - { - Name: "Local", - ID: "local", - OAuth: false, - }, - }, - LocalUsers: []model.LocalUser{ - { - Username: "testuser", - Password: string(passwd), - }, - { - Username: "totpuser", - Password: string(passwd), - TOTPSecret: testingTOTPSecret, - }, - { - Username: "attruser", - Password: string(passwd), - Attributes: model.UserAttributes{ - Name: "Alice Smith", - Email: "alice@example.com", - }, - }, - { - Username: "attrtotpuser", - Password: string(passwd), - TOTPSecret: testingTOTPSecret, - Attributes: model.UserAttributes{ - Name: "Bob Jones", - Email: "bob@example.com", - }, - }, - }, - CookieDomain: "example.com", - AppURL: "https://tinyauth.example.com", - SessionCookieName: "tinyauth-session", - OIDCClients: func() []model.OIDCClientConfig { - var clients []model.OIDCClientConfig - for id, client := range config.OIDC.Clients { - client.ID = id - clients = append(clients, client) - } - return clients - }(), - } - - return config, runtime -} diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 6b46ed9b..92216451 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -790,10 +790,8 @@ func (service *OIDCService) cleanupRoutine() { token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - continue - } service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") + continue } if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { diff --git a/internal/controller/controller_test.go b/internal/test/test.go similarity index 91% rename from internal/controller/controller_test.go rename to internal/test/test.go index 675f345f..3ee17a47 100644 --- a/internal/controller/controller_test.go +++ b/internal/test/test.go @@ -1,4 +1,4 @@ -package controller_test +package test import ( "path" @@ -9,9 +9,9 @@ import ( "golang.org/x/crypto/bcrypt" ) -var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" +var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" -func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { +func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { tempDir := t.TempDir() config := model.Config{ @@ -69,7 +69,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { { Username: "totpuser", Password: string(passwd), - TOTPSecret: testingTOTPSecret, + TOTPSecret: TestingTOTPSecret, }, { Username: "attruser", @@ -82,7 +82,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { { Username: "attrtotpuser", Password: string(passwd), - TOTPSecret: testingTOTPSecret, + TOTPSecret: TestingTOTPSecret, Attributes: model.UserAttributes{ Name: "Bob Jones", Email: "bob@example.com", diff --git a/internal/utils/logger/logger.go b/internal/utils/logger/logger.go index b3ceda5e..af6b55ea 100644 --- a/internal/utils/logger/logger.go +++ b/internal/utils/logger/logger.go @@ -33,7 +33,7 @@ func NewLogger() *Logger { App: model.LogStreamConfig{ Enabled: true, }, - // No reason to enabled audit by default since it will be suppressed by the log level + // No reason to enable audit by default since it will be suppressed by the log level }, }, } diff --git a/internal/utils/logger/logger_test.go b/internal/utils/logger/logger_test.go index 8288f0d6..167e2337 100644 --- a/internal/utils/logger/logger_test.go +++ b/internal/utils/logger/logger_test.go @@ -162,7 +162,7 @@ func TestLogger(t *testing.T) { l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop") assert.NotEmpty(t, buf.String()) - assert.NotContains(t, "test_nop", buf.String()) + assert.NotContains(t, buf.String(), "test_nop") }, }, } From e739aa8fd0d29ce36cc9240de29ed434b74d9fb9 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sat, 9 May 2026 17:18:58 +0300 Subject: [PATCH 20/22] tests: use filepath join instead of path join --- internal/test/test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/test/test.go b/internal/test/test.go index 3ee17a47..73ff5d38 100644 --- a/internal/test/test.go +++ b/internal/test/test.go @@ -1,7 +1,7 @@ package test import ( - "path" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -33,8 +33,8 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { Name: "Test Client", }, }, - PrivateKeyPath: path.Join(tempDir, "key.pem"), - PublicKeyPath: path.Join(tempDir, "key.pub"), + PrivateKeyPath: filepath.Join(tempDir, "key.pem"), + PublicKeyPath: filepath.Join(tempDir, "key.pub"), }, Auth: model.AuthConfig{ SessionExpiry: 10, @@ -42,11 +42,11 @@ func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { LoginMaxRetries: 3, }, Database: model.DatabaseConfig{ - Path: path.Join(tempDir, "test.db"), + Path: filepath.Join(tempDir, "test.db"), }, Resources: model.ResourcesConfig{ Enabled: true, - Path: path.Join(tempDir, "resources"), + Path: filepath.Join(tempDir, "resources"), }, } From 11b6155b9e8b1c2c6a20d599c23d8e0f4c14d163 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 10 May 2026 16:03:39 +0300 Subject: [PATCH 21/22] fix: ensure unix socket shutdown doesn't run twice --- internal/bootstrap/app_bootstrap.go | 3 +- lint.html | 105 ++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 2 deletions(-) create mode 100644 lint.html diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 8f9fdec0..3f491fa1 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -356,8 +356,6 @@ func (app *BootstrapApp) serveUnix() error { os.Remove(app.config.Server.SocketPath) } - defer shutdown() - go func() { <-app.ctx.Done() app.log.App.Debug().Msg("Shutting down unix socket listener") @@ -367,6 +365,7 @@ func (app *BootstrapApp) serveUnix() error { err = server.Serve(listener) if err != nil && !errors.Is(err, http.ErrServerClosed) { + shutdown() return fmt.Errorf("failed to start unix socket listener: %w", err) } diff --git a/lint.html b/lint.html new file mode 100644 index 00000000..f29902f4 --- /dev/null +++ b/lint.html @@ -0,0 +1,105 @@ + + + + + golangci-lint + + + + + + + + + + +
+
+
+
+
+ + + + \ No newline at end of file From d38784715d775f26bbf7a49cbdbb2e4da43d9b77 Mon Sep 17 00:00:00 2001 From: Stavros Date: Sun, 10 May 2026 16:04:22 +0300 Subject: [PATCH 22/22] chore: remove temp lint file --- lint.html | 105 ------------------------------------------------------ 1 file changed, 105 deletions(-) delete mode 100644 lint.html diff --git a/lint.html b/lint.html deleted file mode 100644 index f29902f4..00000000 --- a/lint.html +++ /dev/null @@ -1,105 +0,0 @@ - - - - - golangci-lint - - - - - - - - - - -
-
-
-
-
- - - - \ No newline at end of file