Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions pkg/session/branch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
"github.com/docker/cagent/pkg/tools"
)

func buildBranchedSession(parent *Session, branchAtPosition int) (*Session, error) {
// BranchSession creates a new session branched from the parent at the given position.
// Messages up to (but not including) branchAtPosition are deep-cloned into the new session.
func BranchSession(parent *Session, branchAtPosition int) (*Session, error) {
if parent == nil {
return nil, fmt.Errorf("parent session is nil")
}
Expand Down Expand Up @@ -242,15 +244,3 @@ func recalculateSessionTotals(sess *Session) {
sess.OutputTokens = outputTokens
sess.Cost = cost
}

func collectSessionIDs(sess *Session, ids map[string]struct{}) {
if sess == nil || ids == nil {
return
}
ids[sess.ID] = struct{}{}
for _, item := range sess.Messages {
if item.SubSession != nil {
collectSessionIDs(item.SubSession, ids)
}
}
}
12 changes: 6 additions & 6 deletions pkg/session/branch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,30 +90,30 @@ func TestCloneSessionItem(t *testing.T) {
})
}

func TestBuildBranchedSession(t *testing.T) {
func TestBranchSession(t *testing.T) {
t.Run("nil parent returns error", func(t *testing.T) {
_, err := buildBranchedSession(nil, 0)
_, err := BranchSession(nil, 0)
require.Error(t, err)
assert.Contains(t, err.Error(), "parent session is nil")
})

t.Run("negative position returns error", func(t *testing.T) {
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
_, err := buildBranchedSession(parent, -1)
_, err := BranchSession(parent, -1)
require.Error(t, err)
assert.Contains(t, err.Error(), "out of range")
})

t.Run("position beyond messages returns error", func(t *testing.T) {
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
_, err := buildBranchedSession(parent, 2)
_, err := BranchSession(parent, 2)
require.Error(t, err)
assert.Contains(t, err.Error(), "out of range")
})

t.Run("position equal to messages length returns error", func(t *testing.T) {
parent := &Session{Messages: []Item{NewMessageItem(UserMessage("test"))}}
_, err := buildBranchedSession(parent, 1)
_, err := BranchSession(parent, 1)
require.Error(t, err)
assert.Contains(t, err.Error(), "out of range")
})
Expand All @@ -129,7 +129,7 @@ func TestBuildBranchedSession(t *testing.T) {
},
}

branched, err := buildBranchedSession(parent, 2)
branched, err := BranchSession(parent, 2)
require.NoError(t, err)
assert.NotNil(t, branched)

Expand Down
128 changes: 18 additions & 110 deletions pkg/session/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,18 @@ func ResolveSessionID(ctx context.Context, store Store, ref string) (string, err
if !isRelative {
return ref, nil
}
return store.GetSessionByOffset(ctx, offset)

summaries, err := store.GetSessionSummaries(ctx)
if err != nil {
return "", fmt.Errorf("getting session summaries: %w", err)
}

index := offset - 1
if index >= len(summaries) {
return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(summaries))
}

return summaries[index].ID, nil
}

// Summary contains lightweight session metadata for listing purposes.
Expand All @@ -72,12 +83,6 @@ type Store interface {
DeleteSession(ctx context.Context, id string) error
UpdateSession(ctx context.Context, session *Session) error // Updates metadata only (not messages/items)
SetSessionStarred(ctx context.Context, id string, starred bool) error
BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error)

// GetSessionByOffset returns the session ID at the given offset from the most recent.
// Offset 1 returns the most recent session, 2 returns the second most recent, etc.
// Only root sessions are considered (sub-sessions are excluded).
GetSessionByOffset(ctx context.Context, offset int) (string, error)

// === Granular item operations ===

