Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type LLMConfig struct {
MaxTokens int `yaml:"max_tokens" json:"max_tokens" jsonschema:"default=500,description=Maximum tokens in response"`
Timeout time.Duration `yaml:"timeout" json:"timeout" jsonschema:"default=30s,description=Request timeout"`
SystemPrompt string `yaml:"system_prompt" json:"system_prompt" jsonschema:"description=System prompt for the LLM (optional)"`
UseStreaming bool `yaml:"use_streaming" json:"use_streaming" jsonschema:"default=false,description=Use streaming mode (required by some providers e.g. ChatGPT subscription via litellm)"`
Classification ClassificationConfig `yaml:"classification" json:"classification" jsonschema:"description=Classification-specific settings"`
}

Expand Down
74 changes: 71 additions & 3 deletions pkg/llm/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package llm
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"strings"
"time"
Expand Down Expand Up @@ -59,6 +61,72 @@ func (c *Classifier) GetBatchTimeout() time.Duration {
return 5 * time.Second // default
}

// createChatCompletion dispatches to streaming or non-streaming based on config.
// Streaming mode is required by some providers (e.g. ChatGPT subscription via litellm)
// whose non-streaming response path is broken. The streamed deltas are accumulated
// into a ChatCompletionResponse with the same shape as the non-streaming API returns.
func (c *Classifier) createChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
if !c.config.UseStreaming {
return c.client.CreateChatCompletion(ctx, req)
}

req.Stream = true
req.StreamOptions = &openai.StreamOptions{IncludeUsage: true}

stream, err := c.client.CreateChatCompletionStream(ctx, req)
if err != nil {
return openai.ChatCompletionResponse{}, err
}
defer stream.Close()

var content, reasoning strings.Builder
var role, finishReason string
resp := openai.ChatCompletionResponse{Object: "chat.completion"}
for {
chunk, recvErr := stream.Recv()
if errors.Is(recvErr, io.EOF) {
break
}
if recvErr != nil {
return openai.ChatCompletionResponse{}, recvErr
}
if resp.ID == "" {
resp.ID = chunk.ID
resp.Created = chunk.Created
resp.Model = chunk.Model
}
if chunk.Usage != nil {
resp.Usage = *chunk.Usage
}
if len(chunk.Choices) == 0 {
continue
}
delta := chunk.Choices[0].Delta
if delta.Role != "" {
role = delta.Role
}
content.WriteString(delta.Content)
reasoning.WriteString(delta.ReasoningContent)
if chunk.Choices[0].FinishReason != "" {
finishReason = string(chunk.Choices[0].FinishReason)
}
}

if role == "" {
role = openai.ChatMessageRoleAssistant
}
resp.Choices = []openai.ChatCompletionChoice{{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: role,
Content: content.String(),
ReasoningContent: reasoning.String(),
},
FinishReason: openai.FinishReason(finishReason),
}}
return resp, nil
}

// default system prompt for article classification
const defaultSystemPrompt = `You are an AI assistant that evaluates articles for relevance to the user's interests.
Rate each article from 0-10 where:
Expand Down Expand Up @@ -196,7 +264,7 @@ func (c *Classifier) classify(ctx context.Context, req ClassifyRequest) ([]domai
}

// call the LLM
resp, err := c.client.CreateChatCompletion(ctx, chatReq)
resp, err := c.createChatCompletion(ctx, chatReq)
if err != nil {
// all errors will be retried by repeater
return fmt.Errorf("llm request failed: %w", err)
Expand Down Expand Up @@ -535,7 +603,7 @@ func (c *Classifier) GeneratePreferenceSummary(ctx context.Context, feedback []d
},
}

resp, err := c.client.CreateChatCompletion(ctx, req)
resp, err := c.createChatCompletion(ctx, req)
if err != nil {
return fmt.Errorf("generate preference summary failed: %w", err)
}
Expand Down Expand Up @@ -620,7 +688,7 @@ func (c *Classifier) UpdatePreferenceSummary(ctx context.Context, currentSummary
},
}

resp, err := c.client.CreateChatCompletion(ctx, req)
resp, err := c.createChatCompletion(ctx, req)
if err != nil {
return fmt.Errorf("update preference summary failed: %w", err)
}
Expand Down
80 changes: 80 additions & 0 deletions pkg/llm/classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1226,3 +1226,83 @@ func TestClassifier_CustomForbiddenPrefixes(t *testing.T) {
assert.False(t, classifier.hasForbiddenPrefix("The article discusses")) // not in custom list
assert.False(t, classifier.hasForbiddenPrefix("Results indicate"))
}

func TestClassifier_Streaming(t *testing.T) {
// SSE response covering the 2 articles from the classify prompt
sseBody := `data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"role":"assistant","content":"["}}]}

data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"content":"{\"guid\":\"item1\",\"score\":7,\"explanation\":\"ok\",\"topics\":[\"go\"],\"summary\":\"Go 1.22 ships range-over-func iterators with 50% faster compilation, better GC, and new toolchain management that simplifies Go version control across projects and CI pipelines today.\"}"}}]}

data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"content":",{\"guid\":\"item2\",\"score\":2,\"explanation\":\"no\",\"topics\":[\"sports\"],\"summary\":\"Manchester United beats Chelsea 3-1 with Bruno Fernandes scoring twice while Liverpool holds top spot after 2-2 draw with Arsenal at Emirates Stadium in Premier League weekend action.\"}]"}}]}

data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}}

data: [DONE]

`

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1/chat/completions", r.URL.Path)
var body map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&body))
assert.Equal(t, true, body["stream"], "stream must be true when UseStreaming is set")

w.Header().Set("Content-Type", "text/event-stream")
_, err := w.Write([]byte(sseBody))
assert.NoError(t, err)
}))
defer server.Close()

cfg := config.LLMConfig{
Endpoint: server.URL + "/v1",
APIKey: "test-key",
Model: "gpt-test",
Temperature: 0.3,
MaxTokens: 100,
UseStreaming: true,
}
classifier := NewClassifier(cfg)

classifications, err := classifier.ClassifyItems(context.Background(), ClassifyRequest{
Articles: []domain.Item{
{GUID: "item1", Title: "Go 1.22 released"},
{GUID: "item2", Title: "Football news"},
},
})
require.NoError(t, err)
require.Len(t, classifications, 2)
assert.Equal(t, "item1", classifications[0].GUID)
assert.InDelta(t, 7.0, classifications[0].Score, 0.01)
assert.Equal(t, "item2", classifications[1].GUID)
}

func TestClassifier_Streaming_DefaultOff(t *testing.T) {
// without UseStreaming, request must not set stream=true
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
assert.NoError(t, json.NewDecoder(r.Body).Decode(&body))
stream, _ := body["stream"].(bool)
assert.False(t, stream)

resp := openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{Content: `[{"guid":"x","score":1,"explanation":"e","topics":["t"],"summary":"Short and punchy summary covering the main points without preamble and staying within the required 300-500 character window so the classifier accepts it."}]`},
}},
}
w.Header().Set("Content-Type", "application/json")
assert.NoError(t, json.NewEncoder(w).Encode(resp))
}))
defer server.Close()

cfg := config.LLMConfig{
Endpoint: server.URL + "/v1",
APIKey: "test-key",
Model: "gpt-test",
}
classifier := NewClassifier(cfg)

_, err := classifier.ClassifyItems(context.Background(), ClassifyRequest{
Articles: []domain.Item{{GUID: "x", Title: "t"}},
})
require.NoError(t, err)
}