Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 153 additions & 32 deletions internal/cli/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -646,13 +646,27 @@ func (a *App) runCluster(ctx context.Context, args []string) error {
}

type embedResult struct {
Repository string `json:"repository"`
Model string `json:"model"`
Basis string `json:"basis"`
Selected int `json:"selected"`
Embedded int `json:"embedded"`
Skipped int `json:"skipped"`
RunID int64 `json:"run_id"`
Repository string `json:"repository"`
Model string `json:"model"`
Basis string `json:"basis"`
Selected int `json:"selected"`
Embedded int `json:"embedded"`
Skipped int `json:"skipped"`
Failed int `json:"failed,omitempty"`
Retries int `json:"retries,omitempty"`
Status string `json:"status,omitempty"`
Failures []embedFailureStat `json:"failures,omitempty"`
RunID int64 `json:"run_id"`
}

type embedFailureStat struct {
BatchStart int `json:"batch_start"`
BatchEnd int `json:"batch_end"`
Attempts int `json:"attempts"`
Status int `json:"status,omitempty"`
Type string `json:"type,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message"`
}

func (a *App) runEmbed(ctx context.Context, args []string) error {
Expand Down Expand Up @@ -731,42 +745,71 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio
return embedResult{}, err
}
started := time.Now().UTC().Format(time.RFC3339Nano)
embedded := 0
batchSize := rt.Config.OpenAI.BatchSize
if batchSize <= 0 {
batchSize = 64
}
client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions})
client := openai.New(openai.Options{APIKey: token.Value, BaseURL: openAIBaseURL(), Dimensions: rt.Config.OpenAI.EmbedDimensions, Retry: embedRetryOverride()})