Expand Down Expand Up @@ -147,6 +152,9 @@ func (s *InMemorySessionStore) GetSessions(_ context.Context) ([]*Session, error
func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary, error) {
summaries := make([]Summary, 0, s.sessions.Length())
s.sessions.Range(func(_ string, value *Session) bool {
if value.ParentID != "" {
return true
}
summaries = append(summaries, Summary{
ID: value.ID,
Title: value.Title,
Expand All @@ -156,6 +164,9 @@ func (s *InMemorySessionStore) GetSessionSummaries(_ context.Context) ([]Summary
})
return true
})
sort.Slice(summaries, func(i, j int) bool {
return summaries[i].CreatedAt.After(summaries[j].CreatedAt)
})
return summaries, nil
}

Expand Down Expand Up @@ -206,25 +217,6 @@ func (s *InMemorySessionStore) SetSessionStarred(_ context.Context, id string, s
return nil
}

// BranchSession creates a new session branched from the parent at the given position.
func (s *InMemorySessionStore) BranchSession(_ context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
if parentSessionID == "" {
return nil, ErrEmptyID
}
parent, exists := s.sessions.Load(parentSessionID)
if !exists {
return nil, ErrNotFound
}

branched, err := buildBranchedSession(parent, branchAtPosition)
if err != nil {
return nil, err
}

s.sessions.Store(branched.ID, branched)
return branched, nil
}

// AddMessage adds a message to a session at the next position.
// Returns the ID of the created message (for in-memory, this is a simple counter).
func (s *InMemorySessionStore) AddMessage(_ context.Context, sessionID string, msg *Message) (int64, error) {
Expand Down Expand Up @@ -358,34 +350,6 @@ func (s *InMemorySessionStore) UpdateSessionTitle(_ context.Context, sessionID,
return nil
}

// GetSessionByOffset returns the session ID at the given offset from the most recent.
func (s *InMemorySessionStore) GetSessionByOffset(_ context.Context, offset int) (string, error) {
if offset < 1 {
return "", fmt.Errorf("offset must be >= 1, got %d", offset)
}

// Collect and sort sessions by creation time (newest first)
var sessions []*Session
s.sessions.Range(func(_ string, value *Session) bool {
// Only include root sessions (not sub-sessions)
if value.ParentID == "" {
sessions = append(sessions, value)
}
return true
})

sort.Slice(sessions, func(i, j int) bool {
return sessions[i].CreatedAt.After(sessions[j].CreatedAt)
})

index := offset - 1 // offset 1 means index 0 (most recent session)
if index >= len(sessions) {
return "", fmt.Errorf("session offset %d out of range (have %d sessions)", offset, len(sessions))
}

return sessions[index].ID, nil
}

// NewSQLiteSessionStore creates a new SQLite session store
func NewSQLiteSessionStore(path string) (Store, error) {
store, err := openAndMigrateSQLiteStore(path)
Expand Down Expand Up @@ -1091,37 +1055,6 @@ func (s *SQLiteSessionStore) SetSessionStarred(ctx context.Context, id string, s
return nil
}

// BranchSession creates a new session branched from the parent at the given position.
func (s *SQLiteSessionStore) BranchSession(ctx context.Context, parentSessionID string, branchAtPosition int) (*Session, error) {
if parentSessionID == "" {
return nil, ErrEmptyID
}

parent, err := s.GetSession(ctx, parentSessionID)
if err != nil {
return nil, err
}

branched, err := buildBranchedSession(parent, branchAtPosition)
if err != nil {
return nil, err
}

if err := s.AddSession(ctx, branched); err != nil {
return nil, err
}

ids := make(map[string]struct{})
collectSessionIDs(branched, ids)
for id := range ids {
if err := s.syncMessagesColumn(ctx, id); err != nil {
slog.Warn("[STORE] Failed to sync messages column after branch", "session_id", id, "error", err)
}
}

return branched, nil
}

// Close closes the database connection
func (s *SQLiteSessionStore) Close() error {
return s.db.Close()
Expand Down Expand Up @@ -1400,28 +1333,3 @@ func (s *SQLiteSessionStore) UpdateSessionTitle(ctx context.Context, sessionID,
title, sessionID)
return err
}

// GetSessionByOffset returns the session ID at the given offset from the most recent.
func (s *SQLiteSessionStore) GetSessionByOffset(ctx context.Context, offset int) (string, error) {
if offset < 1 {
return "", fmt.Errorf("offset must be >= 1, got %d", offset)
}

// Query sessions ordered by creation time (newest first), limited to offset
// Only include root sessions (not sub-sessions)
var sessionID string
err := s.db.QueryRowContext(ctx,
`SELECT id FROM sessions
WHERE parent_id IS NULL OR parent_id = ''
ORDER BY created_at DESC
LIMIT 1 OFFSET ?`,
offset-1).Scan(&sessionID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", fmt.Errorf("session offset %d out of range", offset)
}
return "", err
}

return sessionID, nil
}
116 changes: 12 additions & 104 deletions pkg/session/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,13 @@ func TestBranchSessionCopiesPrefix(t *testing.T) {

require.NoError(t, store.AddSession(t.Context(), parent))

branched, err := store.BranchSession(t.Context(), parent.ID, 2)
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
require.NoError(t, err)

branched, err := BranchSession(parentLoaded, 2)
require.NoError(t, err)

require.NoError(t, store.AddSession(t.Context(), branched))
require.NotNil(t, branched.BranchParentPosition)
assert.Equal(t, parent.ID, branched.BranchParentSessionID)
assert.Equal(t, 2, *branched.BranchParentPosition)
Expand Down Expand Up @@ -289,9 +294,14 @@ func TestBranchSessionClonesSubSession(t *testing.T) {

require.NoError(t, store.AddSession(t.Context(), parent))

branched, err := store.BranchSession(t.Context(), parent.ID, 2)
parentLoaded, err := store.GetSession(t.Context(), parent.ID)
require.NoError(t, err)

branched, err := BranchSession(parentLoaded, 2)
require.NoError(t, err)

require.NoError(t, store.AddSession(t.Context(), branched))

loaded, err := store.GetSession(t.Context(), branched.ID)
require.NoError(t, err)
require.Len(t, loaded.Messages, 2)
Expand Down Expand Up @@ -1274,105 +1284,3 @@ func TestResolveSessionID_InMemory(t *testing.T) {
assert.Equal(t, "some-uuid", id)
})
}

func TestGetSessionByOffset_SQLite(t *testing.T) {
tempDB := filepath.Join(t.TempDir(), "test_offset.db")

store, err := NewSQLiteSessionStore(tempDB)
require.NoError(t, err)
defer store.(*SQLiteSessionStore).Close()

// Create sessions with known timestamps
baseTime := time.Now()
sessions := []struct {
id string
createdAt time.Time
}{
{"oldest", baseTime.Add(-3 * time.Hour)},
{"middle", baseTime.Add(-2 * time.Hour)},
{"newest", baseTime.Add(-1 * time.Hour)},
}

for _, s := range sessions {
err := store.AddSession(t.Context(), &Session{
ID: s.id,
CreatedAt: s.createdAt,
})
require.NoError(t, err)
}

t.Run("offset 1 returns newest", func(t *testing.T) {
id, err := store.GetSessionByOffset(t.Context(), 1)
require.NoError(t, err)
assert.Equal(t, "newest", id)
})

t.Run("offset 2 returns middle", func(t *testing.T) {
id, err := store.GetSessionByOffset(t.Context(), 2)
require.NoError(t, err)
assert.Equal(t, "middle", id)
})

t.Run("offset 3 returns oldest", func(t *testing.T) {
id, err := store.GetSessionByOffset(t.Context(), 3)
require.NoError(t, err)
assert.Equal(t, "oldest", id)
})

t.Run("offset 0 returns error", func(t *testing.T) {
_, err := store.GetSessionByOffset(t.Context(), 0)
require.Error(t, err)
})

t.Run("out of range offset returns error", func(t *testing.T) {
_, err := store.GetSessionByOffset(t.Context(), 4)
require.Error(t, err)
assert.Contains(t, err.Error(), "out of range")
})
}

func TestGetSessionByOffset_InMemory(t *testing.T) {
store := NewInMemorySessionStore()

// Create sessions with known timestamps
baseTime := time.Now()
sessions := []struct {
id string
createdAt time.Time
}{
{"oldest", baseTime.Add(-3 * time.Hour)},
{"middle", baseTime.Add(-2 * time.Hour)},
{"newest", baseTime.Add(-1 * time.Hour)},
}

for _, s := range sessions {
err := store.AddSession(t.Context(), &Session{
ID: s.id,
CreatedAt: s.createdAt,
})
require.NoError(t, err)
}

t.Run("offset 1 returns newest", func(t *testing.T) {
id, err := store.GetSessionByOffset(t.Context(), 1)
require.NoError(t, err)
assert.Equal(t, "newest", id)
})

t.Run("offset 2 returns middle", func(t *testing.T) {
id, err := store.GetSessionByOffset(t.Context(), 2)
require.NoError(t, err)
assert.Equal(t, "middle", id)
})

t.Run("offset 0 returns error", func(t *testing.T) {
_, err := store.GetSessionByOffset(t.Context(), 0)
require.Error(t, err)
})

t.Run("out of range offset returns error", func(t *testing.T) {
_, err := store.GetSessionByOffset(t.Context(), 4)
require.Error(t, err)
assert.Contains(t, err.Error(), "out of range")
})
}
Loading
Loading