Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions pkg/fsx/collect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fsx

import (
"context"
"fmt"
"os"
"path/filepath"
Expand All @@ -13,11 +14,18 @@ import (
// Supports glob patterns (via doublestar), directories, and individual files.
// Skips paths that don't exist instead of returning an error.
// Optional shouldIgnore filter can exclude files/directories (return true to skip).
func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string, error) {
// Respects context cancellation.
func CollectFiles(ctx context.Context, paths []string, shouldIgnore func(path string) bool) ([]string, error) {
var files []string
seen := make(map[string]bool)

for _, pattern := range paths {
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
expanded, err := expandPattern(pattern)
if err != nil {
return nil, err
Expand All @@ -27,6 +35,12 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string
}

for _, entry := range expanded {
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
normalized := normalizePath(entry)

// Check if this path should be ignored
Expand All @@ -44,14 +58,21 @@ func CollectFiles(paths []string, shouldIgnore func(path string) bool) ([]string

if info.IsDir() {
// Use DirectoryTree to collect files from directory
tree, err := DirectoryTree(normalized, func(string) error { return nil }, shouldIgnore, 0)
tree, err := DirectoryTree(ctx, normalized, func(string) error { return nil }, shouldIgnore, 0)
if err != nil {
return nil, fmt.Errorf("failed to read directory %s: %w", normalized, err)
}
// Traverse tree and collect absolute file paths
var dirFiles []string
CollectFilesFromTree(tree, filepath.Dir(normalized), &dirFiles)
for _, f := range dirFiles {
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

absPath := normalizePath(f)
if !seen[absPath] {
files = append(files, absPath)
Expand Down
87 changes: 87 additions & 0 deletions pkg/fsx/collect_cancellation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package fsx

import (
"context"
"fmt"
"os"
"path/filepath"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCollectFiles_ContextCancellation(t *testing.T) {
t.Parallel()

tmpDir := t.TempDir()

// Create a large directory structure to ensure context cancellation has time to kick in
for i := range 100 {
subDir := filepath.Join(tmpDir, "dir", "subdir", "deepdir", fmt.Sprintf("dir%d", i))
require.NoError(t, os.MkdirAll(subDir, 0o755))
for j := range 10 {
filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j))
require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644))
}
}

t.Run("respects context cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())

// Cancel context immediately
cancel()

_, err := CollectFiles(ctx, []string{tmpDir}, nil)
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("respects context timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond)
defer cancel()

// Give time for timeout to trigger
time.Sleep(10 * time.Millisecond)

_, err := CollectFiles(ctx, []string{tmpDir}, nil)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}

func TestDirectoryTree_ContextCancellation(t *testing.T) {
t.Parallel()

tmpDir := t.TempDir()

// Create a large directory structure
for i := range 100 {
subDir := filepath.Join(tmpDir, "dir", "subdir", fmt.Sprintf("dir%d", i))
require.NoError(t, os.MkdirAll(subDir, 0o755))
for j := range 10 {
filePath := filepath.Join(subDir, fmt.Sprintf("file%d.txt", j))
require.NoError(t, os.WriteFile(filePath, []byte("test content"), 0o644))
}
}

t.Run("respects context cancellation", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())

// Cancel context immediately
cancel()

_, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0)
assert.ErrorIs(t, err, context.Canceled)
})

t.Run("respects context timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Nanosecond)
defer cancel()

// Give time for timeout to trigger
time.Sleep(10 * time.Millisecond)

_, err := DirectoryTree(ctx, tmpDir, func(string) error { return nil }, nil, 0)
assert.ErrorIs(t, err, context.DeadlineExceeded)
})
}
20 changes: 10 additions & 10 deletions pkg/fsx/collect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
t.Run("no filter collects all files", func(t *testing.T) {
t.Parallel()

got, err := CollectFiles([]string{tmpDir}, nil)
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
require.NoError(t, err)
assert.Len(t, got, 5)
})
Expand All @@ -66,7 +66,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
return base == "vendor" || base == "node_modules"
}

got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
require.NoError(t, err)

// Should only have src/*.go and build/output.bin
Expand All @@ -87,7 +87,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
return strings.HasSuffix(path, ".bin")
}

got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
require.NoError(t, err)

assert.Len(t, got, 4)
Expand All @@ -104,7 +104,7 @@ func TestCollectFiles_WithShouldIgnoreFilter(t *testing.T) {
return strings.Contains(path, "vendor")
}

got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
require.NoError(t, err)

for _, f := range got {
Expand Down Expand Up @@ -148,7 +148,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) {
t.Run("without filter includes .git", func(t *testing.T) {
t.Parallel()

got, err := CollectFiles([]string{tmpDir}, nil)
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
require.NoError(t, err)

// Should include .git files
Expand Down Expand Up @@ -177,7 +177,7 @@ func TestCollectFiles_GitDirectoryExclusion(t *testing.T) {
strings.HasPrefix(normalized, ".git/")
}

got, err := CollectFiles([]string{tmpDir}, shouldIgnore)
got, err := CollectFiles(t.Context(), []string{tmpDir}, shouldIgnore)
require.NoError(t, err)

// Should only have src/main.go
Expand Down Expand Up @@ -229,7 +229,7 @@ func TestCollectFiles_GlobsWithFilter(t *testing.T) {
return strings.HasSuffix(path, "_test.go")
}

got, err := CollectFiles([]string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore)
got, err := CollectFiles(t.Context(), []string{filepath.Join(tmpDir, "pkg", "**", "*.go")}, shouldIgnore)
require.NoError(t, err)

// Should only have non-test .go files
Expand Down Expand Up @@ -329,7 +329,7 @@ func TestCollectFiles_Deduplication(t *testing.T) {
tmpDir, // Will also include test.go
}

got, err := CollectFiles(patterns, nil)
got, err := CollectFiles(t.Context(), patterns, nil)
require.NoError(t, err)

// Should only have one entry
Expand All @@ -352,7 +352,7 @@ func TestCollectFiles_NonExistentPaths(t *testing.T) {
filepath.Join(tmpDir, "also", "missing", "file.go"),
}

got, err := CollectFiles(patterns, nil)
got, err := CollectFiles(t.Context(), patterns, nil)
require.NoError(t, err)

// Should only have the real file
Expand All @@ -371,7 +371,7 @@ func TestCollectFiles_SortedOutput(t *testing.T) {
require.NoError(t, os.WriteFile(filepath.Join(tmpDir, f), []byte("package test"), 0o644))
}

got, err := CollectFiles([]string{tmpDir}, nil)
got, err := CollectFiles(t.Context(), []string{tmpDir}, nil)
require.NoError(t, err)

// Verify we got all files
Expand Down
32 changes: 17 additions & 15 deletions pkg/fsx/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@ type TreeNode struct {
Children []*TreeNode `json:"children,omitempty"`
}

func DirectoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) {
func DirectoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int) (*TreeNode, error) {
itemCount := 0
return directoryTree(path, isPathAllowed, shouldIgnore, maxItems, &itemCount)
return directoryTree(ctx, path, isPathAllowed, shouldIgnore, maxItems, &itemCount)
}

func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) {
func directoryTree(ctx context.Context, path string, isPathAllowed func(string) error, shouldIgnore func(string) bool, maxItems int, itemCount *int) (*TreeNode, error) {
// Check for context cancellation
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
if maxItems > 0 && *itemCount >= maxItems {
return nil, nil
}
Expand All @@ -47,6 +53,13 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
}

for _, entry := range entries {
// Check for context cancellation
select {
case <-ctx.Done():
return node, ctx.Err()
default:
}

childPath := filepath.Join(path, entry.Name())
if err := isPathAllowed(childPath); err != nil {
continue // Skip disallowed paths
Expand All @@ -57,7 +70,7 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
continue
}

childNode, err := directoryTree(childPath, isPathAllowed, shouldIgnore, maxItems, itemCount)
childNode, err := directoryTree(ctx, childPath, isPathAllowed, shouldIgnore, maxItems, itemCount)
if err != nil || childNode == nil {
continue
}
Expand All @@ -68,17 +81,6 @@ func directoryTree(path string, isPathAllowed func(string) error, shouldIgnore f
return node, nil
}

func ListDirectory(path string, shouldIgnore func(string) bool) ([]string, error) {
tree, err := DirectoryTree(path, func(string) error { return nil }, shouldIgnore, 0)
if err != nil {
return nil, err
}

var files []string
CollectFilesFromTree(tree, "", &files)
return files, nil
}

// CollectFilesFromTree recursively collects file paths from a DirectoryTree.
// Pass basePath="" for relative paths, or a parent directory for absolute paths.
func CollectFilesFromTree(node *TreeNode, basePath string, files *[]string) {
Expand Down
Loading
Loading