type pendingBatch struct {
start, end int
attempts int
}
var queue []pendingBatch
for start := 0; start < len(tasks); start += batchSize {
end := start + batchSize
if end > len(tasks) {
end = len(tasks)
}
batch := tasks[start:end]
texts := make([]string, 0, len(batch))
for _, task := range batch {
queue = append(queue, pendingBatch{start: start, end: end})
}

embedded := 0
totalRetries := 0
var failures []embedFailureStat
cancelled := false
var cancelErr error

const maxBatchAttempts = 2
for len(queue) > 0 {
batch := queue[0]
queue = queue[1:]
batch.attempts++
slice := tasks[batch.start:batch.end]
texts := make([]string, 0, len(slice))
for _, task := range slice {
texts = append(texts, task.Text)
}
fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d\n", start+1, end, len(tasks))
if truncated := truncatedEmbeddingTaskCount(batch); truncated > 0 {
fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes)
fmt.Fprintf(a.Stderr, "[embed] embedding %d-%d of %d (attempt %d)\n", batch.start+1, batch.end, len(tasks), batch.attempts)
if batch.attempts == 1 {
if truncated := truncatedEmbeddingTaskCount(slice); truncated > 0 {
fmt.Fprintf(a.Stderr, "[embed] truncated %d input(s) to %d runes\n", truncated, store.MaxEmbeddingTextRunes)
}
}
vectors, err := client.Embed(ctx, rt.Config.OpenAI.EmbedModel, texts)
if err != nil {
_, _ = rt.Store.RecordRun(ctx, store.RunRecord{
RepoID: repo.ID,
Kind: "embedding",
Scope: "repo",
Status: "error",
StartedAt: started,
FinishedAt: time.Now().UTC().Format(time.RFC3339Nano),
ErrorText: err.Error(),
})
return embedResult{}, err
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
cancelled = true
cancelErr = err
break
}
retryable := true
if apiErr := openai.AsAPIError(err); apiErr != nil {
retryable = apiErr.Retryable()
}
if retryable && batch.attempts < maxBatchAttempts {
totalRetries++
fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed (%s), requeueing\n", batch.start+1, batch.end, summarizeEmbedErr(err))
queue = append(queue, batch)
continue
}
fmt.Fprintf(a.Stderr, "[embed] batch %d-%d failed permanently: %s\n", batch.start+1, batch.end, summarizeEmbedErr(err))
failures = append(failures, makeEmbedFailureStat(batch.start, batch.end, batch.attempts, err))
continue
}
now := time.Now().UTC().Format(time.RFC3339Nano)
for index, vector := range vectors {
task := batch[index]
task := slice[index]
if err := rt.Store.UpsertThreadVector(ctx, store.ThreadVector{
ThreadID: task.ThreadID,
Basis: rt.Config.EmbeddingBasis,
Expand All @@ -783,31 +826,109 @@ func (a *App) embedRepository(ctx context.Context, owner, repoName string, optio
embedded++
}
}

failedRows := 0
for _, f := range failures {
failedRows += f.BatchEnd - f.BatchStart
}

status := "success"
switch {
case cancelled:
status = "cancelled"
case len(failures) > 0 && embedded == 0:
status = "error"
case len(failures) > 0:
status = "partial"
}

result := embedResult{
Repository: repo.FullName,
Model: rt.Config.OpenAI.EmbedModel,
Basis: rt.Config.EmbeddingBasis,
Selected: len(tasks),
Embedded: embedded,
RunID: 0,
Failed: failedRows,
Retries: totalRetries,
Status: status,
Failures: failures,
}
statsJSON, _ := json.Marshal(result)
runID, err := rt.Store.RecordRun(ctx, store.RunRecord{
runRecord := store.RunRecord{
RepoID: repo.ID,
Kind: "embedding",
Scope: "repo",
Status: "success",
Status: status,
StartedAt: started,
FinishedAt: time.Now().UTC().Format(time.RFC3339Nano),
StatsJSON: string(statsJSON),
})
if err != nil {
return embedResult{}, err
}
if cancelled && cancelErr != nil {
runRecord.ErrorText = cancelErr.Error()
} else if status == "error" && len(failures) > 0 {
runRecord.ErrorText = failures[0].Message
}
recordCtx := ctx
if cancelled {
var cancelRecord context.CancelFunc
recordCtx, cancelRecord = context.WithTimeout(context.Background(), 5*time.Second)
defer cancelRecord()
}
runID, recordErr := rt.Store.RecordRun(recordCtx, runRecord)
if recordErr != nil && !cancelled {
return embedResult{}, recordErr
}
result.RunID = runID

if cancelled {
return result, cancelErr
}
if status == "error" {
return result, fmt.Errorf("openai embeddings failed: %s", failures[0].Message)
}
return result, nil
}

func summarizeEmbedErr(err error) string {
if apiErr := openai.AsAPIError(err); apiErr != nil {
parts := []string{fmt.Sprintf("status=%d", apiErr.Status)}
if apiErr.Type != "" {
parts = append(parts, "type="+apiErr.Type)
}
if apiErr.Code != "" {
parts = append(parts, "code="+apiErr.Code)
}
return strings.Join(parts, " ")
}
return err.Error()
}

func makeEmbedFailureStat(start, end, attempts int, err error) embedFailureStat {
stat := embedFailureStat{
BatchStart: start,
BatchEnd: end,
Attempts: attempts,
Message: err.Error(),
}
if apiErr := openai.AsAPIError(err); apiErr != nil {
stat.Status = apiErr.Status
stat.Type = apiErr.Type
stat.Code = apiErr.Code
if apiErr.Message != "" {
stat.Message = apiErr.Message
}
}
return stat
}

func embedRetryOverride() *openai.RetryConfig {
if strings.TrimSpace(os.Getenv("GITCRAWL_OPENAI_RETRY_DISABLED")) == "1" {
cfg := openai.NoRetry()
return &cfg
}
return nil
}

func truncatedEmbeddingTaskCount(tasks []store.EmbeddingTask) int {
count := 0
for _, task := range tasks {
Expand Down
137 changes: 137 additions & 0 deletions internal/cli/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,7 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) {
}))
defer server.Close()
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")
if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--limit", "1"}); err == nil {
t.Fatal("OpenAI error should fail")
}
Expand All @@ -1459,6 +1460,142 @@ func TestEmbedErrorBranchesRecordFailures(t *testing.T) {
}
}

func TestEmbedRunPartialOnSomeFailedBatches(t *testing.T) {
ctx := context.Background()
dir := t.TempDir()
configPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "gitcrawl.db")
if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil {
t.Fatalf("init: %v", err)
}
seedCommandFlowStore(t, dbPath)

