Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions cmd/internal/migrations/v3/context_methods.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package v3

import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"regexp"
"strings"

Expand Down Expand Up @@ -29,6 +34,48 @@ func MigrateContextMethods(cmd *cobra.Command, cwd string, _, _ *semver.Version)
return match
})

// old Context() returned fasthttp.RequestCtx
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "", content, parser.ParseComments)
if err == nil {
modified := false
baseIdent := func(expr ast.Expr) *ast.Ident {
for {
switch e := expr.(type) {
case *ast.Ident:
return e
case *ast.SelectorExpr:
expr = e.X
case *ast.CallExpr:
expr = e.Fun
default:
return nil
}
}
}
ast.Inspect(file, func(n ast.Node) bool {
call, ok := n.(*ast.CallExpr)
if !ok {
return true
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok || sel.Sel.Name != "Context" || len(call.Args) != 0 {
return true
}
if ident := baseIdent(sel.X); ident != nil && isFiberCtx(orig, ident.Name) {
sel.Sel.Name = "RequestCtx"
modified = true
}
return true
})
if modified {
var buf bytes.Buffer
if err := format.Node(&buf, fset, file); err == nil {
content = buf.String()
}
}
}

// SetUserContext removed - comment out the call
reSetUserCtx := regexp.MustCompile(`(?m)^(\s*)(.*?\b(\w+)\.SetUserContext\([^\n]*\).*)$`)
content = reSetUserCtx.ReplaceAllStringFunc(content, func(line string) string {
Expand All @@ -46,20 +93,6 @@ func MigrateContextMethods(cmd *cobra.Command, cwd string, _, _ *semver.Version)
return fmt.Sprintf("%s// TODO: SetUserContext was removed, please migrate manually: %s", parts[1], parts[2])
})

// old Context() returned fasthttp.RequestCtx
reReqCtx := regexp.MustCompile(`(\w+)\.Context\(\)`)
content = reReqCtx.ReplaceAllStringFunc(content, func(match string) string {
parts := reReqCtx.FindStringSubmatch(match)
if len(parts) != 2 {
return match
}
ident := parts[1]
if isFiberCtx(orig, ident) {
return ident + ".RequestCtx()"
}
return match
})

return content
})
if err != nil {
Expand Down
77 changes: 77 additions & 0 deletions cmd/internal/migrations/v3/context_methods_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ func Test_MigrateContextMethods_SkipNonFiber(t *testing.T) {

file := writeTempFile(t, dir, `package main
type ctx struct{}

func (ctx) UserContext() {}

func (ctx) SetUserContext(any) {}
func (ctx) Context() {}
func handler(c ctx) {
Expand All @@ -119,3 +121,78 @@ func handler(c ctx) {
assert.Contains(t, content, "c.Context()")
assert.NotContains(t, buf.String(), "Migrating context methods")
}

func Test_MigrateContextMethods_RequestCtxChained(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mcmtestchain")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

file := writeTempFile(t, dir, `package main
import "github.com/gofiber/fiber/v2"
func handler(c fiber.Ctx) error {
c.Status(fiber.StatusOK).Context().SetBodyStreamWriter(nil)
return nil
}`)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Contains(t, content, ".RequestCtx().SetBodyStreamWriter")
assert.NotContains(t, content, ".Context().SetBodyStreamWriter")
assert.Contains(t, buf.String(), "Migrating context methods")
}

func Test_MigrateContextMethods_MultipleContextCalls(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mcmtestmulti")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

file := writeTempFile(t, dir, `package main
import "github.com/gofiber/fiber/v2"
func handler(c fiber.Ctx) error {
rc := c.Context()
c.Type("json").Status(fiber.StatusOK).Context().SetBodyStreamWriter(nil)
_ = rc
return nil
}`)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Equal(t, 2, strings.Count(content, ".RequestCtx()"))
assert.NotContains(t, content, ".Context()")
assert.Contains(t, buf.String(), "Migrating context methods")
}

func Test_MigrateContextMethods_RequestCtxNested(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mcmtestnested")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

file := writeTempFile(t, dir, `package main
import "github.com/gofiber/fiber/v2"
func handler(c fiber.Ctx) error {
c.Foo(bar(1)).Context()
c.Foo("a)b").Context()
return nil
}`)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Equal(t, 2, strings.Count(content, ".RequestCtx()"))
assert.NotContains(t, content, ".Context()")
assert.Contains(t, buf.String(), "Migrating context methods")
}
Loading