Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions cmd/internal/migrations/v3/ast_helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package v3

import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"path"
)

// parseGoFile parses Go source content into an AST. It returns the parsed file
// or an error if the content cannot be parsed.
func parseGoFile(content string) (*ast.File, error) {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", content, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("parse Go file: %w", err)
}
return file, nil
}

// collectImportAliases finds all import aliases for the given import path within
// the provided file. The default alias derived from the path basename is also
// included when the import does not specify an explicit name.
func collectImportAliases(file *ast.File, importPath string) map[string]struct{} {
aliases := make(map[string]struct{})

for _, imp := range file.Imports {
if imp.Path == nil || imp.Path.Value == "" {
continue
}

if imp.Path.Value != "\""+importPath+"\"" {
continue
}

if imp.Name != nil {
Comment thread
ReneWerner87 marked this conversation as resolved.
if imp.Name.Name == "_" || imp.Name.Name == "." {
continue
}

aliases[imp.Name.Name] = struct{}{}
continue
}

aliases[path.Base(importPath)] = struct{}{}
}

return aliases
}

// collectAssignedCallIdents walks assignment statements and collects identifier
// names that are assigned the result of a call expression matching the provided
// predicate.
func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) bool) map[string]struct{} {
matches := make(map[string]struct{})

ast.Inspect(file, func(n ast.Node) bool {
assign, ok := n.(*ast.AssignStmt)
if !ok {
return true
}

if len(assign.Rhs) == 1 {
// Capture all identifiers from a single call returning multiple values.
if call, ok := assign.Rhs[0].(*ast.CallExpr); ok && predicate(call) {
for _, lhs := range assign.Lhs {
if ident, ok := lhs.(*ast.Ident); ok && ident.Name != "_" {
matches[ident.Name] = struct{}{}
}
}
}
} else {
// Map each call on the RHS to its corresponding identifier on the LHS.
for idx, rhs := range assign.Rhs {
call, ok := rhs.(*ast.CallExpr)
if !ok || !predicate(call) {
continue
}

if idx < len(assign.Lhs) {
if ident, ok := assign.Lhs[idx].(*ast.Ident); ok && ident.Name != "_" {
matches[ident.Name] = struct{}{}
}
}
}
}

return true
})

return matches
}
101 changes: 101 additions & 0 deletions cmd/internal/migrations/v3/ast_helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package v3

import (
"go/ast"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_parseGoFile_InvalidContent(t *testing.T) {
t.Parallel()

_, err := parseGoFile("package main\n func")
assert.Error(t, err)
}

func Test_collectImportAliases(t *testing.T) {
t.Parallel()

tests := map[string]struct { //nolint:govet // fieldalignment warning is not relevant for test data shapes
content string
importPath string
expected map[string]struct{}
}{
"default alias": {
importPath: "github.com/gofiber/fiber/v3/middleware/session",
content: "package main\nimport \"github.com/gofiber/fiber/v3/middleware/session\"\n",
expected: map[string]struct{}{"session": {}},
},
"explicit alias": {
importPath: "github.com/gofiber/fiber/v3/middleware/session",
content: "package main\nimport sess \"github.com/gofiber/fiber/v3/middleware/session\"\n",
expected: map[string]struct{}{"sess": {}},
},
"blank import ignored": {
importPath: "github.com/gofiber/fiber/v3/middleware/session",
content: "package main\nimport _ \"github.com/gofiber/fiber/v3/middleware/session\"\n",
expected: map[string]struct{}{},
},
"dot import ignored": {
importPath: "github.com/gofiber/fiber/v3/middleware/session",
content: "package main\nimport . \"github.com/gofiber/fiber/v3/middleware/session\"\n",
expected: map[string]struct{}{},
},
"unrelated import": {
importPath: "github.com/gofiber/fiber/v3/middleware/session",
content: "package main\nimport \"github.com/example/other\"\n",
expected: map[string]struct{}{},
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
t.Parallel()

file, err := parseGoFile(tt.content)
require.NoError(t, err)

aliases := collectImportAliases(file, tt.importPath)
assert.Equal(t, tt.expected, aliases)
})
}
}

func Test_collectAssignedCallIdents(t *testing.T) {
t.Parallel()

content := `package main

func target() (int, error) { return 0, nil }
func other() int { return 1 }

func main() {
primary, secondary := target()
single := target()
_, captured := target()
value, err := other()
first, second := other(), target()
field.Name = target()
}
`

file, err := parseGoFile(content)
require.NoError(t, err)

matches := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool {
if ident, ok := call.Fun.(*ast.Ident); ok {
return ident.Name == "target"
}
return false
})

