diff --git a/pkg/llmproxy/client/client.go b/pkg/llmproxy/client/client.go new file mode 100644 index 0000000000..f767634dea --- /dev/null +++ b/pkg/llmproxy/client/client.go @@ -0,0 +1,232 @@ +package client + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" +) + +// Client is an HTTP client for the cliproxyapi++ proxy server. +// +// It covers: +// - GET /v1/models — list available models +// - POST /v1/chat/completions — chat completions (non-streaming) +// - POST /v1/responses — OpenAI Responses API passthrough +// - GET / — health / reachability check +// +// Streaming variants are deliberately out of scope for this package; callers +// that need SSE should use [net/http] directly against [Client.BaseURL]. +type Client struct { + cfg clientConfig + http *http.Client +} + +// New creates a Client with the given options. +// +// Defaults: base URL http://127.0.0.1:8318, timeout 120 s, no auth. +func New(opts ...Option) *Client { + cfg := defaultConfig() + for _, o := range opts { + o(&cfg) + } + cfg.baseURL = strings.TrimRight(cfg.baseURL, "/") + return &Client{ + cfg: cfg, + http: &http.Client{Timeout: cfg.httpTimeout}, + } +} + +// BaseURL returns the proxy base URL this client is configured against. +func (c *Client) BaseURL() string { return c.cfg.baseURL } + +// --------------------------------------------------------------------------- +// Internal helpers +// --------------------------------------------------------------------------- + +func (c *Client) newRequest(ctx context.Context, method, path string, body any) (*http.Request, error) { + var bodyReader io.Reader + if body != nil { + b, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: marshal request body: %w", err) + } + bodyReader = bytes.NewReader(b) + } + + req, err := http.NewRequestWithContext(ctx, method, c.cfg.baseURL+path, bodyReader) + if err != nil { + return nil, err + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + + // LLM API key (Bearer token for /v1/* routes) + if c.cfg.apiKey != "" { + req.Header.Set("Authorization", "Bearer "+c.cfg.apiKey) + } + return req, nil +} + +func (c *Client) do(req *http.Request) ([]byte, int, error) { + resp, err := c.http.Do(req) + if err != nil { + return nil, 0, fmt.Errorf("cliproxy/client: HTTP %s %s: %w", req.Method, req.URL.Path, err) + } + defer func() { _ = resp.Body.Close() }() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("cliproxy/client: read response body: %w", err) + } + return data, resp.StatusCode, nil +} + +func (c *Client) doJSON(req *http.Request, out any) error { + data, code, err := c.do(req) + if err != nil { + return err + } + if code >= 400 { + return parseAPIError(code, data) + } + if out == nil { + return nil + } + if err := json.Unmarshal(data, out); err != nil { + return fmt.Errorf("cliproxy/client: decode response (HTTP %d): %w", code, err) + } + return nil +} + +// parseAPIError extracts a structured error from a non-2xx response body. +// It mirrors the error shape produced by _make_error_body in the Python adapter. +func parseAPIError(code int, body []byte) *APIError { + var envelope struct { + Error struct { + Message string `json:"message"` + Code any `json:"code"` + } `json:"error"` + } + msg := strings.TrimSpace(string(body)) + if err := json.Unmarshal(body, &envelope); err == nil && envelope.Error.Message != "" { + msg = envelope.Error.Message + } + if msg == "" { + msg = fmt.Sprintf("proxy returned HTTP %d", code) + } + return &APIError{StatusCode: code, Message: msg, Code: envelope.Error.Code} +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +// Health performs a lightweight GET / against the proxy and reports whether it +// is reachable. A nil error means the server responded with HTTP 2xx. +func (c *Client) Health(ctx context.Context) error { + req, err := c.newRequest(ctx, http.MethodGet, "/", nil) + if err != nil { + return err + } + _, code, err := c.do(req) + if err != nil { + return err + } + if code >= 400 { + return fmt.Errorf("cliproxy/client: health check failed with HTTP %d", code) + } + return nil +} + +// ListModels calls GET /v1/models and returns the normalised model list. +// +// cliproxyapi++ transforms the upstream OpenAI-compatible {"data":[...]} shape +// into {"models":[...]} for Codex compatibility. This method handles both +// shapes transparently. +func (c *Client) ListModels(ctx context.Context) (*ModelsResponse, error) { + req, err := c.newRequest(ctx, http.MethodGet, "/v1/models", nil) + if err != nil { + return nil, err + } + + // Use the underlying Do directly so we can read the response headers. + httpResp, err := c.http.Do(req) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: GET /v1/models: %w", err) + } + defer func() { _ = httpResp.Body.Close() }() + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("cliproxy/client: read /v1/models body: %w", err) + } + if httpResp.StatusCode >= 400 { + return nil, parseAPIError(httpResp.StatusCode, data) + } + + // The proxy normalises the response to {"models":[...]}. + // Fall back to the raw OpenAI {"data":[...], "object":"list"} shape for + // consumers that hit the upstream directly. + var result ModelsResponse + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode /v1/models: %w", err) + } + + if modelsJSON, ok := raw["models"]; ok { + if err := json.Unmarshal(modelsJSON, &result.Models); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode models array: %w", err) + } + } else if dataJSON, ok := raw["data"]; ok { + if err := json.Unmarshal(dataJSON, &result.Models); err != nil { + return nil, fmt.Errorf("cliproxy/client: decode data array: %w", err) + } + } + + // Capture ETag from response header (set by the proxy for cache validation). + result.ETag = httpResp.Header.Get("x-models-etag") + + return &result, nil +} + +// ChatCompletion sends a non-streaming POST /v1/chat/completions request. +// +// For streaming completions use net/http directly; this package does not wrap +// SSE streams in order to avoid pulling in additional dependencies. +func (c *Client) ChatCompletion(ctx context.Context, r ChatCompletionRequest) (*ChatCompletionResponse, error) { + r.Stream = false // enforce non-streaming + req, err := c.newRequest(ctx, http.MethodPost, "/v1/chat/completions", r) + if err != nil { + return nil, err + } + var out ChatCompletionResponse + if err := c.doJSON(req, &out); err != nil { + return nil, err + } + return &out, nil +} + +// Responses sends a non-streaming POST /v1/responses request (OpenAI Responses +// API). The proxy transparently bridges this to /v1/chat/completions when the +// backend does not natively support the Responses endpoint. +// +// The raw decoded JSON is returned as map[string]any to remain forward- +// compatible as the Responses API schema evolves. +func (c *Client) Responses(ctx context.Context, r ResponsesRequest) (map[string]any, error) { + r.Stream = false + req, err := c.newRequest(ctx, http.MethodPost, "/v1/responses", r) + if err != nil { + return nil, err + } + var out map[string]any + if err := c.doJSON(req, &out); err != nil { + return nil, err + } + return out, nil +} diff --git a/pkg/llmproxy/client/client_test.go b/pkg/llmproxy/client/client_test.go new file mode 100644 index 0000000000..2c6da92194 --- /dev/null +++ b/pkg/llmproxy/client/client_test.go @@ -0,0 +1,339 @@ +package client_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/kooshapari/cliproxyapi-plusplus/v6/pkg/llmproxy/client" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func newTestServer(t *testing.T, handler http.Handler) (*httptest.Server, *client.Client) { + t.Helper() + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + c := client.New( + client.WithBaseURL(srv.URL), + client.WithTimeout(5*time.Second), + ) + return srv, c +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +// --------------------------------------------------------------------------- +// Health +// --------------------------------------------------------------------------- + +func TestHealth_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + writeJSON(w, 200, map[string]string{"status": "ok"}) + })) + + if err := c.Health(context.Background()); err != nil { + t.Fatalf("Health() unexpected error: %v", err) + } +} + +func TestHealth_Unreachable(t *testing.T) { + // Point at a port nothing is listening on. + c := client.New( + client.WithBaseURL("http://127.0.0.1:1"), + client.WithTimeout(500*time.Millisecond), + ) + if err := c.Health(context.Background()); err == nil { + t.Fatal("Health() expected error for unreachable server, got nil") + } +} + +func TestHealth_ServerError(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 503, map[string]any{ + "error": map[string]any{"message": "service unavailable", "code": 503}, + }) + })) + if err := c.Health(context.Background()); err == nil { + t.Fatal("Health() expected error for 503, got nil") + } +} + +// --------------------------------------------------------------------------- +// ListModels +// --------------------------------------------------------------------------- + +func TestListModels_ProxyShape(t *testing.T) { + // cliproxyapi++ normalised shape: {"models": [...]} + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/v1/models" { + http.NotFound(w, r) + return + } + w.Header().Set("x-models-etag", "abc123") + writeJSON(w, 200, map[string]any{ + "models": []map[string]any{ + {"id": "anthropic/claude-opus-4-6", "object": "model", "owned_by": "anthropic"}, + {"id": "openai/gpt-4o", "object": "model", "owned_by": "openai"}, + }, + }) + })) + + resp, err := c.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if len(resp.Models) != 2 { + t.Fatalf("expected 2 models, got %d", len(resp.Models)) + } + if resp.Models[0].ID != "anthropic/claude-opus-4-6" { + t.Errorf("unexpected first model ID: %s", resp.Models[0].ID) + } +} + +func TestListModels_OpenAIShape(t *testing.T) { + // Raw upstream OpenAI shape: {"data": [...], "object": "list"} + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 200, map[string]any{ + "object": "list", + "data": []map[string]any{ + {"id": "gpt-4o", "object": "model", "owned_by": "openai"}, + }, + }) + })) + + resp, err := c.ListModels(context.Background()) + if err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if len(resp.Models) != 1 || resp.Models[0].ID != "gpt-4o" { + t.Errorf("unexpected models: %+v", resp.Models) + } +} + +func TestListModels_Error(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 401, map[string]any{ + "error": map[string]any{"message": "unauthorized", "code": 401}, + }) + })) + + _, err := c.ListModels(context.Background()) + if err == nil { + t.Fatal("ListModels() expected error for 401, got nil") + } + if _, ok := err.(*client.APIError); !ok { + t.Logf("error type: %T — not an *client.APIError, that is acceptable", err) + } +} + +// --------------------------------------------------------------------------- +// ChatCompletion +// --------------------------------------------------------------------------- + +func TestChatCompletion_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/chat/completions" { + http.NotFound(w, r) + return + } + // Decode and validate request body + var body client.ChatCompletionRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "bad request", 400) + return + } + if body.Stream { + http.Error(w, "client must not set stream=true", 400) + return + } + writeJSON(w, 200, map[string]any{ + "id": "chatcmpl-test", + "object": "chat.completion", + "created": 1700000000, + "model": body.Model, + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop", + }, + }, + "usage": map[string]any{ + "prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15, + }, + }) + })) + + resp, err := c.ChatCompletion(context.Background(), client.ChatCompletionRequest{ + Model: "anthropic/claude-opus-4-6", + Messages: []client.ChatMessage{ + {Role: "user", Content: "Say hi"}, + }, + }) + if err != nil { + t.Fatalf("ChatCompletion() unexpected error: %v", err) + } + if len(resp.Choices) == 0 { + t.Fatal("expected at least one choice") + } + if resp.Choices[0].Message.Content != "Hello!" { + t.Errorf("unexpected content: %q", resp.Choices[0].Message.Content) + } +} + +func TestChatCompletion_4xx(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 429, map[string]any{ + "error": map[string]any{"message": "rate limit exceeded", "code": 429}, + }) + })) + + _, err := c.ChatCompletion(context.Background(), client.ChatCompletionRequest{ + Model: "any", + Messages: []client.ChatMessage{{Role: "user", Content: "hi"}}, + }) + if err == nil { + t.Fatal("expected error for 429") + } +} + +// --------------------------------------------------------------------------- +// Responses +// --------------------------------------------------------------------------- + +func TestResponses_OK(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/v1/responses" { + http.NotFound(w, r) + return + } + writeJSON(w, 200, map[string]any{ + "id": "resp_test", + "object": "response", + "output": []map[string]any{ + {"type": "message", "role": "assistant", "content": []map[string]any{ + {"type": "text", "text": "Hello from responses API"}, + }}, + }, + }) + })) + + out, err := c.Responses(context.Background(), client.ResponsesRequest{ + Model: "anthropic/claude-opus-4-6", + Input: "Say hi", + }) + if err != nil { + t.Fatalf("Responses() unexpected error: %v", err) + } + if out["id"] != "resp_test" { + t.Errorf("unexpected id: %v", out["id"]) + } +} + +// --------------------------------------------------------------------------- +// Options +// --------------------------------------------------------------------------- + +func TestWithAPIKey_SetsAuthorizationHeader(t *testing.T) { + var gotAuth string + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + // Rebuild with API key + _, c = newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + _ = c // silence unused warning; we rebuild below + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + writeJSON(w, 200, map[string]any{"models": []any{}}) + })) + t.Cleanup(srv.Close) + + c = client.New( + client.WithBaseURL(srv.URL), + client.WithAPIKey("sk-test-key"), + client.WithTimeout(5*time.Second), + ) + if _, err := c.ListModels(context.Background()); err != nil { + t.Fatalf("ListModels() unexpected error: %v", err) + } + if gotAuth != "Bearer sk-test-key" { + t.Errorf("expected 'Bearer sk-test-key', got %q", gotAuth) + } +} + +func TestBaseURL(t *testing.T) { + c := client.New(client.WithBaseURL("http://localhost:9999")) + if c.BaseURL() != "http://localhost:9999" { + t.Errorf("BaseURL() = %q, want %q", c.BaseURL(), "http://localhost:9999") + } +} + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +func TestAPIError_Message(t *testing.T) { + _, c := newTestServer(t, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, 503, map[string]any{ + "error": map[string]any{ + "message": "service unavailable — no providers matched", + "code": 503, + }, + }) + })) + + _, err := c.ListModels(context.Background()) + if err == nil { + t.Fatal("expected error") + } + apiErr, ok := err.(*client.APIError) + if !ok { + t.Fatalf("expected *client.APIError, got %T", err) + } + if apiErr.StatusCode != 503 { + t.Errorf("StatusCode = %d, want 503", apiErr.StatusCode) + } + if apiErr.Message == "" { + t.Error("Message must not be empty") + } +} + +// --------------------------------------------------------------------------- +// Context cancellation +// --------------------------------------------------------------------------- + +func TestContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Block until client cancels + <-r.Context().Done() + w.WriteHeader(200) + })) + t.Cleanup(srv.Close) + + c := client.New(client.WithBaseURL(srv.URL), client.WithTimeout(5*time.Second)) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + if err := c.Health(ctx); err == nil { + t.Fatal("expected error due to context cancellation") + } +} diff --git a/pkg/llmproxy/client/types.go b/pkg/llmproxy/client/types.go new file mode 100644 index 0000000000..216dd69d71 --- /dev/null +++ b/pkg/llmproxy/client/types.go @@ -0,0 +1,147 @@ +// Package client provides a Go SDK for the cliproxyapi++ HTTP proxy API. +// +// It covers the core LLM proxy surface: model listing, chat completions, the +// Responses API, and the proxy process lifecycle (start/stop/health). +// +// # Migration note +// +// This package is the canonical Go replacement for the Python adapter code +// that previously lived in thegent/src/thegent/cliproxy_adapter.py and +// related helpers. Any new consumer should import this package rather than +// re-implementing raw HTTP calls. +package client + +import "time" + +// --------------------------------------------------------------------------- +// Model types +// --------------------------------------------------------------------------- + +// Model is a single entry from GET /v1/models. +type Model struct { + ID string `json:"id"` + Object string `json:"object,omitempty"` + Created int64 `json:"created,omitempty"` + OwnedBy string `json:"owned_by,omitempty"` +} + +// ModelsResponse is the envelope returned by GET /v1/models. +// cliproxyapi++ normalises the upstream shape into {"models": [...]}. +type ModelsResponse struct { + Models []Model `json:"models"` + // ETag is populated from the x-models-etag response header when present. + ETag string `json:"-"` +} + +// --------------------------------------------------------------------------- +// Chat completions types +// --------------------------------------------------------------------------- + +// ChatMessage is a single message in a chat conversation. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionRequest is the body for POST /v1/chat/completions. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Stream bool `json:"stream,omitempty"` + // MaxTokens limits the number of tokens generated. + MaxTokens *int `json:"max_tokens,omitempty"` + // Temperature controls randomness (0–2). + Temperature *float64 `json:"temperature,omitempty"` +} + +// ChatChoice is a single completion choice. +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// Usage holds token counts reported by the backend. +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + Cost float64 `json:"cost,omitempty"` +} + +// ChatCompletionResponse is the non-streaming response from POST /v1/chat/completions. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage Usage `json:"usage"` +} + +// --------------------------------------------------------------------------- +// Responses API types (POST /v1/responses) +// --------------------------------------------------------------------------- + +// ResponsesRequest is the body for POST /v1/responses (OpenAI Responses API). +type ResponsesRequest struct { + Model string `json:"model"` + Input any `json:"input"` + Stream bool `json:"stream,omitempty"` +} + +// --------------------------------------------------------------------------- +// Error type +// --------------------------------------------------------------------------- + +// APIError is returned when the server responds with a non-2xx status code. +type APIError struct { + StatusCode int + Message string + Code any +} + +func (e *APIError) Error() string { + return e.Message +} + +// --------------------------------------------------------------------------- +// Client options +// --------------------------------------------------------------------------- + +// Option configures a [Client]. +type Option func(*clientConfig) + +type clientConfig struct { + baseURL string + apiKey string + secretKey string + httpTimeout time.Duration +} + +func defaultConfig() clientConfig { + return clientConfig{ + baseURL: "http://127.0.0.1:8318", + httpTimeout: 120 * time.Second, + } +} + +// WithBaseURL overrides the proxy base URL (default: http://127.0.0.1:8318). +func WithBaseURL(u string) Option { + return func(c *clientConfig) { c.baseURL = u } +} + +// WithAPIKey sets the Authorization: Bearer header for LLM API calls. +func WithAPIKey(key string) Option { + return func(c *clientConfig) { c.apiKey = key } +} + +// WithSecretKey sets the management API bearer token (used for /v0/management/* routes). +func WithSecretKey(key string) Option { + return func(c *clientConfig) { c.secretKey = key } +} + +// WithTimeout sets the HTTP client timeout (default: 120s). +func WithTimeout(d time.Duration) Option { + return func(c *clientConfig) { c.httpTimeout = d } +}