From 8fe78b7e7ecc45d166335e4ea11e099b9748323d Mon Sep 17 00:00:00 2001 From: Caspar Chou Date: Sun, 19 Apr 2026 16:06:22 +0800 Subject: [PATCH] add use_streaming config to use chat completion streaming --- pkg/config/config.go | 1 + pkg/llm/classifier.go | 74 +++++++++++++++++++++++++++++++++-- pkg/llm/classifier_test.go | 80 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 3 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 2e164d0..19ea942 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -70,6 +70,7 @@ type LLMConfig struct { MaxTokens int `yaml:"max_tokens" json:"max_tokens" jsonschema:"default=500,description=Maximum tokens in response"` Timeout time.Duration `yaml:"timeout" json:"timeout" jsonschema:"default=30s,description=Request timeout"` SystemPrompt string `yaml:"system_prompt" json:"system_prompt" jsonschema:"description=System prompt for the LLM (optional)"` + UseStreaming bool `yaml:"use_streaming" json:"use_streaming" jsonschema:"default=false,description=Use streaming mode (required by some providers e.g. ChatGPT subscription via litellm)"` Classification ClassificationConfig `yaml:"classification" json:"classification" jsonschema:"description=Classification-specific settings"` } diff --git a/pkg/llm/classifier.go b/pkg/llm/classifier.go index 4f19d2b..e7cf30d 100644 --- a/pkg/llm/classifier.go +++ b/pkg/llm/classifier.go @@ -3,7 +3,9 @@ package llm import ( "context" "encoding/json" + "errors" "fmt" + "io" "log" "strings" "time" @@ -59,6 +61,72 @@ func (c *Classifier) GetBatchTimeout() time.Duration { return 5 * time.Second // default } +// createChatCompletion dispatches to streaming or non-streaming based on config. +// Streaming mode is required by some providers (e.g. ChatGPT subscription via litellm) +// whose non-streaming response path is broken. The streamed deltas are accumulated +// into a ChatCompletionResponse with the same shape as the non-streaming API returns. +func (c *Classifier) createChatCompletion(ctx context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) { + if !c.config.UseStreaming { + return c.client.CreateChatCompletion(ctx, req) + } + + req.Stream = true + req.StreamOptions = &openai.StreamOptions{IncludeUsage: true} + + stream, err := c.client.CreateChatCompletionStream(ctx, req) + if err != nil { + return openai.ChatCompletionResponse{}, err + } + defer stream.Close() + + var content, reasoning strings.Builder + var role, finishReason string + resp := openai.ChatCompletionResponse{Object: "chat.completion"} + for { + chunk, recvErr := stream.Recv() + if errors.Is(recvErr, io.EOF) { + break + } + if recvErr != nil { + return openai.ChatCompletionResponse{}, recvErr + } + if resp.ID == "" { + resp.ID = chunk.ID + resp.Created = chunk.Created + resp.Model = chunk.Model + } + if chunk.Usage != nil { + resp.Usage = *chunk.Usage + } + if len(chunk.Choices) == 0 { + continue + } + delta := chunk.Choices[0].Delta + if delta.Role != "" { + role = delta.Role + } + content.WriteString(delta.Content) + reasoning.WriteString(delta.ReasoningContent) + if chunk.Choices[0].FinishReason != "" { + finishReason = string(chunk.Choices[0].FinishReason) + } + } + + if role == "" { + role = openai.ChatMessageRoleAssistant + } + resp.Choices = []openai.ChatCompletionChoice{{ + Index: 0, + Message: openai.ChatCompletionMessage{ + Role: role, + Content: content.String(), + ReasoningContent: reasoning.String(), + }, + FinishReason: openai.FinishReason(finishReason), + }} + return resp, nil +} + // default system prompt for article classification const defaultSystemPrompt = `You are an AI assistant that evaluates articles for relevance to the user's interests. Rate each article from 0-10 where: @@ -196,7 +264,7 @@ func (c *Classifier) classify(ctx context.Context, req ClassifyRequest) ([]domai } // call the LLM - resp, err := c.client.CreateChatCompletion(ctx, chatReq) + resp, err := c.createChatCompletion(ctx, chatReq) if err != nil { // all errors will be retried by repeater return fmt.Errorf("llm request failed: %w", err) @@ -535,7 +603,7 @@ func (c *Classifier) GeneratePreferenceSummary(ctx context.Context, feedback []d }, } - resp, err := c.client.CreateChatCompletion(ctx, req) + resp, err := c.createChatCompletion(ctx, req) if err != nil { return fmt.Errorf("generate preference summary failed: %w", err) } @@ -620,7 +688,7 @@ func (c *Classifier) UpdatePreferenceSummary(ctx context.Context, currentSummary }, } - resp, err := c.client.CreateChatCompletion(ctx, req) + resp, err := c.createChatCompletion(ctx, req) if err != nil { return fmt.Errorf("update preference summary failed: %w", err) } diff --git a/pkg/llm/classifier_test.go b/pkg/llm/classifier_test.go index 95bbd37..c62df76 100644 --- a/pkg/llm/classifier_test.go +++ b/pkg/llm/classifier_test.go @@ -1226,3 +1226,83 @@ func TestClassifier_CustomForbiddenPrefixes(t *testing.T) { assert.False(t, classifier.hasForbiddenPrefix("The article discusses")) // not in custom list assert.False(t, classifier.hasForbiddenPrefix("Results indicate")) } + +func TestClassifier_Streaming(t *testing.T) { + // SSE response covering the 2 articles from the classify prompt + sseBody := `data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"role":"assistant","content":"["}}]} + +data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"content":"{\"guid\":\"item1\",\"score\":7,\"explanation\":\"ok\",\"topics\":[\"go\"],\"summary\":\"Go 1.22 ships range-over-func iterators with 50% faster compilation, better GC, and new toolchain management that simplifies Go version control across projects and CI pipelines today.\"}"}}]} + +data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{"content":",{\"guid\":\"item2\",\"score\":2,\"explanation\":\"no\",\"topics\":[\"sports\"],\"summary\":\"Manchester United beats Chelsea 3-1 with Bruno Fernandes scoring twice while Liverpool holds top spot after 2-2 draw with Arsenal at Emirates Stadium in Premier League weekend action.\"}]"}}]} + +data: {"id":"cmpl-1","object":"chat.completion.chunk","created":1,"model":"m","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}} + +data: [DONE] + +` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1/chat/completions", r.URL.Path) + var body map[string]any + assert.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + assert.Equal(t, true, body["stream"], "stream must be true when UseStreaming is set") + + w.Header().Set("Content-Type", "text/event-stream") + _, err := w.Write([]byte(sseBody)) + assert.NoError(t, err) + })) + defer server.Close() + + cfg := config.LLMConfig{ + Endpoint: server.URL + "/v1", + APIKey: "test-key", + Model: "gpt-test", + Temperature: 0.3, + MaxTokens: 100, + UseStreaming: true, + } + classifier := NewClassifier(cfg) + + classifications, err := classifier.ClassifyItems(context.Background(), ClassifyRequest{ + Articles: []domain.Item{ + {GUID: "item1", Title: "Go 1.22 released"}, + {GUID: "item2", Title: "Football news"}, + }, + }) + require.NoError(t, err) + require.Len(t, classifications, 2) + assert.Equal(t, "item1", classifications[0].GUID) + assert.InDelta(t, 7.0, classifications[0].Score, 0.01) + assert.Equal(t, "item2", classifications[1].GUID) +} + +func TestClassifier_Streaming_DefaultOff(t *testing.T) { + // without UseStreaming, request must not set stream=true + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + assert.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + stream, _ := body["stream"].(bool) + assert.False(t, stream) + + resp := openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{Content: `[{"guid":"x","score":1,"explanation":"e","topics":["t"],"summary":"Short and punchy summary covering the main points without preamble and staying within the required 300-500 character window so the classifier accepts it."}]`}, + }}, + } + w.Header().Set("Content-Type", "application/json") + assert.NoError(t, json.NewEncoder(w).Encode(resp)) + })) + defer server.Close() + + cfg := config.LLMConfig{ + Endpoint: server.URL + "/v1", + APIKey: "test-key", + Model: "gpt-test", + } + classifier := NewClassifier(cfg) + + _, err := classifier.ClassifyItems(context.Background(), ClassifyRequest{ + Articles: []domain.Item{{GUID: "x", Title: "t"}}, + }) + require.NoError(t, err) +}