assert.Contains(t, matches, "primary")
assert.Contains(t, matches, "secondary")
assert.Contains(t, matches, "single")
Comment thread
ReneWerner87 marked this conversation as resolved.
assert.Contains(t, matches, "captured")
assert.Contains(t, matches, "second")
assert.NotContains(t, matches, "value")
assert.NotContains(t, matches, "first")
}
46 changes: 43 additions & 3 deletions cmd/internal/migrations/v3/session_release.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package v3

import (
"fmt"
"go/ast"
"regexp"
"strings"

Expand All @@ -18,13 +19,47 @@ const releaseComment = "// Important: Manual cleanup required"
// This is required in v3 for manual session lifecycle management.
func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
// Match patterns like:
// sess, err := store.Get(c)
// sess, err := store.GetByID(ctx, sessionID)
// session, err := myStore.Get(c)
//
// sess, err := store.Get(c)
// sess, err := store.GetByID(ctx, sessionID)
// session, err := myStore.Get(c)
//
// Capture: variable name, store variable name, method call
reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`)

changed, err := internal.ChangeFileContent(cwd, func(content string) string {
file, err := parseGoFile(content)
if err != nil {
return content
}

sessionAliases := collectImportAliases(file, "github.com/gofiber/fiber/v3/middleware/session")
if len(sessionAliases) == 0 {
return content
}

storeVars := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return false
}

pkgIdent, ok := sel.X.(*ast.Ident)
if !ok {
return false
}

if _, ok = sessionAliases[pkgIdent.Name]; !ok {
return false
}

return sel.Sel.Name == "New"
})

if len(storeVars) == 0 {
return content
}

lines := strings.Split(content, "\n")
result := make([]string, 0, len(lines))

Expand All @@ -38,6 +73,11 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version)
continue
}

storeVar := matches[4]
if _, ok := storeVars[storeVar]; !ok {
continue
}

indent := matches[1]
sessVar := matches[2]
errVar := matches[3]
Expand Down
84 changes: 80 additions & 4 deletions cmd/internal/migrations/v3/session_release_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func handler(c fiber.Ctx) error {
store := session.NewStore()
store := session.New()
sess, err := store.Get(c)
if err != nil {
return err
Expand Down Expand Up @@ -66,7 +66,7 @@ import (
)

func handler(c fiber.Ctx) error {
store := session.NewStore()
store := session.New()
sess, err := store.Get(c)
if err != nil {
return err
Expand Down Expand Up @@ -110,7 +110,7 @@ import (
)

func backgroundTask(sessionID string) {
store := session.NewStore()
store := session.New()
sess, err := store.GetByID(context.Background(), sessionID)
if err != nil {
return
Expand Down Expand Up @@ -150,7 +150,7 @@ import (
)

func handler(c fiber.Ctx) error {
store := session.NewStore()
store := session.New()
sess, err := store.Get(c)
if err != nil {
c.Status(500)
Expand Down Expand Up @@ -180,3 +180,79 @@ func handler(c fiber.Ctx) error {
errorBlockEnd := strings.Index(result, "}")
assert.Greater(t, deferIdx, errorBlockEnd, "defer should come after error block")
}

func Test_MigrateSessionRelease_IgnoresNonSessionStores(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "msessionrelease")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

content := `package main

import (
"github.com/other/module/session"
)

func handler() {
store := session.New()
obj, err := store.Get()
if err != nil {
panic(err)
}

_ = obj.Release
}`

err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600)
require.NoError(t, err)

cmd := &cobra.Command{}
err = MigrateSessionRelease(cmd, dir, nil, nil)
require.NoError(t, err)

data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304
require.NoError(t, err)

result := string(data)
assert.NotContains(t, result, "defer obj.Release()")
}

func Test_MigrateSessionRelease_HandlesAliasedImports(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "msessionrelease")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

content := `package main

import (
"github.com/gofiber/fiber/v3"
alias "github.com/gofiber/fiber/v3/middleware/session"
)

func handler(c fiber.Ctx) error {
store := alias.New()
session, err := store.Get(c)
if err != nil {
return err
}

session.Set("key", "value")
return session.Save()
}`

err = os.WriteFile(filepath.Join(dir, "main.go"), []byte(content), 0o600)
require.NoError(t, err)

cmd := &cobra.Command{}
err = MigrateSessionRelease(cmd, dir, nil, nil)
require.NoError(t, err)

data, err := os.ReadFile(filepath.Join(dir, "main.go")) // #nosec G304
require.NoError(t, err)

result := string(data)
assert.Contains(t, result, "defer session.Release() // Important: Manual cleanup required")
}
Loading