From 080f0b81c667438c13940ec21f640634e31dfca6 Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 7 Dec 2025 16:20:26 +0100 Subject: [PATCH 1/3] Improve session release migration targeting --- cmd/internal/migrations/v3/ast_helpers.go | 78 +++++++++++++++++++ cmd/internal/migrations/v3/session_release.go | 46 ++++++++++- .../migrations/v3/session_release_test.go | 76 ++++++++++++++++++ 3 files changed, 197 insertions(+), 3 deletions(-) create mode 100644 cmd/internal/migrations/v3/ast_helpers.go diff --git a/cmd/internal/migrations/v3/ast_helpers.go b/cmd/internal/migrations/v3/ast_helpers.go new file mode 100644 index 0000000..5c70502 --- /dev/null +++ b/cmd/internal/migrations/v3/ast_helpers.go @@ -0,0 +1,78 @@ +package v3 + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "path" +) + +// parseGoFile parses Go source content into an AST. It returns the parsed file +// and token.FileSet or an error if the content cannot be parsed. +func parseGoFile(content string) (*ast.File, *token.FileSet, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", content, parser.ParseComments) + if err != nil { + return nil, nil, fmt.Errorf("parse Go file: %w", err) + } + return file, fset, nil +} + +// collectImportAliases finds all import aliases for the given import path within +// the provided file. The default alias derived from the path basename is also +// included when the import does not specify an explicit name. +func collectImportAliases(file *ast.File, importPath string) map[string]struct{} { + aliases := make(map[string]struct{}) + + for _, imp := range file.Imports { + if imp.Path == nil || imp.Path.Value == "" { + continue + } + + if imp.Path.Value != "\""+importPath+"\"" { + continue + } + + if imp.Name != nil { + aliases[imp.Name.Name] = struct{}{} + continue + } + + aliases[path.Base(importPath)] = struct{}{} + } + + return aliases +} + +// collectAssignedCallIdents walks assignment statements and collects identifier +// names that are assigned the result of a call expression matching the provided +// predicate. +func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) bool) map[string]struct{} { + matches := make(map[string]struct{}) + + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok || len(assign.Lhs) != len(assign.Rhs) { + return true + } + + for idx, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok || !predicate(call) { + continue + } + + ident, ok := assign.Lhs[idx].(*ast.Ident) + if !ok { + continue + } + + matches[ident.Name] = struct{}{} + } + + return true + }) + + return matches +} diff --git a/cmd/internal/migrations/v3/session_release.go b/cmd/internal/migrations/v3/session_release.go index da89415..3375b46 100644 --- a/cmd/internal/migrations/v3/session_release.go +++ b/cmd/internal/migrations/v3/session_release.go @@ -2,6 +2,7 @@ package v3 import ( "fmt" + "go/ast" "regexp" "strings" @@ -18,13 +19,47 @@ const releaseComment = "// Important: Manual cleanup required" // This is required in v3 for manual session lifecycle management. func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { // Match patterns like: - // sess, err := store.Get(c) - // sess, err := store.GetByID(ctx, sessionID) - // session, err := myStore.Get(c) + // + // sess, err := store.Get(c) + // sess, err := store.GetByID(ctx, sessionID) + // session, err := myStore.Get(c) + // // Capture: variable name, store variable name, method call reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`) changed, err := internal.ChangeFileContent(cwd, func(content string) string { + file, _, err := parseGoFile(content) + if err != nil { + return content + } + + sessionAliases := collectImportAliases(file, "github.com/gofiber/fiber/v3/middleware/session") + if len(sessionAliases) == 0 { + return content + } + + storeVars := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + + pkgIdent, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + + if _, ok = sessionAliases[pkgIdent.Name]; !ok { + return false + } + + return sel.Sel.Name == "New" || strings.HasPrefix(sel.Sel.Name, "NewStore") + }) + + if len(storeVars) == 0 { + return content + } + lines := strings.Split(content, "\n") result := make([]string, 0, len(lines)) @@ -38,6 +73,11 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) continue } + storeVar := matches[4] + if _, ok := storeVars[storeVar]; !ok { + continue + } + indent := matches[1] sessVar := matches[2] errVar := matches[3] diff --git a/cmd/internal/migrations/v3/session_release_test.go b/cmd/internal/migrations/v3/session_release_test.go index 7fa8862..a5c9408 100644 --- a/cmd/internal/migrations/v3/session_release_test.go +++ b/cmd/internal/migrations/v3/session_release_test.go @@ -180,3 +180,79 @@ func handler(c fiber.Ctx) error { errorBlockEnd := strings.Index(result, "}") assert.Greater(t, deferIdx, errorBlockEnd, "defer should come after error block") } + +func Test_MigrateSessionRelease_IgnoresNonSessionStores(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "msessionrelease") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + content := `package main + +import ( + "github.com/other/module/session" +) + +func handler() { + store := session.New() + obj, err := store.Get() + if err != nil { + panic(err) + } + + _ = obj.Release +}` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + assert.NotContains(t, result, "defer obj.Release()") +} + +func Test_MigrateSessionRelease_HandlesAliasedImports(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "msessionrelease") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + alias "github.com/gofiber/fiber/v3/middleware/session" +) + +func handler(c fiber.Ctx) error { + store := alias.NewStore() + session, err := store.Get(c) + if err != nil { + return err + } + + session.Set("key", "value") + return session.Save() +}` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + assert.Contains(t, result, "defer session.Release() // Important: Manual cleanup required") +} From d9bd671b7c3433d1e839764597a4fa3b2f534f64 Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 7 Dec 2025 19:35:27 +0100 Subject: [PATCH 2/3] Harden session migration helpers --- cmd/internal/migrations/v3/ast_helpers.go | 14 ++- .../migrations/v3/ast_helpers_test.go | 97 +++++++++++++++++++ cmd/internal/migrations/v3/session_release.go | 4 +- .../migrations/v3/session_release_test.go | 10 +- 4 files changed, 113 insertions(+), 12 deletions(-) create mode 100644 cmd/internal/migrations/v3/ast_helpers_test.go diff --git a/cmd/internal/migrations/v3/ast_helpers.go b/cmd/internal/migrations/v3/ast_helpers.go index 5c70502..07a129d 100644 --- a/cmd/internal/migrations/v3/ast_helpers.go +++ b/cmd/internal/migrations/v3/ast_helpers.go @@ -9,14 +9,14 @@ import ( ) // parseGoFile parses Go source content into an AST. It returns the parsed file -// and token.FileSet or an error if the content cannot be parsed. -func parseGoFile(content string) (*ast.File, *token.FileSet, error) { +// or an error if the content cannot be parsed. +func parseGoFile(content string) (*ast.File, error) { fset := token.NewFileSet() file, err := parser.ParseFile(fset, "", content, parser.ParseComments) if err != nil { - return nil, nil, fmt.Errorf("parse Go file: %w", err) + return nil, fmt.Errorf("parse Go file: %w", err) } - return file, fset, nil + return file, nil } // collectImportAliases finds all import aliases for the given import path within @@ -35,6 +35,10 @@ func collectImportAliases(file *ast.File, importPath string) map[string]struct{} } if imp.Name != nil { + if imp.Name.Name == "_" || imp.Name.Name == "." { + continue + } + aliases[imp.Name.Name] = struct{}{} continue } @@ -53,7 +57,7 @@ func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) boo ast.Inspect(file, func(n ast.Node) bool { assign, ok := n.(*ast.AssignStmt) - if !ok || len(assign.Lhs) != len(assign.Rhs) { + if !ok { return true } diff --git a/cmd/internal/migrations/v3/ast_helpers_test.go b/cmd/internal/migrations/v3/ast_helpers_test.go new file mode 100644 index 0000000..e35eab4 --- /dev/null +++ b/cmd/internal/migrations/v3/ast_helpers_test.go @@ -0,0 +1,97 @@ +package v3 + +import ( + "go/ast" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_parseGoFile_InvalidContent(t *testing.T) { + t.Parallel() + + _, err := parseGoFile("package main\n func") + assert.Error(t, err) +} + +func Test_collectImportAliases(t *testing.T) { + t.Parallel() + + tests := map[string]struct { //nolint:govet // fieldalignment warning is not relevant for test data shapes + content string + importPath string + expected map[string]struct{} + }{ + "default alias": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{"session": {}}, + }, + "explicit alias": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport sess \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{"sess": {}}, + }, + "blank import ignored": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport _ \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{}, + }, + "dot import ignored": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport . \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{}, + }, + "unrelated import": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport \"github.com/example/other\"\n", + expected: map[string]struct{}{}, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + file, err := parseGoFile(tt.content) + require.NoError(t, err) + + aliases := collectImportAliases(file, tt.importPath) + assert.Equal(t, tt.expected, aliases) + }) + } +} + +func Test_collectAssignedCallIdents(t *testing.T) { + t.Parallel() + + content := `package main + +func target() (int, error) { return 0, nil } +func other() int { return 1 } + +func main() { + primary, secondary := target() + single := target() + _, ignored := target() + value, err := other() + field.Name = target() +} +` + + file, err := parseGoFile(content) + require.NoError(t, err) + + matches := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool { + if ident, ok := call.Fun.(*ast.Ident); ok { + return ident.Name == "target" + } + return false + }) + + assert.Contains(t, matches, "primary") + assert.Contains(t, matches, "single") + assert.NotContains(t, matches, "value") + assert.NotContains(t, matches, "ignored") +} diff --git a/cmd/internal/migrations/v3/session_release.go b/cmd/internal/migrations/v3/session_release.go index 3375b46..064779f 100644 --- a/cmd/internal/migrations/v3/session_release.go +++ b/cmd/internal/migrations/v3/session_release.go @@ -28,7 +28,7 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`) changed, err := internal.ChangeFileContent(cwd, func(content string) string { - file, _, err := parseGoFile(content) + file, err := parseGoFile(content) if err != nil { return content } @@ -53,7 +53,7 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) return false } - return sel.Sel.Name == "New" || strings.HasPrefix(sel.Sel.Name, "NewStore") + return sel.Sel.Name == "New" }) if len(storeVars) == 0 { diff --git a/cmd/internal/migrations/v3/session_release_test.go b/cmd/internal/migrations/v3/session_release_test.go index a5c9408..dbee525 100644 --- a/cmd/internal/migrations/v3/session_release_test.go +++ b/cmd/internal/migrations/v3/session_release_test.go @@ -26,7 +26,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { return err @@ -66,7 +66,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { return err @@ -110,7 +110,7 @@ import ( ) func backgroundTask(sessionID string) { - store := session.NewStore() + store := session.New() sess, err := store.GetByID(context.Background(), sessionID) if err != nil { return @@ -150,7 +150,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { c.Status(500) @@ -233,7 +233,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := alias.NewStore() + store := alias.New() session, err := store.Get(c) if err != nil { return err From 1b1b0e77aa3c824bf3f2cc2549f0526b5415f74d Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 7 Dec 2025 20:03:37 +0100 Subject: [PATCH 3/3] Handle multi-value assigned calls in AST helper --- cmd/internal/migrations/v3/ast_helpers.go | 31 +++++++++++++------ .../migrations/v3/ast_helpers_test.go | 8 +++-- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/cmd/internal/migrations/v3/ast_helpers.go b/cmd/internal/migrations/v3/ast_helpers.go index 07a129d..f603e2d 100644 --- a/cmd/internal/migrations/v3/ast_helpers.go +++ b/cmd/internal/migrations/v3/ast_helpers.go @@ -61,18 +61,29 @@ func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) boo return true } - for idx, rhs := range assign.Rhs { - call, ok := rhs.(*ast.CallExpr) - if !ok || !predicate(call) { - continue + if len(assign.Rhs) == 1 { + // Capture all identifiers from a single call returning multiple values. + if call, ok := assign.Rhs[0].(*ast.CallExpr); ok && predicate(call) { + for _, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && ident.Name != "_" { + matches[ident.Name] = struct{}{} + } + } } - - ident, ok := assign.Lhs[idx].(*ast.Ident) - if !ok { - continue + } else { + // Map each call on the RHS to its corresponding identifier on the LHS. + for idx, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok || !predicate(call) { + continue + } + + if idx < len(assign.Lhs) { + if ident, ok := assign.Lhs[idx].(*ast.Ident); ok && ident.Name != "_" { + matches[ident.Name] = struct{}{} + } + } } - - matches[ident.Name] = struct{}{} } return true diff --git a/cmd/internal/migrations/v3/ast_helpers_test.go b/cmd/internal/migrations/v3/ast_helpers_test.go index e35eab4..49dc94f 100644 --- a/cmd/internal/migrations/v3/ast_helpers_test.go +++ b/cmd/internal/migrations/v3/ast_helpers_test.go @@ -74,8 +74,9 @@ func other() int { return 1 } func main() { primary, secondary := target() single := target() - _, ignored := target() + _, captured := target() value, err := other() + first, second := other(), target() field.Name = target() } ` @@ -91,7 +92,10 @@ func main() { }) assert.Contains(t, matches, "primary") + assert.Contains(t, matches, "secondary") assert.Contains(t, matches, "single") + assert.Contains(t, matches, "captured") + assert.Contains(t, matches, "second") assert.NotContains(t, matches, "value") - assert.NotContains(t, matches, "ignored") + assert.NotContains(t, matches, "first") }