diff --git a/internal/cli/app.go b/internal/cli/app.go index 28ae229..59cb8a3 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -646,13 +646,27 @@ func (a *App) runCluster(ctx context.Context, args []string) error { } type embedResult struct { - Repository string `json:"repository"` - Model string `json:"model"` - Basis string `json:"basis"` - Selected int `json:"selected"` - Embedded int `json:"embedded"` - Skipped int `json:"skipped"` - RunID int64 `json:"run_id"` + Repository string `json:"repository"` + Model string `json:"model"` + Basis string `json:"basis"` + Selected int `json:"selected"` + Embedded int `json:"embedded"` + Skipped int `json:"skipped"` + Failed int `json:"failed,omitempty"` + Retries int `json:"retries,omitempty"` + Status string `json:"status,omitempty"` + Failures []embedFailureStat `json:"failures,omitempty"` + RunID int64 `json:"run_id"` +} + +type embedFailureStat struct { + BatchStart int `json:"batch_start"` + BatchEnd int `json:"batch_end"` + Attempts int `json:"attempts"` + Status int `json:"status,omitempty"` + Type string `json:"type,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message"` } func (a *App) runEmbed(ctx context.Context, args []string) error { @@ -731,42 +745,71 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio return embedResult{}, err } started := time.Now().UTC().Format(time.RFC3339Nano) - embedded := 0 batchSize := rt.Config.OpenAI.BatchSize if batchSize <= 0 { batchSize = 64 } - client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions}) + client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions, Retry: embedRetryOverride()}) + + type pendingBatch struct { + start, end int + attempts int + } + var queue []pendingBatch for start := 0; start < len(tasks); start += batchSize { end := start + batchSize if end > len(tasks) { end = len(tasks) } - batch := tasks[start:end] - texts := make([]string, 0, len(batch)) - for _, task := range batch { + queue = append(queue, pendingBatch{start: start, end: end}) + } + + embedded := 0 + totalRetries := 0 + var failures []embedFailureStat + cancelled := false + var cancelErr error + + const maxBatchAttempts = 2 + for len(queue) > 0 { + batch := queue[0] + queue = queue[1:] + batch.attempts++ + slice := tasks[batch.start:batch.end] + texts := make([]string, 0, len(slice)) + for _, task := range slice { texts = append(texts, task.Text) } - fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d\n", start+1, end, len(tasks)) - if truncated := truncatedEmbeddingTaskCount(batch); truncated > 0 { - fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes) + fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d (attempt %d)\n", batch.start+1, batch.end, len(tasks), batch.attempts) + if batch.attempts == 1 { + if truncated := truncatedEmbeddingTaskCount(slice); truncated > 0 { + fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes) + } } vectors, err := client.Embed(ctx, rt.Config.OpenAI.EmbedModel, texts) if err != nil { - _, _ = rt.Store.RecordRun(ctx, store.RunRecord{ - RepoID: repo.ID, - Kind: "embedding", - Scope: "repo", - Status: "error", - StartedAt: started, - FinishedAt: time.Now().UTC().Format(time.RFC3339Nano), - ErrorText: err.Error(), - }) - return embedResult{}, err + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + cancelled = true + cancelErr = err + break + } + retryable := true + if apiErr := openai.AsAPIError(err); apiErr != nil { + retryable = apiErr.Retryable() + } + if retryable && batch.attempts < maxBatchAttempts { + totalRetries++ + fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed (%s), requeueing\n", batch.start+1, batch.end, summarizeEmbedErr(err)) + queue = append(queue, batch) + continue + } + fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed permanently: %s\n", batch.start+1, batch.end, summarizeEmbedErr(err)) + failures = append(failures, makeEmbedFailureStat(batch.start, batch.end, batch.attempts, err)) + continue } now := time.Now().UTC().Format(time.RFC3339Nano) for index, vector := range vectors { - task := batch[index] + task := slice[index] if err := rt.Store.UpsertThreadVector(ctx, store.ThreadVector{ ThreadID: task.ThreadID, Basis: rt.Config.EmbeddingBasis, @@ -783,31 +826,109 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio embedded++ } } + + failedRows := 0 + for _, f := range failures { + failedRows += f.BatchEnd - f.BatchStart + } + + status := "success" + switch { + case cancelled: + status = "cancelled" + case len(failures) > 0 && embedded == 0: + status = "error" + case len(failures) > 0: + status = "partial" + } + result := embedResult{ Repository: repo.FullName, Model: rt.Config.OpenAI.EmbedModel, Basis: rt.Config.EmbeddingBasis, Selected: len(tasks), Embedded: embedded, - RunID: 0, + Failed: failedRows, + Retries: totalRetries, + Status: status, + Failures: failures, } statsJSON, _ := json.Marshal(result) - runID, err := rt.Store.RecordRun(ctx, store.RunRecord{ + runRecord := store.RunRecord{ RepoID: repo.ID, Kind: "embedding", Scope: "repo", - Status: "success", + Status: status, StartedAt: started, FinishedAt: time.Now().UTC().Format(time.RFC3339Nano), StatsJSON: string(statsJSON), - }) - if err != nil { - return embedResult{}, err + } + if cancelled && cancelErr != nil { + runRecord.ErrorText = cancelErr.Error() + } else if status == "error" && len(failures) > 0 { + runRecord.ErrorText = failures[0].Message + } + recordCtx := ctx + if cancelled { + var cancelRecord context.CancelFunc + recordCtx, cancelRecord = context.WithTimeout(context.Background(), 5*time.Second) + defer cancelRecord() + } + runID, recordErr := rt.Store.RecordRun(recordCtx, runRecord) + if recordErr != nil && !cancelled { + return embedResult{}, recordErr } result.RunID = runID + + if cancelled { + return result, cancelErr + } + if status == "error" { + return result, fmt.Errorf("openai embeddings failed: %s", failures[0].Message) + } return result, nil } +func summarizeEmbedErr(err error) string { + if apiErr := openai.AsAPIError(err); apiErr != nil { + parts := []string{fmt.Sprintf("status=%d", apiErr.Status)} + if apiErr.Type != "" { + parts = append(parts, "type="+apiErr.Type) + } + if apiErr.Code != "" { + parts = append(parts, "code="+apiErr.Code) + } + return strings.Join(parts, " ") + } + return err.Error() +} + +func makeEmbedFailureStat(start, end, attempts int, err error) embedFailureStat { + stat := embedFailureStat{ + BatchStart: start, + BatchEnd: end, + Attempts: attempts, + Message: err.Error(), + } + if apiErr := openai.AsAPIError(err); apiErr != nil { + stat.Status = apiErr.Status + stat.Type = apiErr.Type + stat.Code = apiErr.Code + if apiErr.Message != "" { + stat.Message = apiErr.Message + } + } + return stat +} + +func embedRetryOverride() *openai.RetryConfig { + if strings.TrimSpace(os.Getenv("GITCRAWL_OPENAI_RETRY_DISABLED")) == "1" { + cfg := openai.NoRetry() + return &cfg + } + return nil +} + func truncatedEmbeddingTaskCount(tasks []store.EmbeddingTask) int { count := 0 for _, task := range tasks { diff --git a/internal/cli/app_test.go b/internal/cli/app_test.go index 9cea73a..9fa04c1 100644 --- a/internal/cli/app_test.go +++ b/internal/cli/app_test.go @@ -1438,6 +1438,7 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) { })) defer server.Close() t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL) + t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1") if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--limit", "1"}); err == nil { t.Fatal("OpenAI error should fail") } @@ -1459,6 +1460,142 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) { } } +func TestEmbedRunPartialOnSomeFailedBatches(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + configPath := filepath.Join(dir, "config.toml") + dbPath := filepath.Join(dir, "gitcrawl.db") + if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil { + t.Fatalf("init: %v", err) + } + seedCommandFlowStore(t, dbPath) + + cfg, err := config.Load(configPath) + if err != nil { + t.Fatalf("load config: %v", err) + } + cfg.OpenAI.BatchSize = 1 + if err := config.Save(configPath, cfg); err != nil { + t.Fatalf("save config: %v", err) + } + + var calls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + var payload struct { + Input []string `json:"input"` + } + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode: %v", err) + } + // First input is permanently bad — return non-retryable 400. + if len(payload.Input) == 1 && strings.Contains(payload.Input[0], "Gateway websocket stalls") { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{"message": "bad input", "type": "invalid_request_error"}, + }) + return + } + data := make([]map[string]any, 0, len(payload.Input)) + for index := range payload.Input { + data = append(data, map[string]any{"index": index, "embedding": []float64{1, 0.5 * float64(index)}}) + } + _ = json.NewEncoder(w).Encode(map[string]any{"data": data}) + })) + defer server.Close() + t.Setenv("OPENAI_API_KEY", "test-openai-key") + t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL) + t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1") + + app := New() + var stdout bytes.Buffer + app.Stdout = &stdout + if err := app.Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--json"}); err != nil { + t.Fatalf("embed: %v", err) + } + + var result embedResult + if err := json.Unmarshal(stdout.Bytes(), &result); err != nil { + t.Fatalf("decode embed result: %v\n%s", err, stdout.String()) + } + if result.Status != "partial" { + t.Fatalf("status = %q, want partial", result.Status) + } + if result.Embedded != 2 { + t.Fatalf("embedded = %d, want 2", result.Embedded) + } + if result.Failed != 1 { + t.Fatalf("failed = %d, want 1", result.Failed) + } + if len(result.Failures) != 1 { + t.Fatalf("failures = %+v", result.Failures) + } + if result.Failures[0].Status != http.StatusBadRequest { + t.Fatalf("failure status = %d", result.Failures[0].Status) + } + + st, err := store.Open(ctx, dbPath) + if err != nil { + t.Fatalf("open: %v", err) + } + defer st.Close() + repo, err := st.RepositoryByFullName(ctx, "openclaw/openclaw") + if err != nil { + t.Fatalf("repo: %v", err) + } + runs, err := st.ListRuns(ctx, repo.ID, "embedding", 1) + if err != nil { + t.Fatalf("runs: %v", err) + } + if len(runs) != 1 || runs[0].Status != "partial" { + t.Fatalf("run = %+v", runs) + } +} + +func TestEmbedRunCancelledRecordsCancelledStatus(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + dir := t.TempDir() + configPath := filepath.Join(dir, "config.toml") + dbPath := filepath.Join(dir, "gitcrawl.db") + if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil { + t.Fatalf("init: %v", err) + } + seedCommandFlowStore(t, dbPath) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancel() + select { + case <-r.Context().Done(): + case <-time.After(2 * time.Second): + } + })) + defer server.Close() + t.Setenv("OPENAI_API_KEY", "test-openai-key") + t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL) + t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1") + + if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw"}); err == nil { + t.Fatal("expected cancellation error") + } + + st, err := store.Open(context.Background(), dbPath) + if err != nil { + t.Fatalf("open store: %v", err) + } + defer st.Close() + repo, err := st.RepositoryByFullName(context.Background(), "openclaw/openclaw") + if err != nil { + t.Fatalf("repo: %v", err) + } + runs, err := st.ListRuns(context.Background(), repo.ID, "embedding", 1) + if err != nil { + t.Fatalf("runs: %v", err) + } + if len(runs) != 1 || runs[0].Status != "cancelled" { + t.Fatalf("expected cancelled run, got %+v", runs) + } +} + func TestTruncatedEmbeddingTaskCount(t *testing.T) { tasks := []store.EmbeddingTask{ {Number: 1}, diff --git a/internal/openai/client.go b/internal/openai/client.go index 93356b1..797800b 100644 --- a/internal/openai/client.go +++ b/internal/openai/client.go @@ -4,10 +4,13 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "math/rand" "net/http" "strings" + "sync" "time" ) @@ -17,11 +20,41 @@ const ( maxEmbeddingInputRunes = 6_000 ) +type RetryConfig struct { + MaxAttempts int + BaseDelay time.Duration + OverloadedBase time.Duration + MaxDelay time.Duration + MaxElapsed time.Duration + Jitter float64 +} + +func DefaultRetryConfig() RetryConfig { + return RetryConfig{ + MaxAttempts: 6, + BaseDelay: time.Second, + OverloadedBase: 15 * time.Second, + MaxDelay: 60 * time.Second, + MaxElapsed: 5 * time.Minute, + Jitter: 0.2, + } +} + +func NoRetry() RetryConfig { + return RetryConfig{MaxAttempts: 1} +} + type Client struct { apiKey string baseURL string httpClient *http.Client dimensions int + retry RetryConfig + + now func() time.Time + sleep func(context.Context, time.Duration) error + rand *rand.Rand + randM sync.Mutex } type Options struct { @@ -29,6 +62,10 @@ type Options struct { BaseURL string Dimensions int HTTPClient *http.Client + Retry *RetryConfig + + Now func() time.Time + Sleep func(context.Context, time.Duration) error } type embeddingRequest struct { @@ -45,6 +82,7 @@ type embeddingResponse struct { Error *struct { Message string `json:"message"` Type string `json:"type"` + Code string `json:"code"` } `json:"error,omitempty"` } @@ -57,11 +95,30 @@ func New(options Options) *Client { if httpClient == nil { httpClient = &http.Client{Timeout: 60 * time.Second} } + retry := DefaultRetryConfig() + if options.Retry != nil { + retry = *options.Retry + } + if retry.MaxAttempts <= 0 { + retry.MaxAttempts = 1 + } + now := options.Now + if now == nil { + now = time.Now + } + sleep := options.Sleep + if sleep == nil { + sleep = sleepCtx + } return &Client{ apiKey: strings.TrimSpace(options.APIKey), baseURL: baseURL, httpClient: httpClient, dimensions: options.Dimensions, + retry: retry, + now: now, + sleep: sleep, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } @@ -77,13 +134,67 @@ func (c *Client) Embed(ctx context.Context, model string, texts []string) ([][]f return nil, fmt.Errorf("OpenAI API key is required") } texts = capEmbeddingInputs(texts) + + deadline := c.now().Add(c.retry.MaxElapsed) + var lastErr error + for attempt := 0; attempt < c.retry.MaxAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + vectors, apiErr, err := c.embedOnce(ctx, model, texts) + if err != nil { + if isContextErr(err) { + return nil, err + } + lastErr = err + if attempt+1 >= c.retry.MaxAttempts { + return nil, err + } + delay := c.backoff(attempt, c.retry.BaseDelay, 0) + if !c.canSleep(deadline, delay) { + return nil, err + } + if sleepErr := c.sleep(ctx, delay); sleepErr != nil { + return nil, sleepErr + } + continue + } + if apiErr == nil { + return vectors, nil + } + lastErr = apiErr + if !apiErr.Retryable() { + return nil, apiErr + } + if attempt+1 >= c.retry.MaxAttempts { + return nil, apiErr + } + base := c.retry.BaseDelay + if apiErr.IsOverloaded() { + base = c.retry.OverloadedBase + } + delay := c.backoff(attempt, base, apiErr.RetryAfter) + if !c.canSleep(deadline, delay) { + return nil, apiErr + } + if sleepErr := c.sleep(ctx, delay); sleepErr != nil { + return nil, sleepErr + } + } + if lastErr == nil { + lastErr = fmt.Errorf("openai embeddings: exhausted %d attempts", c.retry.MaxAttempts) + } + return nil, lastErr +} + +func (c *Client) embedOnce(ctx context.Context, model string, texts []string) ([][]float64, *APIError, error) { payload, err := json.Marshal(embeddingRequest{Model: model, Input: texts, Dimensions: c.dimensions}) if err != nil { - return nil, fmt.Errorf("marshal embeddings request: %w", err) + return nil, nil, fmt.Errorf("marshal embeddings request: %w", err) } req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/embeddings", bytes.NewReader(payload)) if err != nil { - return nil, err + return nil, nil, err } req.Header.Set("Authorization", "Bearer "+c.apiKey) req.Header.Set("Content-Type", "application/json") @@ -91,40 +202,103 @@ func (c *Client) Embed(ctx context.Context, model string, texts []string) ([][]f resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("openai embeddings request: %w", err) + return nil, nil, fmt.Errorf("openai embeddings request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, maxEmbeddingResponseBytes)) if err != nil { - return nil, fmt.Errorf("read embeddings response: %w", err) + return nil, nil, fmt.Errorf("read embeddings response: %w", err) } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + apiErr := &APIError{Status: resp.StatusCode} var parsed embeddingResponse - if err := json.Unmarshal(body, &parsed); err == nil && parsed.Error != nil && parsed.Error.Message != "" { - return nil, fmt.Errorf("openai embeddings failed with status %d: %s", resp.StatusCode, parsed.Error.Message) + if jerr := json.Unmarshal(body, &parsed); jerr == nil && parsed.Error != nil { + apiErr.Message = parsed.Error.Message + apiErr.Type = parsed.Error.Type + apiErr.Code = parsed.Error.Code + } else { + apiErr.Message = strings.TrimSpace(string(body)) } - return nil, fmt.Errorf("openai embeddings failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + apiErr.RetryAfter = parseRetryAfter(resp.Header.Get("Retry-After"), c.now()) + return nil, apiErr, nil } + var parsed embeddingResponse if err := json.Unmarshal(body, &parsed); err != nil { - return nil, fmt.Errorf("decode embeddings response: %w", err) + return nil, nil, fmt.Errorf("decode embeddings response: %w", err) } if len(parsed.Data) != len(texts) { - return nil, fmt.Errorf("openai embeddings returned %d vectors for %d inputs", len(parsed.Data), len(texts)) + return nil, nil, fmt.Errorf("openai embeddings returned %d vectors for %d inputs", len(parsed.Data), len(texts)) } out := make([][]float64, len(texts)) for _, item := range parsed.Data { if item.Index < 0 || item.Index >= len(texts) { - return nil, fmt.Errorf("openai embeddings returned invalid index %d", item.Index) + return nil, nil, fmt.Errorf("openai embeddings returned invalid index %d", item.Index) } out[item.Index] = item.Embedding } for index, vector := range out { if len(vector) == 0 { - return nil, fmt.Errorf("openai embeddings returned empty vector at index %d", index) + return nil, nil, fmt.Errorf("openai embeddings returned empty vector at index %d", index) + } + } + return out, nil, nil +} + +func (c *Client) backoff(attempt int, base time.Duration, retryAfter time.Duration) time.Duration { + if retryAfter > 0 { + if retryAfter > c.retry.MaxDelay { + return c.retry.MaxDelay } + return retryAfter + } + if base <= 0 { + base = time.Second + } + shift := attempt + if shift > 6 { + shift = 6 + } + delay := base * (1 << shift) + if delay > c.retry.MaxDelay { + delay = c.retry.MaxDelay } - return out, nil + if c.retry.Jitter > 0 { + c.randM.Lock() + offset := (c.rand.Float64()*2 - 1) * c.retry.Jitter * float64(delay) + c.randM.Unlock() + delay += time.Duration(offset) + if delay < 0 { + delay = 0 + } + } + return delay +} + +func (c *Client) canSleep(deadline time.Time, delay time.Duration) bool { + if c.retry.MaxElapsed <= 0 { + return true + } + return c.now().Add(delay).Before(deadline) +} + +func sleepCtx(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil + } + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return nil + } +} + +func isContextErr(err error) bool { + return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } func capEmbeddingInputs(texts []string) []string { diff --git a/internal/openai/client_test.go b/internal/openai/client_test.go index 4f11441..559151f 100644 --- a/internal/openai/client_test.go +++ b/internal/openai/client_test.go @@ -3,10 +3,13 @@ package openai import ( "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" + "time" ) func TestEmbedAcceptsLargeBatchResponse(t *testing.T) { @@ -40,7 +43,8 @@ func TestEmbedAcceptsLargeBatchResponse(t *testing.T) { for index := range inputs { inputs[index] = "thread text" } - vectors, err := New(Options{APIKey: "test", BaseURL: server.URL, Dimensions: 1024}).Embed(context.Background(), "text-embedding-3-small", inputs) + noRetry := NoRetry() + vectors, err := New(Options{APIKey: "test", BaseURL: server.URL, Dimensions: 1024, Retry: &noRetry}).Embed(context.Background(), "text-embedding-3-small", inputs) if err != nil { t.Fatalf("embed: %v", err) } @@ -82,14 +86,15 @@ func TestEmbedCapsOversizedInputsBeforeRequest(t *testing.T) { } func TestEmbedErrorBranches(t *testing.T) { - client := New(Options{APIKey: "test"}) + noRetry := NoRetry() + client := New(Options{APIKey: "test", Retry: &noRetry}) if _, err := client.Embed(context.Background(), "", []string{"text"}); err == nil { t.Fatal("missing model should fail") } if vectors, err := client.Embed(context.Background(), "model", nil); err != nil || vectors != nil { t.Fatalf("empty inputs = %+v err=%v", vectors, err) } - if _, err := New(Options{}).Embed(context.Background(), "model", []string{"text"}); err == nil { + if _, err := New(Options{Retry: &noRetry}).Embed(context.Background(), "model", []string{"text"}); err == nil { t.Fatal("missing API key should fail") } @@ -100,6 +105,7 @@ func TestEmbedErrorBranches(t *testing.T) { _ = json.NewEncoder(w).Encode(embeddingResponse{Error: &struct { Message string `json:"message"` Type string `json:"type"` + Code string `json:"code"` }{Message: "bad input", Type: "invalid_request"}}) case strings.Contains(r.URL.Path, "wrong-count"): _ = json.NewEncoder(w).Encode(embeddingResponse{}) @@ -119,9 +125,198 @@ func TestEmbedErrorBranches(t *testing.T) { })) defer server.Close() for _, suffix := range []string{"/api-error", "/wrong-count", "/bad-index", "/empty-vector", ""} { - _, err := New(Options{APIKey: "test", BaseURL: server.URL + suffix}).Embed(context.Background(), "model", []string{"text"}) + _, err := New(Options{APIKey: "test", BaseURL: server.URL + suffix, Retry: &noRetry}).Embed(context.Background(), "model", []string{"text"}) if err == nil { t.Fatalf("expected error for %q", suffix) } } } + +func newSingleVectorServer(handler http.HandlerFunc) *httptest.Server { + return httptest.NewServer(handler) +} + +func writeSingleVector(w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(embeddingResponse{Data: []struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` + }{{Index: 0, Embedding: []float64{0.1}}}}) +} + +func TestEmbedRetriesOn429AndHonorsRetryAfter(t *testing.T) { + var calls int32 + server := newSingleVectorServer(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.Header().Set("Retry-After", "2") + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(embeddingResponse{Error: &struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + }{Message: "rate limited", Type: "rate_limit_exceeded"}}) + return + } + writeSingleVector(w) + }) + defer server.Close() + + var slept []time.Duration + retry := RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry, Sleep: func(_ context.Context, d time.Duration) error { + slept = append(slept, d) + return nil + }}) + if _, err := client.Embed(context.Background(), "model", []string{"hi"}); err != nil { + t.Fatalf("embed: %v", err) + } + if calls != 2 { + t.Fatalf("calls = %d, want 2", calls) + } + if len(slept) != 1 || slept[0] != 2*time.Second { + t.Fatalf("expected single sleep of 2s honoring Retry-After, got %v", slept) + } +} + +func TestEmbedDoesNotSleepAfterFinalRetryableError(t *testing.T) { + var calls int32 + server := newSingleVectorServer(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(embeddingResponse{Error: &struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + }{Message: "rate limited", Type: "rate_limit_exceeded"}}) + }) + defer server.Close() + + var slept []time.Duration + retry := RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry, Sleep: func(_ context.Context, d time.Duration) error { + slept = append(slept, d) + return nil + }}) + _, err := client.Embed(context.Background(), "model", []string{"hi"}) + if err == nil { + t.Fatalf("expected final retryable error") + } + if calls != 3 { + t.Fatalf("calls = %d, want 3", calls) + } + if len(slept) != 2 { + t.Fatalf("slept %d times, want 2 before final attempt: %v", len(slept), slept) + } +} + +func TestEmbedDoesNotRetryInsufficientQuota(t *testing.T) { + var calls int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&calls, 1) + w.WriteHeader(http.StatusTooManyRequests) + _ = json.NewEncoder(w).Encode(embeddingResponse{Error: &struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + }{Message: "out of money", Type: "insufficient_quota", Code: "insufficient_quota"}}) + })) + defer server.Close() + retry := RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry}) + _, err := client.Embed(context.Background(), "model", []string{"hi"}) + if err == nil { + t.Fatalf("expected error") + } + if calls != 1 { + t.Fatalf("calls = %d, want 1 (no retry on insufficient_quota)", calls) + } + apiErr := AsAPIError(err) + if apiErr == nil { + t.Fatalf("expected typed APIError, got %T: %v", err, err) + } + if apiErr.Code != "insufficient_quota" { + t.Fatalf("code = %q, want insufficient_quota", apiErr.Code) + } +} + +func TestEmbedOverloadedUsesLongerBackoff(t *testing.T) { + var calls int32 + server := newSingleVectorServer(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(embeddingResponse{Error: &struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + }{Message: "overloaded", Type: "overloaded_error"}}) + return + } + writeSingleVector(w) + }) + defer server.Close() + + var slept []time.Duration + retry := RetryConfig{MaxAttempts: 3, BaseDelay: 10 * time.Millisecond, OverloadedBase: 5 * time.Second, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry, Sleep: func(_ context.Context, d time.Duration) error { + slept = append(slept, d) + return nil + }}) + if _, err := client.Embed(context.Background(), "model", []string{"hi"}); err != nil { + t.Fatalf("embed: %v", err) + } + if len(slept) != 1 { + t.Fatalf("slept = %v, want one entry", slept) + } + if slept[0] < 4*time.Second || slept[0] > 6*time.Second { + t.Fatalf("overloaded backoff = %v, expected ~5s ± jitter", slept[0]) + } +} + +func TestEmbedPropagatesContextCancellation(t *testing.T) { + // Server returns 429 so retry would normally engage. We pre-cancel the + // context to assert that cancellation short-circuits the retry loop and + // is not classified as a retryable failure. + server := newSingleVectorServer(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + }) + defer server.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + retry := RetryConfig{MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry}) + _, err := client.Embed(ctx, "model", []string{"hi"}) + if err == nil { + t.Fatal("expected cancellation error") + } + if !errors.Is(err, context.Canceled) { + t.Fatalf("err = %v, want context.Canceled", err) + } +} + +func TestEmbedRetryAfterDateForm(t *testing.T) { + var calls int32 + server := newSingleVectorServer(func(w http.ResponseWriter, r *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.Header().Set("Retry-After", time.Now().Add(3*time.Second).UTC().Format(http.TimeFormat)) + w.WriteHeader(http.StatusTooManyRequests) + return + } + writeSingleVector(w) + }) + defer server.Close() + + var slept []time.Duration + retry := RetryConfig{MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Hour, MaxElapsed: time.Hour} + client := New(Options{APIKey: "test", BaseURL: server.URL, Retry: &retry, Sleep: func(_ context.Context, d time.Duration) error { + slept = append(slept, d) + return nil + }}) + if _, err := client.Embed(context.Background(), "model", []string{"hi"}); err != nil { + t.Fatalf("embed: %v", err) + } + if len(slept) != 1 || slept[0] < time.Second || slept[0] > 4*time.Second { + t.Fatalf("expected ~3s sleep from HTTP-date Retry-After, got %v", slept) + } +} diff --git a/internal/openai/errors.go b/internal/openai/errors.go new file mode 100644 index 0000000..38cb575 --- /dev/null +++ b/internal/openai/errors.go @@ -0,0 +1,88 @@ +package openai + +import ( + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +type APIError struct { + Status int + Type string + Code string + Message string + RetryAfter time.Duration +} + +func (e *APIError) Error() string { + parts := []string{fmt.Sprintf("openai embeddings status=%d", e.Status)} + if e.Type != "" { + parts = append(parts, "type="+e.Type) + } + if e.Code != "" { + parts = append(parts, "code="+e.Code) + } + if e.Message != "" { + parts = append(parts, "message="+e.Message) + } + return strings.Join(parts, " ") +} + +func (e *APIError) Retryable() bool { + if e == nil { + return false + } + switch e.Status { + case http.StatusRequestTimeout, http.StatusTooManyRequests: + return e.Type != "insufficient_quota" && e.Code != "insufficient_quota" + case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return true + default: + return false + } +} + +func (e *APIError) IsOverloaded() bool { + return e != nil && (e.Type == "overloaded_error" || (e.Status == http.StatusServiceUnavailable && e.Code == "overloaded")) +} + +func AsAPIError(err error) *APIError { + if err == nil { + return nil + } + var apiErr *APIError + if errors.As(err, &apiErr) { + return apiErr + } + return nil +} + +func parseRetryAfter(header string, now time.Time) time.Duration { + header = strings.TrimSpace(header) + if header == "" { + return 0 + } + if seconds, err := strconv.Atoi(header); err == nil { + if seconds < 0 { + return 0 + } + return time.Duration(seconds) * time.Second + } + if seconds, err := strconv.ParseFloat(header, 64); err == nil { + if seconds < 0 { + return 0 + } + return time.Duration(seconds * float64(time.Second)) + } + if when, err := http.ParseTime(header); err == nil { + delta := when.Sub(now) + if delta < 0 { + return 0 + } + return delta + } + return 0 +}