From 26992414cbf4bda5e671b93f3dcd5bb18a60289b Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Mon, 16 Mar 2026 09:38:57 +0100 Subject: [PATCH 1/2] Handle context cancelling during rag initialization & query Assisted-By: docker-agent Signed-off-by: Guillaume Tardif --- pkg/fsx/collect.go | 25 ++++++- pkg/fsx/collect_cancellation_test.go | 87 ++++++++++++++++++++++++ pkg/fsx/collect_test.go | 20 +++--- pkg/fsx/fs.go | 25 ++++++- pkg/rag/strategy/bm25.go | 53 +++++++++++++-- pkg/rag/strategy/vector_store.go | 99 +++++++++++++++++++++++++--- pkg/tools/builtin/filesystem.go | 4 +- 7 files changed, 282 insertions(+), 31 deletions(-) create mode 100644 pkg/fsx/collect_cancellation_test.go diff --git a/pkg/fsx/collect.go b/pkg/fsx/collect.go index 63b1466ec..612917c83 100644 --- a/pkg/fsx/collect.go +++ b/pkg/fsx/collect.go @@ -1,6 +1,7 @@ package fsx import ( + "context" "fmt" "os" "path/filepath" @@ -13,11 +14,18 @@ import ( // Supports glob patterns (via doublestar), directories, and individual files. // Skips paths that don't exist instead of returning an error. // Optional shouldIgnore filter can exclude files/directories (return true to skip). -func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string, error) { +// Respects context cancellation. +func CollectFiles(ctx context.Context, paths []string, shouldIgnore func(path string) bool) ([]string, error) { var files []string seen := make(map[string]bool) for _, pattern := range paths { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } expanded, err := expandPattern(pattern) if err != nil { return nil, err @@ -27,6 +35,12 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string } for _, entry := range expanded { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } normalized := normalizePath(entry) // Check if this path should be ignored @@ -44,7 +58,7 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string if info.IsDir() { // Use DirectoryTree to collect files from directory - tree, err := DirectoryTree(normalized, func(string) error { return nil }, shouldIgnore, 0) + tree, err := DirectoryTreeWithContext(ctx, normalized, func(string) error { return nil }, shouldIgnore, 0) if err != nil { return nil, fmt.Errorf("failed to read directory %s: %w", normalized, err) } @@ -52,6 +66,13 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string var dirFiles []string CollectFilesFromTree(tree, filepath.Dir(normalized), &dirFiles) for _, f := range dirFiles { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + absPath := normalizePath(f) if !seen[absPath] { files = append(files, absPath) diff --git a/pkg/fsx/collect_cancellation_test.go b/pkg/fsx/collect_cancellation_test.go new file mode 100644 index 000000000..db53efe02 --- /dev/null +++ b/pkg/fsx/collect_cancellation_test.go @@ -0,0 +1,87 @@ +package fsx + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCollectFiles_ContextCancellation(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + // Create a large directory structure to ensure context cancellation has time to kick in + for i := range 100 { + subDir := filepath.Join(tmpDir, "dir", "subdir", "deepdir", fmt.Sprintf("dir%d", i)) + require.NoError(t, os.MkdirAll(subDir, 0o755)) + for j := range 10 { + filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j)) + require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644)) + } + } + + t.Run("respects context cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + // Cancel context immediately + cancel() + + _, err := CollectFiles(ctx, []string{tmpDir}, nil) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("respects context timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond) + defer cancel() + + // Give time for timeout to trigger + time.Sleep(10 * time.Millisecond) + + _, err := CollectFiles(ctx, []string{tmpDir}, nil) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) +} + +func TestDirectoryTree_ContextCancellation(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + // Create a large directory structure + for i := range 100 { + subDir := filepath.Join(tmpDir, "dir", "subdir", fmt.Sprintf("dir%d", i)) + require.NoError(t, os.MkdirAll(subDir, 0o755)) + for j := range 10 { + filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j)) + require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644)) + } + } + + t.Run("respects context cancellation", func(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + + // Cancel context immediately + cancel() + + _, err := DirectoryTreeWithContext(ctx, tmpDir, func(string) error { return nil }, nil, 0) + assert.ErrorIs(t, err, context.Canceled) + }) + + t.Run("respects context timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond) + defer cancel() + + // Give time for timeout to trigger + time.Sleep(10 * time.Millisecond) + + _, err := DirectoryTreeWithContext(ctx, tmpDir, func(string) error { return nil }, nil, 0) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) +} diff --git a/pkg/fsx/collect_test.go b/pkg/fsx/collect_test.go index 991381686..75129fdf0 100644 --- a/pkg/fsx/collect_test.go +++ b/pkg/fsx/collect_test.go @@ -52,7 +52,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) { t.Run("no filter collects all files", func(t *testing.T) { t.Parallel() - got, err := CollectFiles([]string{tmpDir}, nil) + got, err := CollectFiles(t.Context(), []string{tmpDir}, nil) require.NoError(t, err) assert.Len(t, got, 5) }) @@ -66,7 +66,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) { return base == "vendor" || base == "node_modules" } - got, err := CollectFiles([]string{tmpDir}, shouldIgnore) + got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore) require.NoError(t, err) // Should only have src/*.go and build/output.bin @@ -87,7 +87,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) { return strings.HasSuffix(path, ".bin") } - got, err := CollectFiles([]string{tmpDir}, shouldIgnore) + got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore) require.NoError(t, err) assert.Len(t, got, 4) @@ -104,7 +104,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) { return strings.Contains(path, "vendor") } - got, err := CollectFiles([]string{tmpDir}, shouldIgnore) + got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore) require.NoError(t, err) for _, f := range got { @@ -148,7 +148,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) { t.Run("without filter includes .git", func(t *testing.T) { t.Parallel() - got, err := CollectFiles([]string{tmpDir}, nil) + got, err := CollectFiles(t.Context(), []string{tmpDir}, nil) require.NoError(t, err) // Should include .git files @@ -177,7 +177,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) { strings.HasPrefix(normalized, ".git/") } - got, err := CollectFiles([]string{tmpDir}, shouldIgnore) + got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore) require.NoError(t, err) // Should only have src/main.go @@ -229,7 +229,7 @@ func TestCollectFiles_GlobsWithFilter(t *testing.T) { return strings.HasSuffix(path, "_test.go") } - got, err := CollectFiles([]string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore) + got, err := CollectFiles(t.Context(), []string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore) require.NoError(t, err) // Should only have non-test .go files @@ -329,7 +329,7 @@ func TestCollectFiles_Deduplication(t *testing.T) { tmpDir, // Will also include test.go } - got, err := CollectFiles(patterns, nil) + got, err := CollectFiles(t.Context(), patterns, nil) require.NoError(t, err) // Should only have one entry @@ -352,7 +352,7 @@ func TestCollectFiles_NonExistentPaths(t *testing.T) { filepath.Join(tmpDir, "also", "missing", "file.go"), } - got, err := CollectFiles(patterns, nil) + got, err := CollectFiles(t.Context(), patterns, nil) require.NoError(t, err) // Should only have the real file @@ -371,7 +371,7 @@ func TestCollectFiles_SortedOutput(t *testing.T) { require.NoError(t, os.WriteFile(filepath.Join(tmpDir, f), []byte("package test"), 0o644)) } - got, err := CollectFiles([]string{tmpDir}, nil) + got, err := CollectFiles(t.Context(), []string{tmpDir}, nil) require.NoError(t, err) // Verify we got all files diff --git a/pkg/fsx/fs.go b/pkg/fsx/fs.go index 3dae9ca1b..dd36ea620 100644 --- a/pkg/fsx/fs.go +++ b/pkg/fsx/fs.go @@ -17,10 +17,22 @@ type TreeNode struct { func DirectoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) { itemCount := 0 - return directoryTree(path, isPathAllowed, shouldIgnore, maxItems, &itemCount) + return directoryTree(context.Background(), path, isPathAllowed, shouldIgnore, maxItems, &itemCount) } -func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) { +// DirectoryTreeWithContext is a context-aware version of DirectoryTree. +func DirectoryTreeWithContext(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) { + itemCount := 0 + return directoryTree(ctx, path, isPathAllowed, shouldIgnore, maxItems, &itemCount) +} + +func directoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } if maxItems > 0 && *itemCount >= maxItems { return nil, nil } @@ -47,6 +59,13 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f } for _, entry := range entries { + // Check for context cancellation + select { + case <-ctx.Done(): + return node, ctx.Err() + default: + } + childPath := filepath.Join(path, entry.Name()) if err := isPathAllowed(childPath); err != nil { continue // Skip disallowed paths @@ -57,7 +76,7 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f continue } - childNode, err := directoryTree(childPath, isPathAllowed, shouldIgnore, maxItems, itemCount) + childNode, err := directoryTree(ctx, childPath, isPathAllowed, shouldIgnore, maxItems, itemCount) if err != nil || childNode == nil { continue } diff --git a/pkg/rag/strategy/bm25.go b/pkg/rag/strategy/bm25.go index 72cb0fe56..f7938d96b 100644 --- a/pkg/rag/strategy/bm25.go +++ b/pkg/rag/strategy/bm25.go @@ -157,7 +157,7 @@ func (s *BM25Strategy) Initialize(ctx context.Context, docPaths []string, chunki // Collect all files slog.Debug("Collecting files", "strategy", s.name, "paths", docPaths) - files, err := fsx.CollectFiles(docPaths, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, docPaths, s.shouldIgnore) if err != nil { s.emitEvent(types.Event{Type: types.EventTypeError, Error: err}) return fmt.Errorf("failed to collect files: %w", err) @@ -165,6 +165,12 @@ func (s *BM25Strategy) Initialize(ctx context.Context, docPaths []string, chunki seenFilesForCleanup := make(map[string]bool) for _, f := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } seenFilesForCleanup[f] = true } @@ -188,6 +194,13 @@ func (s *BM25Strategy) Initialize(ctx context.Context, docPaths []string, chunki filesToIndex := 0 for _, filePath := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + seenFiles[filePath] = true needsIndexing, err := s.needsIndexing(ctx, filePath) @@ -324,7 +337,7 @@ func (s *BM25Strategy) Query(ctx context.Context, query string, numResults int, // CheckAndReindexChangedFiles checks for file changes and re-indexes if needed func (s *BM25Strategy) CheckAndReindexChangedFiles(ctx context.Context, docPaths []string, chunking ChunkingConfig) error { - files, err := fsx.CollectFiles(docPaths, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, docPaths, s.shouldIgnore) if err != nil { return fmt.Errorf("failed to collect files: %w", err) } @@ -332,6 +345,13 @@ func (s *BM25Strategy) CheckAndReindexChangedFiles(ctx context.Context, docPaths seenFiles := make(map[string]bool) for _, filePath := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + seenFiles[filePath] = true needsIndexing, err := s.needsIndexing(ctx, filePath) @@ -372,7 +392,7 @@ func (s *BM25Strategy) StartFileWatcher(ctx context.Context, docPaths []string, s.watcher = watcher for _, docPath := range docPaths { - if err := s.addPathToWatcher(docPath); err != nil { + if err := s.addPathToWatcher(ctx, docPath); err != nil { slog.Warn("Failed to watch path", "strategy", s.name, "path", docPath, "error", err) continue } @@ -542,6 +562,13 @@ func (s *BM25Strategy) indexFile(ctx context.Context, filePath string) error { storedChunks := 0 for _, chunk := range chunks { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if chunk.Content == "" { continue } @@ -582,6 +609,13 @@ func (s *BM25Strategy) cleanupOrphanedDocuments(ctx context.Context, seenFiles m } for _, meta := range metadata { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if !seenFiles[meta.SourcePath] { if err := s.db.DeleteDocumentsByPath(ctx, meta.SourcePath); err != nil { slog.Error("Failed to delete orphaned documents", "path", meta.SourcePath, "error", err) @@ -600,7 +634,7 @@ func (s *BM25Strategy) cleanupOrphanedDocuments(ctx context.Context, seenFiles m return nil } -func (s *BM25Strategy) addPathToWatcher(path string) error { +func (s *BM25Strategy) addPathToWatcher(ctx context.Context, path string) error { absPath, err := filepath.Abs(path) if err != nil { return fmt.Errorf("failed to get absolute path: %w", err) @@ -619,7 +653,7 @@ func (s *BM25Strategy) addPathToWatcher(path string) error { } if stat.IsDir() { - files, err := fsx.CollectFiles([]string{absPath}, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, []string{absPath}, s.shouldIgnore) if err != nil { return fmt.Errorf("failed to collect files: %w", err) } @@ -657,6 +691,13 @@ func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { } for _, file := range changedFiles { + // Check for context cancellation + select { + case <-ctx.Done(): + return // Stop processing if context is cancelled + default: + } + // Check if the file matches any of the configured document paths/patterns matches, matchErr := fsx.Matches(file, docPaths) if matchErr != nil { @@ -705,7 +746,7 @@ func (s *BM25Strategy) watchLoop(ctx context.Context, docPaths []string) { if event.Op&fsnotify.Create != 0 { s.watcherMu.Lock() - _ = s.addPathToWatcher(event.Name) + _ = s.addPathToWatcher(ctx, event.Name) s.watcherMu.Unlock() } diff --git a/pkg/rag/strategy/vector_store.go b/pkg/rag/strategy/vector_store.go index f15a60ea0..dd036851f 100644 --- a/pkg/rag/strategy/vector_store.go +++ b/pkg/rag/strategy/vector_store.go @@ -231,7 +231,7 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin // Collect all files slog.Debug("Collecting files", "strategy", s.name, "paths", docPaths) - files, err := fsx.CollectFiles(docPaths, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, docPaths, s.shouldIgnore) if err != nil { s.emitEvent(types.Event{Type: types.EventTypeError, Error: err}) return fmt.Errorf("failed to collect files: %w", err) @@ -240,6 +240,12 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin // Track seen files for cleanup seenFilesForCleanup := make(map[string]bool) for _, f := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } seenFilesForCleanup[f] = true } @@ -268,6 +274,13 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin filesToIndex := 0 for _, filePath := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + seenFiles[filePath] = true needsIndexing, err := s.needsIndexing(ctx, filePath) @@ -307,7 +320,7 @@ func (s *VectorStore) Initialize(ctx context.Context, docPaths []string, chunkin } g.Go(func() error { - // Check for context cancellation + // Check for context cancellation at start of goroutine select { case <-gctx.Done(): return gctx.Err() @@ -389,7 +402,7 @@ func (s *VectorStore) Query(ctx context.Context, query string, numResults int, t // CheckAndReindexChangedFiles checks for file changes and re-indexes if needed func (s *VectorStore) CheckAndReindexChangedFiles(ctx context.Context, docPaths []string, chunking ChunkingConfig) error { - files, err := fsx.CollectFiles(docPaths, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, docPaths, s.shouldIgnore) if err != nil { return fmt.Errorf("failed to collect files: %w", err) } @@ -397,6 +410,13 @@ func (s *VectorStore) CheckAndReindexChangedFiles(ctx context.Context, docPaths seenFiles := make(map[string]bool) for _, filePath := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + seenFiles[filePath] = true needsIndexing, err := s.needsIndexing(ctx, filePath) @@ -432,7 +452,7 @@ func (s *VectorStore) StartFileWatcher(ctx context.Context, docPaths []string, c s.watcher = watcher for _, docPath := range docPaths { - if err := s.addPathToWatcher(docPath); err != nil { + if err := s.addPathToWatcher(ctx, docPath); err != nil { slog.Warn("Failed to watch path", "strategy", s.name, "path", docPath, "error", err) continue } @@ -545,6 +565,13 @@ func (s *VectorStore) indexFile(ctx context.Context, filePath string) error { // Filter out empty chunks var validChunks []chunk.Chunk for _, ch := range chunks { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if ch.Content == "" { continue } @@ -579,6 +606,13 @@ func (s *VectorStore) indexFile(ctx context.Context, filePath string) error { // Store all documents storedChunks := 0 for i, ch := range validChunks { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + doc := database.Document{ ID: fmt.Sprintf("%s_%d_%d", filePath, ch.Index, time.Now().UnixNano()), SourcePath: filePath, @@ -624,7 +658,21 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string, g.SetLimit(s.embeddingConcurrency) for i, ch := range chunks { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + g.Go(func() error { + // Check for context cancellation + select { + case <-gctx.Done(): + return gctx.Err() + default: + } + text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(gctx, filePath, ch) if berr != nil || strings.TrimSpace(text) == "" { slog.Warn("Embedding input builder failed; falling back to raw chunk content", @@ -644,6 +692,13 @@ func (s *VectorStore) buildEmbeddingInputs(ctx context.Context, filePath string, } } else { for i, ch := range chunks { + // Check for context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + text, berr := s.embeddingInputBuilder.BuildEmbeddingInput(ctx, filePath, ch) if berr != nil || strings.TrimSpace(text) == "" { slog.Warn("Embedding input builder failed; falling back to raw chunk content", @@ -668,6 +723,13 @@ func (s *VectorStore) cleanupOrphanedDocuments(ctx context.Context, seenFiles ma deletedCount := 0 for _, meta := range metadata { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if seenFiles[meta.SourcePath] { continue } @@ -700,9 +762,9 @@ func (s *VectorStore) cleanupOrphanedDocuments(ctx context.Context, seenFiles ma return nil } -func (s *VectorStore) addPathToWatcher(path string) error { +func (s *VectorStore) addPathToWatcher(ctx context.Context, path string) error { // Resolve path(s) using Processor (handles globs, directories, files) - files, err := fsx.CollectFiles([]string{path}, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, []string{path}, s.shouldIgnore) if err != nil { return fmt.Errorf("failed to collect files for watching: %w", err) } @@ -779,6 +841,13 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { filesToReindex := make([]string, 0) for _, file := range changedFiles { + // Check for context cancellation + select { + case <-ctx.Done(): + return // Stop processing if context is cancelled + default: + } + // Check if the file matches any of the configured document paths/patterns matches, matchErr := fsx.Matches(file, docPaths) if matchErr != nil { @@ -812,6 +881,14 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { }) for i, file := range filesToReindex { + // Check for context cancellation + select { + case <-ctx.Done(): + slog.Info("File watcher stopped during reindexing due to context cancellation", "strategy", s.name) + return + default: + } + s.emitEvent(types.Event{ Type: "indexing_progress", Message: "Re-indexing: " + filepath.Base(file), @@ -862,7 +939,7 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { if event.Op&fsnotify.Create != 0 { s.watcherMu.Lock() - if err := s.addPathToWatcher(event.Name); err != nil { + if err := s.addPathToWatcher(ctx, event.Name); err != nil { slog.Debug("Could not watch new path", "path", event.Name, "error", err) } s.watcherMu.Unlock() @@ -906,13 +983,19 @@ func (s *VectorStore) watchLoop(ctx context.Context, docPaths []string) { } func (s *VectorStore) cleanupOrphanedDocumentsFromDisk(ctx context.Context, docPaths []string) error { - files, err := fsx.CollectFiles(docPaths, s.shouldIgnore) + files, err := fsx.CollectFiles(ctx, docPaths, s.shouldIgnore) if err != nil { return fmt.Errorf("failed to collect files: %w", err) } seenFiles := make(map[string]bool) for _, file := range files { + // Check for context cancellation + select { + case <-ctx.Done(): + return ctx.Err() + default: + } seenFiles[file] = true } diff --git a/pkg/tools/builtin/filesystem.go b/pkg/tools/builtin/filesystem.go index 938d1cec4..65f938112 100644 --- a/pkg/tools/builtin/filesystem.go +++ b/pkg/tools/builtin/filesystem.go @@ -384,14 +384,14 @@ func (t *FilesystemTool) shouldIgnorePath(path string) bool { // Handler implementations -func (t *FilesystemTool) handleDirectoryTree(_ context.Context, args DirectoryTreeArgs) (*tools.ToolCallResult, error) { +func (t *FilesystemTool) handleDirectoryTree(ctx context.Context, args DirectoryTreeArgs) (*tools.ToolCallResult, error) { resolvedPath := t.resolvePath(args.Path) isPathAllowed := func(_ string) error { return nil } - tree, err := fsx.DirectoryTree(resolvedPath, isPathAllowed, t.shouldIgnorePath, maxFiles) + tree, err := fsx.DirectoryTreeWithContext(ctx, resolvedPath, isPathAllowed, t.shouldIgnorePath, maxFiles) if err != nil { return tools.ResultError(fmt.Sprintf("Error building directory tree: %s", err)), nil } From c8a9ae8a21a9e8a0f173036f0ec121bf352fec19 Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Mon, 16 Mar 2026 09:44:55 +0100 Subject: [PATCH 2/2] Remove unused methods Signed-off-by: Guillaume Tardif --- pkg/fsx/collect.go | 2 +- pkg/fsx/collect_cancellation_test.go | 4 ++-- pkg/fsx/fs.go | 19 +------------------ pkg/tools/builtin/filesystem.go | 2 +- 4 files changed, 5 insertions(+), 22 deletions(-) diff --git a/pkg/fsx/collect.go b/pkg/fsx/collect.go index 612917c83..beb9a9a50 100644 --- a/pkg/fsx/collect.go +++ b/pkg/fsx/collect.go @@ -58,7 +58,7 @@ func CollectFiles(ctx context.Context, paths []string, shouldIgnore func(path st if info.IsDir() { // Use DirectoryTree to collect files from directory - tree, err := DirectoryTreeWithContext(ctx, normalized, func(string) error { return nil }, shouldIgnore, 0) + tree, err := DirectoryTree(ctx, normalized, func(string) error { return nil }, shouldIgnore, 0) if err != nil { return nil, fmt.Errorf("failed to read directory %s: %w", normalized, err) } diff --git a/pkg/fsx/collect_cancellation_test.go b/pkg/fsx/collect_cancellation_test.go index db53efe02..02a5c7aed 100644 --- a/pkg/fsx/collect_cancellation_test.go +++ b/pkg/fsx/collect_cancellation_test.go @@ -70,7 +70,7 @@ func TestDirectoryTree_ContextCancellation(t *testing.T) { // Cancel context immediately cancel() - _, err := DirectoryTreeWithContext(ctx, tmpDir, func(string) error { return nil }, nil, 0) + _, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0) assert.ErrorIs(t, err, context.Canceled) }) @@ -81,7 +81,7 @@ func TestDirectoryTree_ContextCancellation(t *testing.T) { // Give time for timeout to trigger time.Sleep(10 * time.Millisecond) - _, err := DirectoryTreeWithContext(ctx, tmpDir, func(string) error { return nil }, nil, 0) + _, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0) assert.ErrorIs(t, err, context.DeadlineExceeded) }) } diff --git a/pkg/fsx/fs.go b/pkg/fsx/fs.go index dd36ea620..12c281e41 100644 --- a/pkg/fsx/fs.go +++ b/pkg/fsx/fs.go @@ -15,13 +15,7 @@ type TreeNode struct { Children []*TreeNode `json:"children,omitempty"` } -func DirectoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) { - itemCount := 0 - return directoryTree(context.Background(), path, isPathAllowed, shouldIgnore, maxItems, &itemCount) -} - -// DirectoryTreeWithContext is a context-aware version of DirectoryTree. -func DirectoryTreeWithContext(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) { +func DirectoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) { itemCount := 0 return directoryTree(ctx, path, isPathAllowed, shouldIgnore, maxItems, &itemCount) } @@ -87,17 +81,6 @@ func directoryTree(ctx context.Context, path string, isPathAllowed func(string) return node, nil } -func ListDirectory(path string, shouldIgnore func(string) bool) ([]string, error) { - tree, err := DirectoryTree(path, func(string) error { return nil }, shouldIgnore, 0) - if err != nil { - return nil, err - } - - var files []string - CollectFilesFromTree(tree, "", &files) - return files, nil -} - // CollectFilesFromTree recursively collects file paths from a DirectoryTree. // Pass basePath="" for relative paths, or a parent directory for absolute paths. func CollectFilesFromTree(node *TreeNode, basePath string, files *[]string) { diff --git a/pkg/tools/builtin/filesystem.go b/pkg/tools/builtin/filesystem.go index 65f938112..34e250a85 100644 --- a/pkg/tools/builtin/filesystem.go +++ b/pkg/tools/builtin/filesystem.go @@ -391,7 +391,7 @@ func (t *FilesystemTool) handleDirectoryTree(ctx context.Context, args Directory return nil } - tree, err := fsx.DirectoryTreeWithContext(ctx, resolvedPath, isPathAllowed, t.shouldIgnorePath, maxFiles) + tree, err := fsx.DirectoryTree(ctx, resolvedPath, isPathAllowed, t.shouldIgnorePath, maxFiles) if err != nil { return tools.ResultError(fmt.Sprintf("Error building directory tree: %s", err)), nil }