diff --git a/internal/db/gorm/token_store.go b/internal/db/gorm/token_store.go index 8f641c01..88938201 100644 --- a/internal/db/gorm/token_store.go +++ b/internal/db/gorm/token_store.go @@ -45,19 +45,19 @@ func (s *TokenStore) List(ctx context.Context) ([]APIToken, error) { return tokens, err } -// FindByPrefix looks up a non-revoked token by its prefix for auth middleware. -func (s *TokenStore) FindByPrefix(ctx context.Context, prefix string) (*APIToken, error) { - var token APIToken +// FindByPrefix looks up all non-revoked tokens matching the given prefix. +// Multiple tokens may share a prefix in the non-unique index, so callers +// must iterate over the returned slice and compare bcrypt hashes to find the +// matching token. +func (s *TokenStore) FindByPrefix(ctx context.Context, prefix string) ([]APIToken, error) { + var tokens []APIToken err := s.db.WithContext(ctx). Where("token_prefix = ? AND NOT revoked", prefix). - First(&token).Error - if err == gorm.ErrRecordNotFound { - return nil, nil - } + Find(&tokens).Error if err != nil { return nil, err } - return &token, nil + return tokens, nil } // Revoke marks a token as revoked. diff --git a/internal/worker/handlers_analytics.go b/internal/worker/handlers_analytics.go index d71f11c8..ec45fa0a 100644 --- a/internal/worker/handlers_analytics.go +++ b/internal/worker/handlers_analytics.go @@ -4,6 +4,7 @@ package worker import ( "fmt" "net/http" + "slices" "strconv" "time" @@ -57,7 +58,11 @@ func (s *Service) handleGetTrends(w http.ResponseWriter, r *http.Request) { } } - // Get observations for analysis (rough estimate of limit) + // Get observations for analysis. The multiplier of 50 is a heuristic: we + // assume at most ~50 observations are created per day on average. For + // deployments with higher throughput the analytics may be slightly + // incomplete; a future improvement would be to filter by created_at >= cutoff + // at the DB level rather than applying a fetch limit here. obs, err := s.observationStore.GetRecentObservations(r.Context(), project, days*50) if err != nil { log.Error().Err(err).Msg("get observations for trends failed") @@ -119,17 +124,13 @@ func (s *Service) handleGetTrends(w http.ResponseWriter, r *http.Request) { name string count int } - var topConcepts []conceptEntry + topConcepts := make([]conceptEntry, 0, len(conceptCounts)) for name, count := range conceptCounts { topConcepts = append(topConcepts, conceptEntry{name, count}) } - for i := 0; i < len(topConcepts) && i < 10; i++ { - for j := i + 1; j < len(topConcepts); j++ { - if topConcepts[j].count > topConcepts[i].count { - topConcepts[i], topConcepts[j] = topConcepts[j], topConcepts[i] - } - } - } + slices.SortFunc(topConcepts, func(a, b conceptEntry) int { + return b.count - a.count // descending by count + }) if len(topConcepts) > 10 { topConcepts = topConcepts[:10] } diff --git a/internal/worker/handlers_auth.go b/internal/worker/handlers_auth.go index 3fbaa6d8..f0ca38cb 100644 --- a/internal/worker/handlers_auth.go +++ b/internal/worker/handlers_auth.go @@ -9,11 +9,14 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "net/http" + "strings" "time" "github.com/go-chi/chi/v5" "github.com/rs/zerolog/log" + "gorm.io/gorm" "golang.org/x/crypto/bcrypt" ) @@ -98,6 +101,7 @@ func (s *Service) handleAuthLogin(w http.ResponseWriter, r *http.Request) { Path: "/", MaxAge: sessionMaxAge, HttpOnly: true, + Secure: r.TLS != nil, SameSite: http.SameSiteStrictMode, }) @@ -317,7 +321,11 @@ func (s *Service) handleRevokeToken(w http.ResponseWriter, r *http.Request) { if err := tokenStore.Revoke(r.Context(), id); err != nil { log.Error().Err(err).Str("token_id", id).Msg("auth: failed to revoke token") - http.Error(w, "not found", http.StatusNotFound) + if errors.Is(err, gorm.ErrRecordNotFound) { + http.Error(w, "not found", http.StatusNotFound) + } else { + http.Error(w, "internal error", http.StatusInternalServerError) + } return } @@ -358,12 +366,8 @@ func isDuplicateKeyError(err error) bool { // containsDuplicateKey checks error message for duplicate key indicators. func containsDuplicateKey(msg string) bool { for _, s := range []string{"duplicate key", "23505", "UNIQUE constraint"} { - if len(msg) >= len(s) { - for i := 0; i <= len(msg)-len(s); i++ { - if msg[i:i+len(s)] == s { - return true - } - } + if strings.Contains(msg, s) { + return true } } return false diff --git a/internal/worker/handlers_maintenance.go b/internal/worker/handlers_maintenance.go index e3e0b84e..da28f7f3 100644 --- a/internal/worker/handlers_maintenance.go +++ b/internal/worker/handlers_maintenance.go @@ -2,6 +2,7 @@ package worker import ( + "context" "encoding/json" "net/http" @@ -85,7 +86,9 @@ func (s *Service) handleRunMaintenance(w http.ResponseWriter, r *http.Request) { return } - s.maintenanceService.RunNow(r.Context()) + // Use background context: the request context is cancelled after the + // response is sent, which would prematurely abort the background job. + s.maintenanceService.RunNow(context.Background()) writeJSON(w, map[string]any{ "status": "triggered", diff --git a/internal/worker/handlers_sessions_rest.go b/internal/worker/handlers_sessions_rest.go index b8467381..2613a255 100644 --- a/internal/worker/handlers_sessions_rest.go +++ b/internal/worker/handlers_sessions_rest.go @@ -10,6 +10,10 @@ import ( "github.com/thebtf/engram/internal/sessions" ) +const ( + maxSessionsLimit = 200 // Server-side cap to prevent huge response payloads. +) + // handleListIndexedSessions godoc // @Summary List indexed sessions // @Description Returns indexed Claude Code sessions with optional project and workstation filters. @@ -18,7 +22,7 @@ import ( // @Security ApiKeyAuth // @Param project query string false "Filter by project ID" // @Param workstation query string false "Filter by workstation ID" -// @Param limit query int false "Number of results (default 20)" +// @Param limit query int false "Number of results (default 20, max 200)" // @Param offset query int false "Pagination offset" // @Success 200 {array} object // @Failure 500 {string} string "internal error" @@ -32,7 +36,7 @@ func (s *Service) handleListIndexedSessions(w http.ResponseWriter, r *http.Reque limit := 20 if val := r.URL.Query().Get("limit"); val != "" { if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 { - limit = parsed + limit = min(parsed, maxSessionsLimit) } } offset := 0 @@ -52,7 +56,7 @@ func (s *Service) handleListIndexedSessions(w http.ResponseWriter, r *http.Reque list, err := s.sessionIdxStore.ListSessions(r.Context(), opts) if err != nil { log.Error().Err(err).Msg("list indexed sessions failed") - http.Error(w, "list sessions: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "internal error", http.StatusInternalServerError) return } @@ -117,14 +121,14 @@ func (s *Service) handleSearchIndexedSessions(w http.ResponseWriter, r *http.Req limit := 10 if val := r.URL.Query().Get("limit"); val != "" { if parsed, err := strconv.Atoi(val); err == nil && parsed > 0 { - limit = parsed + limit = min(parsed, maxSessionsLimit) } } results, err := s.sessionIdxStore.SearchSessions(r.Context(), query, limit) if err != nil { log.Error().Err(err).Str("query", query).Msg("search indexed sessions failed") - http.Error(w, "search sessions: "+err.Error(), http.StatusInternalServerError) + http.Error(w, "internal error", http.StatusInternalServerError) return } diff --git a/internal/worker/middleware.go b/internal/worker/middleware.go index c330b937..ec321e89 100644 --- a/internal/worker/middleware.go +++ b/internal/worker/middleware.go @@ -276,19 +276,26 @@ func (ta *TokenAuth) authenticateClientToken(w http.ResponseWriter, r *http.Requ } prefix := rawToken[4:12] - token, err := store.FindByPrefix(r.Context(), prefix) + candidates, err := store.FindByPrefix(r.Context(), prefix) if err != nil { log.Error().Err(err).Msg("auth: token store lookup failed") http.Error(w, "internal error", http.StatusInternalServerError) return true } - if token == nil { + if len(candidates) == 0 { http.Error(w, "unauthorized", http.StatusUnauthorized) return true } - // Verify bcrypt hash - if err := bcrypt.CompareHashAndPassword([]byte(token.TokenHash), []byte(rawToken)); err != nil { + // Find the matching token by bcrypt comparison (handles prefix collisions). + var token *gormdb.APIToken + for i := range candidates { + if bcrypt.CompareHashAndPassword([]byte(candidates[i].TokenHash), []byte(rawToken)) == nil { + token = &candidates[i] + break + } + } + if token == nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return true } diff --git a/internal/worker/token_stats.go b/internal/worker/token_stats.go index 772ac6df..fc1d29e6 100644 --- a/internal/worker/token_stats.go +++ b/internal/worker/token_stats.go @@ -35,26 +35,27 @@ func (s *Service) startTokenStatsFlusher(ctx context.Context) { for { select { case <-ctx.Done(): - // Final flush on shutdown - s.flushTokenStats(pending) + // Final flush on shutdown: use a bounded timeout so the goroutine + // does not block indefinitely after cancellation. + flushCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + s.flushTokenStats(flushCtx, pending) + cancel() return case tokenID := <-ch: pending[tokenID]++ case <-ticker.C: - s.flushTokenStats(pending) - // Reset the map - for k := range pending { - delete(pending, k) - } + s.flushTokenStats(ctx, pending) + // Reset map by allocating a fresh one (cheaper than deleting in a loop). + pending = make(map[string]int) } } }() } // flushTokenStats writes accumulated token usage counts to the database. -func (s *Service) flushTokenStats(counts map[string]int) { +func (s *Service) flushTokenStats(ctx context.Context, counts map[string]int) { if len(counts) == 0 { return } @@ -67,7 +68,7 @@ func (s *Service) flushTokenStats(counts map[string]int) { return } - if err := store.BatchIncrementStats(context.Background(), counts); err != nil { + if err := store.BatchIncrementStats(ctx, counts); err != nil { log.Warn().Err(err).Int("tokens", len(counts)).Msg("auth: failed to flush token stats") } } diff --git a/ui/src/components/layout/AppSidebar.vue b/ui/src/components/layout/AppSidebar.vue index 77dd54cf..746c32a6 100644 --- a/ui/src/components/layout/AppSidebar.vue +++ b/ui/src/components/layout/AppSidebar.vue @@ -43,12 +43,15 @@ function isActive(item: NavItem): boolean { if (item.path === '/') { return route.path === '/' } - return route.path.startsWith(item.path) + return route.path === item.path || route.path.startsWith(item.path + '/') } async function handleLogout() { - await logout() - router.push({ name: 'login' }) + try { + await logout() + } finally { + router.push({ name: 'login' }) + } } diff --git a/ui/src/components/layout/ConfirmDialog.vue b/ui/src/components/layout/ConfirmDialog.vue index 279e06f6..d4019687 100644 --- a/ui/src/components/layout/ConfirmDialog.vue +++ b/ui/src/components/layout/ConfirmDialog.vue @@ -1,5 +1,7 @@