From 5964359ef69815882c1904b2440ac0ed595a4839 Mon Sep 17 00:00:00 2001 From: Pu Junsong Date: Mon, 15 Sep 2025 11:56:02 +0800 Subject: [PATCH] fix issue #61 #66 --- lang/golang/parser/utils.go | 81 +++++++++++++++++++++++++++- lang/golang/parser/utils_test.go | 90 ++++++++++++++++++++++++++++++++ lang/golang/writer/write.go | 21 +++++++- 3 files changed, 190 insertions(+), 2 deletions(-) diff --git a/lang/golang/parser/utils.go b/lang/golang/parser/utils.go index dedfbf7d..87edce81 100644 --- a/lang/golang/parser/utils.go +++ b/lang/golang/parser/utils.go @@ -17,14 +17,17 @@ package parser import ( "bufio" "bytes" + "container/list" "fmt" "go/ast" + "go/build" "go/types" "io" "os" "path" "regexp" "strings" + "sync" "github.com/Knetic/govaluate" . "github.com/cloudwego/abcoder/lang/uniast" @@ -49,8 +52,84 @@ func (c cache) Visited(val interface{}) bool { return ok } +type cacheEntry struct { + key string + value bool +} + +// PackageCache 缓存 importPath 是否是 system package +type PackageCache struct { + lock sync.Mutex + cache map[string]*list.Element + lru *list.List + lruCapacity int +} + +func NewPackageCache(lruCapacity int) *PackageCache { + return &PackageCache{ + cache: make(map[string]*list.Element), + lru: list.New(), + lruCapacity: lruCapacity, + } +} + +// get retrieves a value from the cache. +func (pc *PackageCache) get(key string) (bool, bool) { + pc.lock.Lock() + defer pc.lock.Unlock() + if elem, ok := pc.cache[key]; ok { + pc.lru.MoveToFront(elem) + return elem.Value.(*cacheEntry).value, true + } + return false, false +} + +// set adds a value to the cache. +func (pc *PackageCache) set(key string, value bool) { + pc.lock.Lock() + defer pc.lock.Unlock() + + if elem, ok := pc.cache[key]; ok { + pc.lru.MoveToFront(elem) + elem.Value.(*cacheEntry).value = value + return + } + + if pc.lru.Len() >= pc.lruCapacity { + oldest := pc.lru.Back() + if oldest != nil { + pc.lru.Remove(oldest) + delete(pc.cache, oldest.Value.(*cacheEntry).key) + } + } + + elem := pc.lru.PushFront(&cacheEntry{key: key, value: value}) + pc.cache[key] = elem +} + +// IsStandardPackage 检查一个包是否为标准库,并使用内部缓存。 +func (pc *PackageCache) IsStandardPackage(path string) bool { + if isStd, found := pc.get(path); found { + return isStd + } + + pkg, err := build.Import(path, "", build.FindOnly) + if err != nil { + // Cannot find the package, assume it's not a standard package + pc.set(path, false) + return false + } + + isStd := pkg.Goroot + pc.set(path, isStd) + return isStd +} + +// stdlibCache 缓存 importPath 是否是 system package, 10000 个缓存 +var stdlibCache = NewPackageCache(10000) + func isSysPkg(importPath string) bool { - return !strings.Contains(strings.Split(importPath, "/")[0], ".") + return stdlibCache.IsStandardPackage(importPath) } var ( diff --git a/lang/golang/parser/utils_test.go b/lang/golang/parser/utils_test.go index fdd7a0e2..301e6a1d 100644 --- a/lang/golang/parser/utils_test.go +++ b/lang/golang/parser/utils_test.go @@ -21,8 +21,11 @@ import ( "go/token" "go/types" "slices" + "sync" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) @@ -195,3 +198,90 @@ var f func() (*http.Request, error)`, }) } } + +func resetGlobals() { + // 重置包缓存 + stdlibCache = NewPackageCache(10000) +} + +func Test_isSysPkg(t *testing.T) { + // 测试在 `go env GOROOT` 可以成功执行时的行为 + t.Run("Group: Happy Path - GOROOT is found", func(t *testing.T) { + resetGlobals() + + testCases := []struct { + name string + importPath string + want bool + }{ + {"standard library package", "fmt", true}, + {"nested standard library package", "net/http", true}, + {"third-party package", "github.com/google/uuid", false}, + {"extended library package", "golang.org/x/sync/errgroup", false}, + {"local-like package name", "myproject/utils", false}, + {"non-existent package", "non/existent/package", false}, + {"root-level package with dot", "gopkg.in/yaml.v2", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if got := isSysPkg(tc.importPath); got != tc.want { + t.Errorf("isSysPkg(%q) = %v, want %v", tc.importPath, got, tc.want) + } + }) + } + }) + + // 测试并发调用时的行为 + t.Run("Group: Concurrency Test", func(t *testing.T) { + resetGlobals() + var wg sync.WaitGroup + numGoroutines := 50 + numOpsPerGoroutine := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < numOpsPerGoroutine; j++ { + isSysPkg("fmt") + isSysPkg("github.com/cloudwego/abcoder") + isSysPkg("net/http") + isSysPkg("a/b/c") + } + }() + } + wg.Wait() + }) + + // 测试 LRU 缓存的驱逐策略 + t.Run("Group: LRU Eviction Test", func(t *testing.T) { + resetGlobals() + stdlibCache.lruCapacity = 2 + + // 1. 填满 Cache + isSysPkg("fmt") + isSysPkg("os") + assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache should be full") + + // 2. 访问 "fmt" 使它最近被使用 + isSysPkg("fmt") + assert.Equal(t, "fmt", stdlibCache.lru.Front().Value.(*cacheEntry).key, "fmt should be the most recently used") + + // 3. 访问 "net" 使它最近被使用 + isSysPkg("net") // "os" should be evicted + assert.Equal(t, 2, stdlibCache.lru.Len(), "Cache size should remain at capacity") + + // 4. "fmt" 应该在 Cache 中 + _, foundFmt := stdlibCache.get("fmt") + assert.True(t, foundFmt, "fmt should still be in the cache") + + // 5. "net" 应该在 Cache 中 + _, foundNet := stdlibCache.get("net") + assert.True(t, foundNet, "net should be in the cache") + + // 6. "os" 不应该在 Cache 中 + _, foundOs := stdlibCache.get("os") + assert.False(t, foundOs, "os should have been evicted from the cache") + }) +} diff --git a/lang/golang/writer/write.go b/lang/golang/writer/write.go index 89fd6624..77118eca 100644 --- a/lang/golang/writer/write.go +++ b/lang/golang/writer/write.go @@ -35,6 +35,7 @@ import ( ) var _ uniast.Writer = (*Writer)(nil) +var testPkgPathRegex = regexp.MustCompile(`^(.+?) \[(.+)\]$`) type Options struct { // RepoDir string @@ -81,6 +82,22 @@ func (w *Writer) WriteRepo(repo *uniast.Repository, outDir string) error { return nil } +// sanitizePkgPath sanitize the package path, remove the suffix in brackets +func sanitizePkgPath(pkgPath string) string { + matches := testPkgPathRegex.FindStringSubmatch(pkgPath) + // matches should be 3 elements: + // 1. The full string + // 2. The package name + // 3. The content inside the brackets + if len(matches) == 3 { + packageName := matches[1] + testName := matches[2] + if testName == packageName+".test" { + return packageName + } + } + return pkgPath +} func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir string) error { mod := repo.Modules[modPath] if mod == nil { @@ -94,7 +111,9 @@ func (w *Writer) WriteModule(repo *uniast.Repository, modPath string, outDir str outdir := filepath.Join(outDir, mod.Dir) for dir, pkg := range w.visited { - rel := strings.TrimPrefix(dir, mod.Name) + // sanitize the package path + cleanDir := sanitizePkgPath(dir) + rel := strings.TrimPrefix(cleanDir, mod.Name) pkgDir := filepath.Join(outdir, rel) if err := os.MkdirAll(pkgDir, 0755); err != nil { return fmt.Errorf("mkdir %s failed: %v", pkgDir, err)