cfg, err := config.Load(configPath)
if err != nil {
t.Fatalf("load config: %v", err)
}
cfg.OpenAI.BatchSize = 1
if err := config.Save(configPath, cfg); err != nil {
t.Fatalf("save config: %v", err)
}

var calls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
calls++
var payload struct {
Input []string `json:"input"`
}
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
t.Fatalf("decode: %v", err)
}
// First input is permanently bad — return non-retryable 400.
if len(payload.Input) == 1 && strings.Contains(payload.Input[0], "Gateway websocket stalls") {
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]any{
"error": map[string]any{"message": "bad input", "type": "invalid_request_error"},
})
return
}
data := make([]map[string]any, 0, len(payload.Input))
for index := range payload.Input {
data = append(data, map[string]any{"index": index, "embedding": []float64{1, 0.5 * float64(index)}})
}
_ = json.NewEncoder(w).Encode(map[string]any{"data": data})
}))
defer server.Close()
t.Setenv("OPENAI_API_KEY", "test-openai-key")
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")

app := New()
var stdout bytes.Buffer
app.Stdout = &stdout
if err := app.Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw", "--json"}); err != nil {
t.Fatalf("embed: %v", err)
}

var result embedResult
if err := json.Unmarshal(stdout.Bytes(), &result); err != nil {
t.Fatalf("decode embed result: %v\n%s", err, stdout.String())
}
if result.Status != "partial" {
t.Fatalf("status = %q, want partial", result.Status)
}
if result.Embedded != 2 {
t.Fatalf("embedded = %d, want 2", result.Embedded)
}
if result.Failed != 1 {
t.Fatalf("failed = %d, want 1", result.Failed)
}
if len(result.Failures) != 1 {
t.Fatalf("failures = %+v", result.Failures)
}
if result.Failures[0].Status != http.StatusBadRequest {
t.Fatalf("failure status = %d", result.Failures[0].Status)
}

st, err := store.Open(ctx, dbPath)
if err != nil {
t.Fatalf("open: %v", err)
}
defer st.Close()
repo, err := st.RepositoryByFullName(ctx, "openclaw/openclaw")
if err != nil {
t.Fatalf("repo: %v", err)
}
runs, err := st.ListRuns(ctx, repo.ID, "embedding", 1)
if err != nil {
t.Fatalf("runs: %v", err)
}
if len(runs) != 1 || runs[0].Status != "partial" {
t.Fatalf("run = %+v", runs)
}
}

func TestEmbedRunCancelledRecordsCancelledStatus(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
dir := t.TempDir()
configPath := filepath.Join(dir, "config.toml")
dbPath := filepath.Join(dir, "gitcrawl.db")
if err := New().Run(ctx, []string{"--config", configPath, "init", "--db", dbPath}); err != nil {
t.Fatalf("init: %v", err)
}
seedCommandFlowStore(t, dbPath)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cancel()
select {
case <-r.Context().Done():
case <-time.After(2 * time.Second):
}
}))
defer server.Close()
t.Setenv("OPENAI_API_KEY", "test-openai-key")
t.Setenv("GITCRAWL_OPENAI_BASE_URL", server.URL)
t.Setenv("GITCRAWL_OPENAI_RETRY_DISABLED", "1")

if err := New().Run(ctx, []string{"--config", configPath, "embed", "openclaw/openclaw"}); err == nil {
t.Fatal("expected cancellation error")
}

st, err := store.Open(context.Background(), dbPath)
if err != nil {
t.Fatalf("open store: %v", err)
}
defer st.Close()
repo, err := st.RepositoryByFullName(context.Background(), "openclaw/openclaw")
if err != nil {
t.Fatalf("repo: %v", err)
}
runs, err := st.ListRuns(context.Background(), repo.ID, "embedding", 1)
if err != nil {
t.Fatalf("runs: %v", err)
}
if len(runs) != 1 || runs[0].Status != "cancelled" {
t.Fatalf("expected cancelled run, got %+v", runs)
}
}

func TestTruncatedEmbeddingTaskCount(t *testing.T) {
tasks := []store.EmbeddingTask{
{Number: 1},
Expand Down
Loading