diff --git a/internal/adapter/claude/standard_request.go b/internal/adapter/claude/standard_request.go index 7b16c96..d73ffda 100644 --- a/internal/adapter/claude/standard_request.go +++ b/internal/adapter/claude/standard_request.go @@ -36,7 +36,7 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma thinkingEnabled = false searchEnabled = false } - finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"])) + finalPrompt := deepseek.MessagesPrepareWithThinking(toMessageMaps(dsPayload["messages"]), thinkingEnabled) toolNames := extractClaudeToolNames(toolsRequested) if len(toolNames) == 0 && len(toolsRequested) > 0 { toolNames = []string{"__any_tool__"} diff --git a/internal/adapter/gemini/convert_request.go b/internal/adapter/gemini/convert_request.go index 34eb2a2..5a9ff95 100644 --- a/internal/adapter/gemini/convert_request.go +++ b/internal/adapter/gemini/convert_request.go @@ -28,7 +28,7 @@ func normalizeGeminiRequest(store ConfigReader, routeModel string, req map[strin } toolsRaw := convertGeminiTools(req["tools"]) - finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "") + finalPrompt, toolNames := openai.BuildPromptForAdapter(messagesRaw, toolsRaw, "", thinkingEnabled) passThrough := collectGeminiPassThrough(req) return util.StandardRequest{ diff --git a/internal/adapter/openai/prompt_build.go b/internal/adapter/openai/prompt_build.go index d6823b2..2e1d891 100644 --- a/internal/adapter/openai/prompt_build.go +++ b/internal/adapter/openai/prompt_build.go @@ -5,22 +5,22 @@ import ( "ds2api/internal/util" ) -func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { - return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy()) +func buildOpenAIFinalPrompt(messagesRaw []any, toolsRaw any, traceID string, thinkingEnabled bool) (string, []string) { + return buildOpenAIFinalPromptWithPolicy(messagesRaw, toolsRaw, traceID, util.DefaultToolChoicePolicy(), thinkingEnabled) } -func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy) (string, []string) { +func buildOpenAIFinalPromptWithPolicy(messagesRaw []any, toolsRaw any, traceID string, toolPolicy util.ToolChoicePolicy, thinkingEnabled bool) (string, []string) { messages := normalizeOpenAIMessagesForPrompt(messagesRaw, traceID) toolNames := []string{} if tools, ok := toolsRaw.([]any); ok && len(tools) > 0 { messages, toolNames = injectToolPrompt(messages, tools, toolPolicy) } - return deepseek.MessagesPrepare(messages), toolNames + return deepseek.MessagesPrepareWithThinking(messages, thinkingEnabled), toolNames } // BuildPromptForAdapter exposes the OpenAI-compatible prompt building flow so // other protocol adapters (for example Gemini) can reuse the same tool/history // normalization logic and remain behavior-compatible with chat/completions. -func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string) (string, []string) { - return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID) +func BuildPromptForAdapter(messagesRaw []any, toolsRaw any, traceID string, thinkingEnabled bool) (string, []string) { + return buildOpenAIFinalPrompt(messagesRaw, toolsRaw, traceID, thinkingEnabled) } diff --git a/internal/adapter/openai/prompt_build_test.go b/internal/adapter/openai/prompt_build_test.go index 390cbd4..724fef8 100644 --- a/internal/adapter/openai/prompt_build_test.go +++ b/internal/adapter/openai/prompt_build_test.go @@ -40,7 +40,7 @@ func TestBuildOpenAIFinalPrompt_HandlerPathIncludesToolRoundtripSemantics(t *tes }, } - finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "") + finalPrompt, toolNames := buildOpenAIFinalPrompt(messages, tools, "", false) if len(toolNames) != 1 || toolNames[0] != "get_weather" { t.Fatalf("unexpected tool names: %#v", toolNames) } @@ -73,7 +73,7 @@ func TestBuildOpenAIFinalPrompt_VercelPreparePathKeepsFinalAnswerInstruction(t * }, } - finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "") + finalPrompt, _ := buildOpenAIFinalPrompt(messages, tools, "", false) if !strings.Contains(finalPrompt, "Remember: Output ONLY the ... XML block when calling tools.") { t.Fatalf("vercel prepare finalPrompt missing final tool-call anchor instruction: %q", finalPrompt) } diff --git a/internal/adapter/openai/standard_request.go b/internal/adapter/openai/standard_request.go index e9904f7..39ccb01 100644 --- a/internal/adapter/openai/standard_request.go +++ b/internal/adapter/openai/standard_request.go @@ -24,7 +24,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID responseModel = resolvedModel } toolPolicy := util.DefaultToolChoicePolicy() - finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled) toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) passThrough := collectOpenAIChatPassThrough(req) @@ -74,7 +74,7 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra if err != nil { return util.StandardRequest{}, err } - finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy) + finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy, thinkingEnabled) toolNames = ensureToolDetectionEnabled(toolNames, req["tools"]) if !toolPolicy.IsNone() { toolPolicy.Allowed = namesToSet(toolNames) diff --git a/internal/deepseek/prompt.go b/internal/deepseek/prompt.go index 2410390..77fd36f 100644 --- a/internal/deepseek/prompt.go +++ b/internal/deepseek/prompt.go @@ -5,3 +5,7 @@ import "ds2api/internal/prompt" func MessagesPrepare(messages []map[string]any) string { return prompt.MessagesPrepare(messages) } + +func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool) string { + return prompt.MessagesPrepareWithThinking(messages, thinkingEnabled) +} diff --git a/internal/prompt/messages.go b/internal/prompt/messages.go index fe69f72..91a3b84 100644 --- a/internal/prompt/messages.go +++ b/internal/prompt/messages.go @@ -10,6 +10,7 @@ import ( var markdownImagePattern = regexp.MustCompile(`!\[(.*?)\]\((.*?)\)`) const ( + beginSentenceMarker = "<|begin▁of▁sentence|>" systemMarker = "<|System|>" userMarker = "<|User|>" assistantMarker = "<|Assistant|>" @@ -17,9 +18,15 @@ const ( endSentenceMarker = "<|end▁of▁sentence|>" endToolResultsMarker = "<|end▁of▁toolresults|>" endInstructionsMarker = "<|end▁of▁instructions|>" + openThinkMarker = "" + closeThinkMarker = "" ) func MessagesPrepare(messages []map[string]any) string { + return MessagesPrepareWithThinking(messages, false) +} + +func MessagesPrepareWithThinking(messages []map[string]any, thinkingEnabled bool) string { type block struct { Role string Text string @@ -41,11 +48,14 @@ func MessagesPrepare(messages []map[string]any) string { } merged = append(merged, msg) } - parts := make([]string, 0, len(merged)) + parts := make([]string, 0, len(merged)+2) + parts = append(parts, beginSentenceMarker) + lastRole := "" for _, m := range merged { + lastRole = m.Role switch m.Role { case "assistant": - parts = append(parts, formatRoleBlock(assistantMarker, m.Text, endSentenceMarker)) + parts = append(parts, formatRoleBlock(assistantMarker, closeThinkMarker+m.Text, endSentenceMarker)) case "tool": if strings.TrimSpace(m.Text) != "" { parts = append(parts, formatRoleBlock(toolMarker, m.Text, endToolResultsMarker)) @@ -62,6 +72,13 @@ func MessagesPrepare(messages []map[string]any) string { } } } + if lastRole != "assistant" { + thinkPrefix := closeThinkMarker + if thinkingEnabled { + thinkPrefix = openThinkMarker + } + parts = append(parts, assistantMarker+thinkPrefix) + } out := strings.Join(parts, "\n\n") return markdownImagePattern.ReplaceAllString(out, `[${1}](${2})`) } diff --git a/internal/prompt/messages_test.go b/internal/prompt/messages_test.go index 5465c7a..a86f9db 100644 --- a/internal/prompt/messages_test.go +++ b/internal/prompt/messages_test.go @@ -32,13 +32,16 @@ func TestMessagesPrepareUsesTurnSuffixes(t *testing.T) { {"role": "assistant", "content": "Answer"}, } got := MessagesPrepare(messages) + if !strings.HasPrefix(got, "<|begin▁of▁sentence|>") { + t.Fatalf("expected begin-of-sentence marker, got %q", got) + } if !strings.Contains(got, "<|System|>\nSystem rule<|end▁of▁instructions|>") { t.Fatalf("expected system instructions suffix, got %q", got) } if !strings.Contains(got, "<|User|>\nQuestion<|end▁of▁sentence|>") { t.Fatalf("expected user sentence suffix, got %q", got) } - if !strings.Contains(got, "<|Assistant|>\nAnswer<|end▁of▁sentence|>") { + if !strings.Contains(got, "<|Assistant|>\nAnswer<|end▁of▁sentence|>") { t.Fatalf("expected assistant sentence suffix, got %q", got) } } @@ -51,3 +54,11 @@ func TestNormalizeContentArrayFallsBackToContentWhenTextEmpty(t *testing.T) { t.Fatalf("expected fallback to content when text is empty, got %q", got) } } + +func TestMessagesPrepareWithThinkingEndsWithOpenThink(t *testing.T) { + messages := []map[string]any{{"role": "user", "content": "Question"}} + got := MessagesPrepareWithThinking(messages, true) + if !strings.HasSuffix(got, "<|Assistant|>") { + t.Fatalf("expected thinking suffix, got %q", got) + } +} diff --git a/internal/util/messages_test.go b/internal/util/messages_test.go index 1fd2024..9a09a26 100644 --- a/internal/util/messages_test.go +++ b/internal/util/messages_test.go @@ -12,7 +12,7 @@ func TestMessagesPrepareBasic(t *testing.T) { if got == "" { t.Fatal("expected non-empty prompt") } - if got != "<|User|>\nHello<|end▁of▁sentence|>" { + if got != "<|begin▁of▁sentence|>\n\n<|User|>\nHello<|end▁of▁sentence|>\n\n<|Assistant|>" { t.Fatalf("unexpected prompt: %q", got) } } @@ -29,10 +29,13 @@ func TestMessagesPrepareRoles(t *testing.T) { if !contains(got, "<|System|>\nYou are helper<|end▁of▁instructions|>\n\n<|User|>\nHi<|end▁of▁sentence|>") { t.Fatalf("expected system/user separation in %q", got) } - if !contains(got, "<|User|>\nHi<|end▁of▁sentence|>\n\n<|Assistant|>\nHello<|end▁of▁sentence|>") { + if !contains(got, "<|begin▁of▁sentence|>") { + t.Fatalf("expected begin marker in %q", got) + } + if !contains(got, "<|User|>\nHi<|end▁of▁sentence|>\n\n<|Assistant|>\nHello<|end▁of▁sentence|>") { t.Fatalf("expected user/assistant separation in %q", got) } - if !contains(got, "<|Assistant|>\nHello<|end▁of▁sentence|>\n\n<|Tool|>\nSearch results<|end▁of▁toolresults|>") { + if !contains(got, "<|Assistant|>\nHello<|end▁of▁sentence|>\n\n<|Tool|>\nSearch results<|end▁of▁toolresults|>") { t.Fatalf("expected assistant/tool separation in %q", got) } if !contains(got, "<|Tool|>\nSearch results<|end▁of▁toolresults|>\n\n<|User|>\nHow are you<|end▁of▁sentence|>") { @@ -74,7 +77,7 @@ func TestMessagesPrepareArrayTextVariants(t *testing.T) { }, } got := MessagesPrepare(messages) - if got != "<|User|>\nline1\nline2<|end▁of▁sentence|>" { + if got != "<|begin▁of▁sentence|>\n\n<|User|>\nline1\nline2<|end▁of▁sentence|>\n\n<|Assistant|>" { t.Fatalf("unexpected content from text variants: %q", got) } } diff --git a/internal/util/util_edge_test.go b/internal/util/util_edge_test.go index 621df2f..4b0dda5 100644 --- a/internal/util/util_edge_test.go +++ b/internal/util/util_edge_test.go @@ -162,7 +162,7 @@ func TestMessagesPrepareMergesConsecutiveSameRole(t *testing.T) { {"role": "user", "content": "World"}, } got := MessagesPrepare(messages) - if !strings.HasPrefix(got, "<|User|>") { + if !strings.HasPrefix(got, "<|begin▁of▁sentence|>") { t.Fatalf("expected user marker at the start, got %q", got) } if !strings.Contains(got, "Hello") || !strings.Contains(got, "World") { @@ -193,7 +193,7 @@ func TestMessagesPrepareAssistantMarkers(t *testing.T) { if strings.Count(got, "<|end▁of▁sentence|>") != 2 { t.Fatalf("expected both turns to be terminated, got %q", got) } - if !strings.Contains(got, "<|Assistant|>\nHello!<|end▁of▁sentence|>") { + if !strings.Contains(got, "<|Assistant|>\nHello!<|end▁of▁sentence|>") { t.Fatalf("expected assistant EOS suffix, got %q", got) } if strings.Contains(got, "") {