From f78dccb849052df4898e841d057c8a083ef59c05 Mon Sep 17 00:00:00 2001 From: Djordje Lukic Date: Mon, 9 Feb 2026 23:14:07 +0100 Subject: [PATCH 1/2] Remove GetSessionByOffset from Store interface ResolveSessionID now uses the existing GetSessionSummaries method, which already returns root sessions sorted by created_at DESC. Also fixes InMemorySessionStore.GetSessionSummaries to filter out sub-sessions and sort by CreatedAt descending, matching the SQLite implementation. Assisted-By: cagent Signed-off-by: Djordje Lukic --- pkg/session/store.go | 77 +++++++--------------------- pkg/session/store_test.go | 102 -------------------------------------- 2 files changed, 18 insertions(+), 161 deletions(-) diff --git a/pkg/session/store.go b/pkg/session/store.go index a23f664b8..e48ccf522 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -49,7 +49,18 @@ func ResolveSessionID(ctx context.Context, store Store, ref string) (string, err if !isRelative { return ref, nil } - return store.GetSessionByOffset(ctx, offset) + + summaries, err := store.GetSessionSummaries(ctx) + if err != nil { + return "", fmt.Errorf("getting session summaries: %w", err) + } + + index := offset - 1 + if index >= len(summaries) { + return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(summaries)) + } + + return summaries[index].ID, nil } // Summary contains lightweight session metadata for listing purposes. @@ -74,11 +85,6 @@ type Store interface { SetSessionStarred(ctx context.Context, id string, starred bool) error BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) - // GetSessionByOffset returns the session ID at the given offset from the most recent. - // Offset 1 returns the most recent session, 2 returns the second most recent, etc. - // Only root sessions are considered (sub-sessions are excluded). - GetSessionByOffset(ctx context.Context, offset int) (string, error) - // === Granular item operations === // AddMessage adds a message to a session at the next position. @@ -147,6 +153,9 @@ func (s *InMemorySessionStore) GetSessions(_ context.Context) ([]*Session, error func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary, error) { summaries := make([]Summary, 0, s.sessions.Length()) s.sessions.Range(func(_ string, value *Session) bool { + if value.ParentID != "" { + return true + } summaries = append(summaries, Summary{ ID: value.ID, Title: value.Title, @@ -156,6 +165,9 @@ func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary }) return true }) + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].CreatedAt.After(summaries[j].CreatedAt) + }) return summaries, nil } @@ -358,34 +370,6 @@ func (s *InMemorySessionStore) UpdateSessionTitle(_ context.Context, sessionID, return nil } -// GetSessionByOffset returns the session ID at the given offset from the most recent. -func (s *InMemorySessionStore) GetSessionByOffset(_ context.Context, offset int) (string, error) { - if offset < 1 { - return "", fmt.Errorf("offset must be >= 1, got %d", offset) - } - - // Collect and sort sessions by creation time (newest first) - var sessions []*Session - s.sessions.Range(func(_ string, value *Session) bool { - // Only include root sessions (not sub-sessions) - if value.ParentID == "" { - sessions = append(sessions, value) - } - return true - }) - - sort.Slice(sessions, func(i, j int) bool { - return sessions[i].CreatedAt.After(sessions[j].CreatedAt) - }) - - index := offset - 1 // offset 1 means index 0 (most recent session) - if index >= len(sessions) { - return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(sessions)) - } - - return sessions[index].ID, nil -} - // NewSQLiteSessionStore creates a new SQLite session store func NewSQLiteSessionStore(path string) (Store, error) { store, err := openAndMigrateSQLiteStore(path) @@ -1400,28 +1384,3 @@ func (s *SQLiteSessionStore) UpdateSessionTitle(ctx context.Context, sessionID, title, sessionID) return err } - -// GetSessionByOffset returns the session ID at the given offset from the most recent. -func (s *SQLiteSessionStore) GetSessionByOffset(ctx context.Context, offset int) (string, error) { - if offset < 1 { - return "", fmt.Errorf("offset must be >= 1, got %d", offset) - } - - // Query sessions ordered by creation time (newest first), limited to offset - // Only include root sessions (not sub-sessions) - var sessionID string - err := s.db.QueryRowContext(ctx, - `SELECT id FROM sessions - WHERE parent_id IS NULL OR parent_id = '' - ORDER BY created_at DESC - LIMIT 1 OFFSET ?`, - offset-1).Scan(&sessionID) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return "", fmt.Errorf("session offset %d out of range", offset) - } - return "", err - } - - return sessionID, nil -} diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go index b9883016f..bfc226483 100644 --- a/pkg/session/store_test.go +++ b/pkg/session/store_test.go @@ -1274,105 +1274,3 @@ func TestResolveSessionID_InMemory(t *testing.T) { assert.Equal(t, "some-uuid", id) }) } - -func TestGetSessionByOffset_SQLite(t *testing.T) { - tempDB := filepath.Join(t.TempDir(), "test_offset.db") - - store, err := NewSQLiteSessionStore(tempDB) - require.NoError(t, err) - defer store.(*SQLiteSessionStore).Close() - - // Create sessions with known timestamps - baseTime := time.Now() - sessions := []struct { - id string - createdAt time.Time - }{ - {"oldest", baseTime.Add(-3 * time.Hour)}, - {"middle", baseTime.Add(-2 * time.Hour)}, - {"newest", baseTime.Add(-1 * time.Hour)}, - } - - for _, s := range sessions { - err := store.AddSession(t.Context(), &Session{ - ID: s.id, - CreatedAt: s.createdAt, - }) - require.NoError(t, err) - } - - t.Run("offset 1 returns newest", func(t *testing.T) { - id, err := store.GetSessionByOffset(t.Context(), 1) - require.NoError(t, err) - assert.Equal(t, "newest", id) - }) - - t.Run("offset 2 returns middle", func(t *testing.T) { - id, err := store.GetSessionByOffset(t.Context(), 2) - require.NoError(t, err) - assert.Equal(t, "middle", id) - }) - - t.Run("offset 3 returns oldest", func(t *testing.T) { - id, err := store.GetSessionByOffset(t.Context(), 3) - require.NoError(t, err) - assert.Equal(t, "oldest", id) - }) - - t.Run("offset 0 returns error", func(t *testing.T) { - _, err := store.GetSessionByOffset(t.Context(), 0) - require.Error(t, err) - }) - - t.Run("out of range offset returns error", func(t *testing.T) { - _, err := store.GetSessionByOffset(t.Context(), 4) - require.Error(t, err) - assert.Contains(t, err.Error(), "out of range") - }) -} - -func TestGetSessionByOffset_InMemory(t *testing.T) { - store := NewInMemorySessionStore() - - // Create sessions with known timestamps - baseTime := time.Now() - sessions := []struct { - id string - createdAt time.Time - }{ - {"oldest", baseTime.Add(-3 * time.Hour)}, - {"middle", baseTime.Add(-2 * time.Hour)}, - {"newest", baseTime.Add(-1 * time.Hour)}, - } - - for _, s := range sessions { - err := store.AddSession(t.Context(), &Session{ - ID: s.id, - CreatedAt: s.createdAt, - }) - require.NoError(t, err) - } - - t.Run("offset 1 returns newest", func(t *testing.T) { - id, err := store.GetSessionByOffset(t.Context(), 1) - require.NoError(t, err) - assert.Equal(t, "newest", id) - }) - - t.Run("offset 2 returns middle", func(t *testing.T) { - id, err := store.GetSessionByOffset(t.Context(), 2) - require.NoError(t, err) - assert.Equal(t, "middle", id) - }) - - t.Run("offset 0 returns error", func(t *testing.T) { - _, err := store.GetSessionByOffset(t.Context(), 0) - require.Error(t, err) - }) - - t.Run("out of range offset returns error", func(t *testing.T) { - _, err := store.GetSessionByOffset(t.Context(), 4) - require.Error(t, err) - assert.Contains(t, err.Error(), "out of range") - }) -} From dfddca567291f5d291029d255354b1476ef54d88 Mon Sep 17 00:00:00 2001 From: Djordje Lukic Date: Mon, 9 Feb 2026 23:18:35 +0100 Subject: [PATCH 2/2] Remove BranchSession from Store interface Export buildBranchedSession as session.BranchSession so callers compose it with GetSession and AddSession directly, rather than having the store own branching logic. Also removes the now-unused collectSessionIDs helper. Assisted-By: cagent Signed-off-by: Djordje Lukic --- pkg/session/branch.go | 16 +++--------- pkg/session/branch_test.go | 12 ++++----- pkg/session/store.go | 51 -------------------------------------- pkg/session/store_test.go | 14 +++++++++-- pkg/tui/handlers.go | 12 ++++++++- 5 files changed, 32 insertions(+), 73 deletions(-) diff --git a/pkg/session/branch.go b/pkg/session/branch.go index 496b8d3ad..11b947471 100644 --- a/pkg/session/branch.go +++ b/pkg/session/branch.go @@ -9,7 +9,9 @@ import ( "github.com/docker/cagent/pkg/tools" ) -func buildBranchedSession(parent *Session, branchAtPosition int) (*Session, error) { +// BranchSession creates a new session branched from the parent at the given position. +// Messages up to (but not including) branchAtPosition are deep-cloned into the new session. +func BranchSession(parent *Session, branchAtPosition int) (*Session, error) { if parent == nil { return nil, fmt.Errorf("parent session is nil") } @@ -242,15 +244,3 @@ func recalculateSessionTotals(sess *Session) { sess.OutputTokens = outputTokens sess.Cost = cost } - -func collectSessionIDs(sess *Session, ids map[string]struct{}) { - if sess == nil || ids == nil { - return - } - ids[sess.ID] = struct{}{} - for _, item := range sess.Messages { - if item.SubSession != nil { - collectSessionIDs(item.SubSession, ids) - } - } -} diff --git a/pkg/session/branch_test.go b/pkg/session/branch_test.go index efb0b86b4..d7bf5a6de 100644 --- a/pkg/session/branch_test.go +++ b/pkg/session/branch_test.go @@ -90,30 +90,30 @@ func TestCloneSessionItem(t *testing.T) { }) } -func TestBuildBranchedSession(t *testing.T) { +func TestBranchSession(t *testing.T) { t.Run("nil parent returns error", func(t *testing.T) { - _, err := buildBranchedSession(nil, 0) + _, err := BranchSession(nil, 0) require.Error(t, err) assert.Contains(t, err.Error(), "parent session is nil") }) t.Run("negative position returns error", func(t *testing.T) { parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}} - _, err := buildBranchedSession(parent, -1) + _, err := BranchSession(parent, -1) require.Error(t, err) assert.Contains(t, err.Error(), "out of range") }) t.Run("position beyond messages returns error", func(t *testing.T) { parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}} - _, err := buildBranchedSession(parent, 2) + _, err := BranchSession(parent, 2) require.Error(t, err) assert.Contains(t, err.Error(), "out of range") }) t.Run("position equal to messages length returns error", func(t *testing.T) { parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}} - _, err := buildBranchedSession(parent, 1) + _, err := BranchSession(parent, 1) require.Error(t, err) assert.Contains(t, err.Error(), "out of range") }) @@ -129,7 +129,7 @@ func TestBuildBranchedSession(t *testing.T) { }, } - branched, err := buildBranchedSession(parent, 2) + branched, err := BranchSession(parent, 2) require.NoError(t, err) assert.NotNil(t, branched) diff --git a/pkg/session/store.go b/pkg/session/store.go index e48ccf522..ccfd2d305 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -83,7 +83,6 @@ type Store interface { DeleteSession(ctx context.Context, id string) error UpdateSession(ctx context.Context, session *Session) error // Updates metadata only (not messages/items) SetSessionStarred(ctx context.Context, id string, starred bool) error - BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) // === Granular item operations === @@ -218,25 +217,6 @@ func (s *InMemorySessionStore) SetSessionStarred(_ context.Context, id string, s return nil } -// BranchSession creates a new session branched from the parent at the given position. -func (s *InMemorySessionStore) BranchSession(_ context.Context, parentSessionID string, branchAtPosition int) (*Session, error) { - if parentSessionID == "" { - return nil, ErrEmptyID - } - parent, exists := s.sessions.Load(parentSessionID) - if !exists { - return nil, ErrNotFound - } - - branched, err := buildBranchedSession(parent, branchAtPosition) - if err != nil { - return nil, err - } - - s.sessions.Store(branched.ID, branched) - return branched, nil -} - // AddMessage adds a message to a session at the next position. // Returns the ID of the created message (for in-memory, this is a simple counter). func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, msg *Message) (int64, error) { @@ -1075,37 +1055,6 @@ func (s *SQLiteSessionStore) SetSessionStarred(ctx context.Context, id string, s return nil } -// BranchSession creates a new session branched from the parent at the given position. -func (s *SQLiteSessionStore) BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) { - if parentSessionID == "" { - return nil, ErrEmptyID - } - - parent, err := s.GetSession(ctx, parentSessionID) - if err != nil { - return nil, err - } - - branched, err := buildBranchedSession(parent, branchAtPosition) - if err != nil { - return nil, err - } - - if err := s.AddSession(ctx, branched); err != nil { - return nil, err - } - - ids := make(map[string]struct{}) - collectSessionIDs(branched, ids) - for id := range ids { - if err := s.syncMessagesColumn(ctx, id); err != nil { - slog.Warn("[STORE] Failed to sync messages column after branch", "session_id", id, "error", err) - } - } - - return branched, nil -} - // Close closes the database connection func (s *SQLiteSessionStore) Close() error { return s.db.Close() diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go index bfc226483..104914990 100644 --- a/pkg/session/store_test.go +++ b/pkg/session/store_test.go @@ -244,8 +244,13 @@ func TestBranchSessionCopiesPrefix(t *testing.T) { require.NoError(t, store.AddSession(t.Context(), parent)) - branched, err := store.BranchSession(t.Context(), parent.ID, 2) + parentLoaded, err := store.GetSession(t.Context(), parent.ID) require.NoError(t, err) + + branched, err := BranchSession(parentLoaded, 2) + require.NoError(t, err) + + require.NoError(t, store.AddSession(t.Context(), branched)) require.NotNil(t, branched.BranchParentPosition) assert.Equal(t, parent.ID, branched.BranchParentSessionID) assert.Equal(t, 2, *branched.BranchParentPosition) @@ -289,9 +294,14 @@ func TestBranchSessionClonesSubSession(t *testing.T) { require.NoError(t, store.AddSession(t.Context(), parent)) - branched, err := store.BranchSession(t.Context(), parent.ID, 2) + parentLoaded, err := store.GetSession(t.Context(), parent.ID) require.NoError(t, err) + branched, err := BranchSession(parentLoaded, 2) + require.NoError(t, err) + + require.NoError(t, store.AddSession(t.Context(), branched)) + loaded, err := store.GetSession(t.Context(), branched.ID) require.NoError(t, err) require.Len(t, loaded.Messages, 2) diff --git a/pkg/tui/handlers.go b/pkg/tui/handlers.go index b03e8ca51..941f289ea 100644 --- a/pkg/tui/handlers.go +++ b/pkg/tui/handlers.go @@ -15,6 +15,7 @@ import ( "github.com/docker/cagent/pkg/browser" "github.com/docker/cagent/pkg/evaluation" "github.com/docker/cagent/pkg/modelsdev" + "github.com/docker/cagent/pkg/session" "github.com/docker/cagent/pkg/tools" mcptools "github.com/docker/cagent/pkg/tools/mcp" "github.com/docker/cagent/pkg/tui/components/notification" @@ -106,11 +107,20 @@ func (a *appModel) handleBranchFromEdit(msg messages.BranchFromEditMsg) (tea.Mod return a, notification.ErrorCmd("No parent session for branch") } - newSess, err := store.BranchSession(context.Background(), msg.ParentSessionID, msg.BranchAtPosition) + parent, err := store.GetSession(context.Background(), msg.ParentSessionID) + if err != nil { + return a, notification.ErrorCmd(fmt.Sprintf("Failed to load parent session: %v", err)) + } + + newSess, err := session.BranchSession(parent, msg.BranchAtPosition) if err != nil { return a, notification.ErrorCmd(fmt.Sprintf("Failed to branch session: %v", err)) } + if err := store.AddSession(context.Background(), newSess); err != nil { + return a, notification.ErrorCmd(fmt.Sprintf("Failed to save branched session: %v", err)) + } + if current := a.application.Session(); current != nil { newSess.HideToolResults = current.HideToolResults newSess.ToolsApproved = current.ToolsApproved