From 2a460dd71bfaa8ba22b66599fc00f82110d05fc1 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Tue, 18 Jun 2024 17:36:45 -0300 Subject: [PATCH] feat: continue using the same model fixes #263 --- db.go | 33 ++++++++++++++++++++++++++++----- db_test.go | 32 ++++++++++++++++---------------- main.go | 5 ++++- mods.go | 22 +++++++++++++++------- mods_test.go | 18 +++++++++--------- 5 files changed, 72 insertions(+), 38 deletions(-) diff --git a/db.go b/db.go index ef46763b..714ef26d 100644 --- a/db.go +++ b/db.go @@ -62,9 +62,30 @@ func openDB(ds string) (*convoDB, error) { `); err != nil { return nil, fmt.Errorf("could not migrate db: %w", err) } + + if !hasColumn(db, "model") { + if _, err := db.Exec(` + ALTER TABLE conversations ADD COLUMN model string + `); err != nil { + return nil, fmt.Errorf("could not migrate db: %w", err) + } + } + return &convoDB{db: db}, nil } +func hasColumn(db *sqlx.DB, col string) bool { + var count int + if err := db.Get(&count, ` + SELECT count(*) + FROM pragma_table_info('conversations') c + WHERE c.name = $1 + `, col); err != nil { + return false + } + return count > 0 +} + type convoDB struct { db *sqlx.DB } @@ -74,21 +95,23 @@ type Conversation struct { ID string `db:"id"` Title string `db:"title"` UpdatedAt time.Time `db:"updated_at"` + Model *string `db:"model"` } func (c *convoDB) Close() error { return c.db.Close() //nolint: wrapcheck } -func (c *convoDB) Save(id, title string) error { +func (c *convoDB) Save(id, title, model string) error { res, err := c.db.Exec(c.db.Rebind(` UPDATE conversations SET title = ?, + model = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? - `), title, id) + `), title, model, id) if err != nil { return fmt.Errorf("Save: %w", err) } @@ -104,10 +127,10 @@ func (c *convoDB) Save(id, title string) error { if _, err := c.db.Exec(c.db.Rebind(` INSERT INTO - conversations (id, title) + conversations (id, title, model) VALUES - (?, ?) - `), id, title); err != nil { + (?, ?, ?) + `), id, title, model); err != nil { return fmt.Errorf("Save: %w", err) } diff --git a/db_test.go b/db_test.go index 35bb7e22..1e0930f3 100644 --- a/db_test.go +++ b/db_test.go @@ -30,7 +30,7 @@ func TestConvoDB(t *testing.T) { t.Run("save", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 1")) + require.NoError(t, db.Save(testid, "message 1", "gpt-4o")) convo, err := db.Find("df31") require.NoError(t, err) @@ -44,20 +44,20 @@ func TestConvoDB(t *testing.T) { t.Run("save no id", func(t *testing.T) { db := testDB(t) - require.Error(t, db.Save("", "message 1")) + require.Error(t, db.Save("", "message 1", "gpt-4o")) }) t.Run("save no message", func(t *testing.T) { db := testDB(t) - require.Error(t, db.Save(newConversationID(), "")) + require.Error(t, db.Save(newConversationID(), "", "gpt-4o")) }) t.Run("update", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 1")) + require.NoError(t, db.Save(testid, "message 1", "gpt-4o")) time.Sleep(100 * time.Millisecond) - require.NoError(t, db.Save(testid, "message 2")) + require.NoError(t, db.Save(testid, "message 2", "gpt-4o")) convo, err := db.Find("df31") require.NoError(t, err) @@ -72,7 +72,7 @@ func TestConvoDB(t *testing.T) { t.Run("find head single", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 2")) + require.NoError(t, db.Save(testid, "message 2", "gpt-4o")) head, err := db.FindHEAD() require.NoError(t, err) @@ -83,10 +83,10 @@ func TestConvoDB(t *testing.T) { t.Run("find head multiple", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 2")) + require.NoError(t, db.Save(testid, "message 2", "gpt-4o")) time.Sleep(time.Millisecond * 100) nextConvo := newConversationID() - require.NoError(t, db.Save(nextConvo, "another message")) + require.NoError(t, db.Save(nextConvo, "another message", "gpt-4o")) head, err := db.FindHEAD() require.NoError(t, err) @@ -101,8 +101,8 @@ func TestConvoDB(t *testing.T) { t.Run("find by title", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(newConversationID(), "message 1")) - require.NoError(t, db.Save(testid, "message 2")) + require.NoError(t, db.Save(newConversationID(), "message 1", "gpt-4o")) + require.NoError(t, db.Save(testid, "message 2", "gpt-4o")) convo, err := db.Find("message 2") require.NoError(t, err) @@ -112,7 +112,7 @@ func TestConvoDB(t *testing.T) { t.Run("find match nothing", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 1")) + require.NoError(t, db.Save(testid, "message 1", "gpt-4o")) _, err := db.Find("message") require.ErrorIs(t, err, errNoMatches) }) @@ -120,8 +120,8 @@ func TestConvoDB(t *testing.T) { t.Run("find match many", func(t *testing.T) { db := testDB(t) const testid2 = "df31ae23ab9b75b5641c2f846c571000edc71315" - require.NoError(t, db.Save(testid, "message 1")) - require.NoError(t, db.Save(testid2, "message 2")) + require.NoError(t, db.Save(testid, "message 1", "gpt-4o")) + require.NoError(t, db.Save(testid2, "message 2", "gpt-4o")) _, err := db.Find("df31ae") require.ErrorIs(t, err, errManyMatches) }) @@ -129,7 +129,7 @@ func TestConvoDB(t *testing.T) { t.Run("delete", func(t *testing.T) { db := testDB(t) - require.NoError(t, db.Save(testid, "message 1")) + require.NoError(t, db.Save(testid, "message 1", "gpt-4o")) require.NoError(t, db.Delete(newConversationID())) list, err := db.List() @@ -152,8 +152,8 @@ func TestConvoDB(t *testing.T) { const title1 = "some title" const testid2 = "6c33f71694bf41a18c844a96d1f62f153e5f6f44" const title2 = "football teams" - require.NoError(t, db.Save(testid1, title1)) - require.NoError(t, db.Save(testid2, title2)) + require.NoError(t, db.Save(testid1, title1, "gpt-4o")) + require.NoError(t, db.Save(testid2, title2, "gpt-4o")) results, err := db.Completions("f") require.NoError(t, err) diff --git a/main.go b/main.go index 14c98918..5de8e797 100644 --- a/main.go +++ b/main.go @@ -609,6 +609,9 @@ func makeOptions(conversations []Conversation) []huh.Option[string] { timea := stdoutStyles().Timeago.Render(timeago.Of(c.UpdatedAt)) left := stdoutStyles().SHA1.Render(c.ID[:sha1short]) right := stdoutStyles().ConversationList.Render(c.Title, timea) + if c.Model != nil { + right += stdoutStyles().Comment.Render(*c.Model) + } opts = append(opts, huh.NewOption(left+" "+right, c.ID)) } return opts @@ -688,7 +691,7 @@ func saveConversation(mods *Mods) error { stderrStyles().InlineCode.Render("NO_CACHE"), )} } - if err := db.Save(id, title); err != nil { + if err := db.Save(id, title, config.Model); err != nil { _ = cache.delete(id) // remove leftovers return modsError{err, fmt.Sprintf( "There was a problem writing %s to the cache. Use %s / %s to disable it.", diff --git a/mods.go b/mods.go index 157eb056..aa9bdf84 100644 --- a/mods.go +++ b/mods.go @@ -111,6 +111,7 @@ func (m *Mods) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.Config.cacheWriteToID = msg.WriteID m.Config.cacheWriteToTitle = msg.Title m.Config.cacheReadFromID = msg.ReadID + m.Config.Model = msg.Model if !m.Config.Quiet { m.anim = newAnim(m.Config.Fanciness, m.Config.StatusText, m.renderer, m.Styles) @@ -474,7 +475,7 @@ func (m *Mods) receiveCompletionStreamCmd(msg completionOutput) tea.Cmd { } type cacheDetailsMsg struct { - WriteID, Title, ReadID string + WriteID, Title, ReadID, Model string } func (m *Mods) findCacheOpsDetails() tea.Cmd { @@ -483,6 +484,7 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd { readID := ordered.First(m.Config.Continue, m.Config.Show) writeID := ordered.First(m.Config.Title, m.Config.Continue) title := writeID + model := config.Model if readID != "" || continueLast || m.Config.ShowLast { found, err := m.findReadID(readID) @@ -492,7 +494,12 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd { reason: "Could not find the conversation.", } } - readID = found + if found != nil { + readID = found.ID + if found.Model != nil { + model = *found.Model + } + } } // if we are continuing last, update the existing conversation @@ -518,23 +525,24 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd { WriteID: writeID, Title: title, ReadID: readID, + Model: model, } } } -func (m *Mods) findReadID(in string) (string, error) { +func (m *Mods) findReadID(in string) (*Conversation, error) { convo, err := m.db.Find(in) if err == nil { - return convo.ID, nil + return convo, nil } if errors.Is(err, errNoMatches) && m.Config.Show == "" { convo, err := m.db.FindHEAD() if err != nil { - return "", err + return nil, err } - return convo.ID, nil + return convo, nil } - return "", err + return nil, err } func (m *Mods) readStdinCmd() tea.Msg { diff --git a/mods_test.go b/mods_test.go index c30d1d5b..59d3ba0d 100644 --- a/mods_test.go +++ b/mods_test.go @@ -28,7 +28,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("show id", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message")) + require.NoError(t, mods.db.Save(id, "message", "gpt-4")) mods.Config.Show = id[:8] msg := mods.findCacheOpsDetails()() dets := msg.(cacheDetailsMsg) @@ -38,7 +38,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("show title", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.Show = "message 1" msg := mods.findCacheOpsDetails()() dets := msg.(cacheDetailsMsg) @@ -48,7 +48,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue id", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message")) + require.NoError(t, mods.db.Save(id, "message", "gpt-4")) mods.Config.Continue = id[:5] mods.Config.Prefix = "prompt" msg := mods.findCacheOpsDetails()() @@ -60,7 +60,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue with no prompt", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.ContinueLast = true msg := mods.findCacheOpsDetails()() dets := msg.(cacheDetailsMsg) @@ -72,7 +72,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue title", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.Continue = "message 1" mods.Config.Prefix = "prompt" msg := mods.findCacheOpsDetails()() @@ -84,7 +84,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue last", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.ContinueLast = true mods.Config.Prefix = "prompt" msg := mods.findCacheOpsDetails()() @@ -97,7 +97,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue last with name", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.Continue = "message 2" mods.Config.Prefix = "prompt" msg := mods.findCacheOpsDetails()() @@ -122,7 +122,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue id and write with title", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.Title = "some title" mods.Config.Continue = id[:10] msg := mods.findCacheOpsDetails()() @@ -137,7 +137,7 @@ func TestFindCacheOpsDetails(t *testing.T) { t.Run("continue title and write with title", func(t *testing.T) { mods := newMods(t) id := newConversationID() - require.NoError(t, mods.db.Save(id, "message 1")) + require.NoError(t, mods.db.Save(id, "message 1", "gpt-4")) mods.Config.Title = "some title" mods.Config.Continue = "message 1" msg := mods.findCacheOpsDetails()()