Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,30 @@ func openDB(ds string) (*convoDB, error) {
`); err != nil {
return nil, fmt.Errorf("could not migrate db: %w", err)
}

if !hasColumn(db, "model") {
if _, err := db.Exec(`
ALTER TABLE conversations ADD COLUMN model string
`); err != nil {
return nil, fmt.Errorf("could not migrate db: %w", err)
}
}

return &convoDB{db: db}, nil
}

func hasColumn(db *sqlx.DB, col string) bool {
var count int
if err := db.Get(&count, `
SELECT count(*)
FROM pragma_table_info('conversations') c
WHERE c.name = $1
`, col); err != nil {
return false
}
return count > 0
}

type convoDB struct {
db *sqlx.DB
}
Expand All @@ -74,21 +95,23 @@ type Conversation struct {
ID string `db:"id"`
Title string `db:"title"`
UpdatedAt time.Time `db:"updated_at"`
Model *string `db:"model"`
}

func (c *convoDB) Close() error {
return c.db.Close() //nolint: wrapcheck
}

func (c *convoDB) Save(id, title string) error {
func (c *convoDB) Save(id, title, model string) error {
res, err := c.db.Exec(c.db.Rebind(`
UPDATE conversations
SET
title = ?,
model = ?,
updated_at = CURRENT_TIMESTAMP
WHERE
id = ?
`), title, id)
`), title, model, id)
if err != nil {
return fmt.Errorf("Save: %w", err)
}
Expand All @@ -104,10 +127,10 @@ func (c *convoDB) Save(id, title string) error {

if _, err := c.db.Exec(c.db.Rebind(`
INSERT INTO
conversations (id, title)
conversations (id, title, model)
VALUES
(?, ?)
`), id, title); err != nil {
(?, ?, ?)
`), id, title, model); err != nil {
return fmt.Errorf("Save: %w", err)
}

Expand Down
32 changes: 16 additions & 16 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestConvoDB(t *testing.T) {
t.Run("save", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(testid, "message 1"))
require.NoError(t, db.Save(testid, "message 1", "gpt-4o"))

convo, err := db.Find("df31")
require.NoError(t, err)
Expand All @@ -44,20 +44,20 @@ func TestConvoDB(t *testing.T) {

t.Run("save no id", func(t *testing.T) {
db := testDB(t)
require.Error(t, db.Save("", "message 1"))
require.Error(t, db.Save("", "message 1", "gpt-4o"))
})

t.Run("save no message", func(t *testing.T) {
db := testDB(t)
require.Error(t, db.Save(newConversationID(), ""))
require.Error(t, db.Save(newConversationID(), "", "gpt-4o"))
})

t.Run("update", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(testid, "message 1"))
require.NoError(t, db.Save(testid, "message 1", "gpt-4o"))
time.Sleep(100 * time.Millisecond)
require.NoError(t, db.Save(testid, "message 2"))
require.NoError(t, db.Save(testid, "message 2", "gpt-4o"))

convo, err := db.Find("df31")
require.NoError(t, err)
Expand All @@ -72,7 +72,7 @@ func TestConvoDB(t *testing.T) {
t.Run("find head single", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(testid, "message 2"))
require.NoError(t, db.Save(testid, "message 2", "gpt-4o"))

head, err := db.FindHEAD()
require.NoError(t, err)
Expand All @@ -83,10 +83,10 @@ func TestConvoDB(t *testing.T) {
t.Run("find head multiple", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(testid, "message 2"))
require.NoError(t, db.Save(testid, "message 2", "gpt-4o"))
time.Sleep(time.Millisecond * 100)
nextConvo := newConversationID()
require.NoError(t, db.Save(nextConvo, "another message"))
require.NoError(t, db.Save(nextConvo, "another message", "gpt-4o"))

head, err := db.FindHEAD()
require.NoError(t, err)
Expand All @@ -101,8 +101,8 @@ func TestConvoDB(t *testing.T) {
t.Run("find by title", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(newConversationID(), "message 1"))
require.NoError(t, db.Save(testid, "message 2"))
require.NoError(t, db.Save(newConversationID(), "message 1", "gpt-4o"))
require.NoError(t, db.Save(testid, "message 2", "gpt-4o"))

convo, err := db.Find("message 2")
require.NoError(t, err)
Expand All @@ -112,24 +112,24 @@ func TestConvoDB(t *testing.T) {

t.Run("find match nothing", func(t *testing.T) {
db := testDB(t)
require.NoError(t, db.Save(testid, "message 1"))
require.NoError(t, db.Save(testid, "message 1", "gpt-4o"))
_, err := db.Find("message")
require.ErrorIs(t, err, errNoMatches)
})

t.Run("find match many", func(t *testing.T) {
db := testDB(t)
const testid2 = "df31ae23ab9b75b5641c2f846c571000edc71315"
require.NoError(t, db.Save(testid, "message 1"))
require.NoError(t, db.Save(testid2, "message 2"))
require.NoError(t, db.Save(testid, "message 1", "gpt-4o"))
require.NoError(t, db.Save(testid2, "message 2", "gpt-4o"))
_, err := db.Find("df31ae")
require.ErrorIs(t, err, errManyMatches)
})

t.Run("delete", func(t *testing.T) {
db := testDB(t)

require.NoError(t, db.Save(testid, "message 1"))
require.NoError(t, db.Save(testid, "message 1", "gpt-4o"))
require.NoError(t, db.Delete(newConversationID()))

list, err := db.List()
Expand All @@ -152,8 +152,8 @@ func TestConvoDB(t *testing.T) {
const title1 = "some title"
const testid2 = "6c33f71694bf41a18c844a96d1f62f153e5f6f44"
const title2 = "football teams"
require.NoError(t, db.Save(testid1, title1))
require.NoError(t, db.Save(testid2, title2))
require.NoError(t, db.Save(testid1, title1, "gpt-4o"))
require.NoError(t, db.Save(testid2, title2, "gpt-4o"))

results, err := db.Completions("f")
require.NoError(t, err)
Expand Down
5 changes: 4 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ func makeOptions(conversations []Conversation) []huh.Option[string] {
timea := stdoutStyles().Timeago.Render(timeago.Of(c.UpdatedAt))
left := stdoutStyles().SHA1.Render(c.ID[:sha1short])
right := stdoutStyles().ConversationList.Render(c.Title, timea)
if c.Model != nil {
right += stdoutStyles().Comment.Render(*c.Model)
}
opts = append(opts, huh.NewOption(left+" "+right, c.ID))
}
return opts
Expand Down Expand Up @@ -688,7 +691,7 @@ func saveConversation(mods *Mods) error {
stderrStyles().InlineCode.Render("NO_CACHE"),
)}
}
if err := db.Save(id, title); err != nil {
if err := db.Save(id, title, config.Model); err != nil {
_ = cache.delete(id) // remove leftovers
return modsError{err, fmt.Sprintf(
"There was a problem writing %s to the cache. Use %s / %s to disable it.",
Expand Down
22 changes: 15 additions & 7 deletions mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func (m *Mods) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.Config.cacheWriteToID = msg.WriteID
m.Config.cacheWriteToTitle = msg.Title
m.Config.cacheReadFromID = msg.ReadID
m.Config.Model = msg.Model

if !m.Config.Quiet {
m.anim = newAnim(m.Config.Fanciness, m.Config.StatusText, m.renderer, m.Styles)
Expand Down Expand Up @@ -474,7 +475,7 @@ func (m *Mods) receiveCompletionStreamCmd(msg completionOutput) tea.Cmd {
}

type cacheDetailsMsg struct {
WriteID, Title, ReadID string
WriteID, Title, ReadID, Model string
}

func (m *Mods) findCacheOpsDetails() tea.Cmd {
Expand All @@ -483,6 +484,7 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd {
readID := ordered.First(m.Config.Continue, m.Config.Show)
writeID := ordered.First(m.Config.Title, m.Config.Continue)
title := writeID
model := config.Model

if readID != "" || continueLast || m.Config.ShowLast {
found, err := m.findReadID(readID)
Expand All @@ -492,7 +494,12 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd {
reason: "Could not find the conversation.",
}
}
readID = found
if found != nil {
readID = found.ID
if found.Model != nil {
model = *found.Model
}
}
}

// if we are continuing last, update the existing conversation
Expand All @@ -518,23 +525,24 @@ func (m *Mods) findCacheOpsDetails() tea.Cmd {
WriteID: writeID,
Title: title,
ReadID: readID,
Model: model,
}
}
}

func (m *Mods) findReadID(in string) (string, error) {
func (m *Mods) findReadID(in string) (*Conversation, error) {
convo, err := m.db.Find(in)
if err == nil {
return convo.ID, nil
return convo, nil
}
if errors.Is(err, errNoMatches) && m.Config.Show == "" {
convo, err := m.db.FindHEAD()
if err != nil {
return "", err
return nil, err
}
return convo.ID, nil
return convo, nil
}
return "", err
return nil, err
}

func (m *Mods) readStdinCmd() tea.Msg {
Expand Down
18 changes: 9 additions & 9 deletions mods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("show id", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message"))
require.NoError(t, mods.db.Save(id, "message", "gpt-4"))
mods.Config.Show = id[:8]
msg := mods.findCacheOpsDetails()()
dets := msg.(cacheDetailsMsg)
Expand All @@ -38,7 +38,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("show title", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.Show = "message 1"
msg := mods.findCacheOpsDetails()()
dets := msg.(cacheDetailsMsg)
Expand All @@ -48,7 +48,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue id", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message"))
require.NoError(t, mods.db.Save(id, "message", "gpt-4"))
mods.Config.Continue = id[:5]
mods.Config.Prefix = "prompt"
msg := mods.findCacheOpsDetails()()
Expand All @@ -60,7 +60,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue with no prompt", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.ContinueLast = true
msg := mods.findCacheOpsDetails()()
dets := msg.(cacheDetailsMsg)
Expand All @@ -72,7 +72,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue title", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.Continue = "message 1"
mods.Config.Prefix = "prompt"
msg := mods.findCacheOpsDetails()()
Expand All @@ -84,7 +84,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue last", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.ContinueLast = true
mods.Config.Prefix = "prompt"
msg := mods.findCacheOpsDetails()()
Expand All @@ -97,7 +97,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue last with name", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.Continue = "message 2"
mods.Config.Prefix = "prompt"
msg := mods.findCacheOpsDetails()()
Expand All @@ -122,7 +122,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue id and write with title", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.Title = "some title"
mods.Config.Continue = id[:10]
msg := mods.findCacheOpsDetails()()
Expand All @@ -137,7 +137,7 @@ func TestFindCacheOpsDetails(t *testing.T) {
t.Run("continue title and write with title", func(t *testing.T) {
mods := newMods(t)
id := newConversationID()
require.NoError(t, mods.db.Save(id, "message 1"))
require.NoError(t, mods.db.Save(id, "message 1", "gpt-4"))
mods.Config.Title = "some title"
mods.Config.Continue = "message 1"
msg := mods.findCacheOpsDetails()()
Expand Down