diff --git a/cmd/internal/migrations/v3/ast_helpers.go b/cmd/internal/migrations/v3/ast_helpers.go new file mode 100644 index 0000000..f603e2d --- /dev/null +++ b/cmd/internal/migrations/v3/ast_helpers.go @@ -0,0 +1,93 @@ +package v3 + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "path" +) + +// parseGoFile parses Go source content into an AST. It returns the parsed file +// or an error if the content cannot be parsed. +func parseGoFile(content string) (*ast.File, error) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", content, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("parse Go file: %w", err) + } + return file, nil +} + +// collectImportAliases finds all import aliases for the given import path within +// the provided file. The default alias derived from the path basename is also +// included when the import does not specify an explicit name. +func collectImportAliases(file *ast.File, importPath string) map[string]struct{} { + aliases := make(map[string]struct{}) + + for _, imp := range file.Imports { + if imp.Path == nil || imp.Path.Value == "" { + continue + } + + if imp.Path.Value != "\""+importPath+"\"" { + continue + } + + if imp.Name != nil { + if imp.Name.Name == "_" || imp.Name.Name == "." { + continue + } + + aliases[imp.Name.Name] = struct{}{} + continue + } + + aliases[path.Base(importPath)] = struct{}{} + } + + return aliases +} + +// collectAssignedCallIdents walks assignment statements and collects identifier +// names that are assigned the result of a call expression matching the provided +// predicate. +func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) bool) map[string]struct{} { + matches := make(map[string]struct{}) + + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + if len(assign.Rhs) == 1 { + // Capture all identifiers from a single call returning multiple values. + if call, ok := assign.Rhs[0].(*ast.CallExpr); ok && predicate(call) { + for _, lhs := range assign.Lhs { + if ident, ok := lhs.(*ast.Ident); ok && ident.Name != "_" { + matches[ident.Name] = struct{}{} + } + } + } + } else { + // Map each call on the RHS to its corresponding identifier on the LHS. + for idx, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok || !predicate(call) { + continue + } + + if idx < len(assign.Lhs) { + if ident, ok := assign.Lhs[idx].(*ast.Ident); ok && ident.Name != "_" { + matches[ident.Name] = struct{}{} + } + } + } + } + + return true + }) + + return matches +} diff --git a/cmd/internal/migrations/v3/ast_helpers_test.go b/cmd/internal/migrations/v3/ast_helpers_test.go new file mode 100644 index 0000000..49dc94f --- /dev/null +++ b/cmd/internal/migrations/v3/ast_helpers_test.go @@ -0,0 +1,101 @@ +package v3 + +import ( + "go/ast" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_parseGoFile_InvalidContent(t *testing.T) { + t.Parallel() + + _, err := parseGoFile("package main\n func") + assert.Error(t, err) +} + +func Test_collectImportAliases(t *testing.T) { + t.Parallel() + + tests := map[string]struct { //nolint:govet // fieldalignment warning is not relevant for test data shapes + content string + importPath string + expected map[string]struct{} + }{ + "default alias": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{"session": {}}, + }, + "explicit alias": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport sess \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{"sess": {}}, + }, + "blank import ignored": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport _ \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{}, + }, + "dot import ignored": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport . \"github.com/gofiber/fiber/v3/middleware/session\"\n", + expected: map[string]struct{}{}, + }, + "unrelated import": { + importPath: "github.com/gofiber/fiber/v3/middleware/session", + content: "package main\nimport \"github.com/example/other\"\n", + expected: map[string]struct{}{}, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + file, err := parseGoFile(tt.content) + require.NoError(t, err) + + aliases := collectImportAliases(file, tt.importPath) + assert.Equal(t, tt.expected, aliases) + }) + } +} + +func Test_collectAssignedCallIdents(t *testing.T) { + t.Parallel() + + content := `package main + +func target() (int, error) { return 0, nil } +func other() int { return 1 } + +func main() { + primary, secondary := target() + single := target() + _, captured := target() + value, err := other() + first, second := other(), target() + field.Name = target() +} +` + + file, err := parseGoFile(content) + require.NoError(t, err) + + matches := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool { + if ident, ok := call.Fun.(*ast.Ident); ok { + return ident.Name == "target" + } + return false + }) + + assert.Contains(t, matches, "primary") + assert.Contains(t, matches, "secondary") + assert.Contains(t, matches, "single") + assert.Contains(t, matches, "captured") + assert.Contains(t, matches, "second") + assert.NotContains(t, matches, "value") + assert.NotContains(t, matches, "first") +} diff --git a/cmd/internal/migrations/v3/session_release.go b/cmd/internal/migrations/v3/session_release.go index da89415..064779f 100644 --- a/cmd/internal/migrations/v3/session_release.go +++ b/cmd/internal/migrations/v3/session_release.go @@ -2,6 +2,7 @@ package v3 import ( "fmt" + "go/ast" "regexp" "strings" @@ -18,13 +19,47 @@ const releaseComment = "// Important: Manual cleanup required" // This is required in v3 for manual session lifecycle management. func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { // Match patterns like: - // sess, err := store.Get(c) - // sess, err := store.GetByID(ctx, sessionID) - // session, err := myStore.Get(c) + // + // sess, err := store.Get(c) + // sess, err := store.GetByID(ctx, sessionID) + // session, err := myStore.Get(c) + // // Capture: variable name, store variable name, method call reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`) changed, err := internal.ChangeFileContent(cwd, func(content string) string { + file, err := parseGoFile(content) + if err != nil { + return content + } + + sessionAliases := collectImportAliases(file, "github.com/gofiber/fiber/v3/middleware/session") + if len(sessionAliases) == 0 { + return content + } + + storeVars := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool { + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + + pkgIdent, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + + if _, ok = sessionAliases[pkgIdent.Name]; !ok { + return false + } + + return sel.Sel.Name == "New" + }) + + if len(storeVars) == 0 { + return content + } + lines := strings.Split(content, "\n") result := make([]string, 0, len(lines)) @@ -38,6 +73,11 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) continue } + storeVar := matches[4] + if _, ok := storeVars[storeVar]; !ok { + continue + } + indent := matches[1] sessVar := matches[2] errVar := matches[3] diff --git a/cmd/internal/migrations/v3/session_release_test.go b/cmd/internal/migrations/v3/session_release_test.go index 7fa8862..dbee525 100644 --- a/cmd/internal/migrations/v3/session_release_test.go +++ b/cmd/internal/migrations/v3/session_release_test.go @@ -26,7 +26,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { return err @@ -66,7 +66,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { return err @@ -110,7 +110,7 @@ import ( ) func backgroundTask(sessionID string) { - store := session.NewStore() + store := session.New() sess, err := store.GetByID(context.Background(), sessionID) if err != nil { return @@ -150,7 +150,7 @@ import ( ) func handler(c fiber.Ctx) error { - store := session.NewStore() + store := session.New() sess, err := store.Get(c) if err != nil { c.Status(500) @@ -180,3 +180,79 @@ func handler(c fiber.Ctx) error { errorBlockEnd := strings.Index(result, "}") assert.Greater(t, deferIdx, errorBlockEnd, "defer should come after error block") } + +func Test_MigrateSessionRelease_IgnoresNonSessionStores(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "msessionrelease") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + content := `package main + +import ( + "github.com/other/module/session" +) + +func handler() { + store := session.New() + obj, err := store.Get() + if err != nil { + panic(err) + } + + _ = obj.Release +}` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + assert.NotContains(t, result, "defer obj.Release()") +} + +func Test_MigrateSessionRelease_HandlesAliasedImports(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "msessionrelease") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + content := `package main + +import ( + "github.com/gofiber/fiber/v3" + alias "github.com/gofiber/fiber/v3/middleware/session" +) + +func handler(c fiber.Ctx) error { + store := alias.New() + session, err := store.Get(c) + if err != nil { + return err + } + + session.Set("key", "value") + return session.Save() +}` + + err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600) + require.NoError(t, err) + + cmd := &cobra.Command{} + err = MigrateSessionRelease(cmd, dir, nil, nil) + require.NoError(t, err) + + data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304 + require.NoError(t, err) + + result := string(data) + assert.Contains(t, result, "defer session.Release() // Important: Manual cleanup required") +}