From b1e9bfed5afa5e4b5e50921dd7c95300feadc6b6 Mon Sep 17 00:00:00 2001 From: RW Date: Wed, 3 Sep 2025 14:12:39 +0200 Subject: [PATCH] migrate: add multi-run test for context migration --- cmd/internal/migrations/v3/context_methods.go | 113 +++++++++--------- .../migrations/v3/context_methods_test.go | 52 +++++++- 2 files changed, 102 insertions(+), 63 deletions(-) diff --git a/cmd/internal/migrations/v3/context_methods.go b/cmd/internal/migrations/v3/context_methods.go index e265ca5..ad793eb 100644 --- a/cmd/internal/migrations/v3/context_methods.go +++ b/cmd/internal/migrations/v3/context_methods.go @@ -20,7 +20,51 @@ func MigrateContextMethods(cmd *cobra.Command, cwd string, _, _ *semver.Version) changed, err := internal.ChangeFileContent(cwd, func(content string) string { orig := content - // UserContext() removed - Ctx implements context.Context + // old Context() returned fasthttp.RequestCtx + if !strings.Contains(orig, ".SetContext(") { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", content, parser.ParseComments) + if err == nil { + modified := false + baseIdent := func(expr ast.Expr) *ast.Ident { + for { + switch e := expr.(type) { + case *ast.Ident: + return e + case *ast.SelectorExpr: + expr = e.X + case *ast.CallExpr: + expr = e.Fun + default: + return nil + } + } + } + ast.Inspect(file, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok || sel.Sel.Name != "Context" || len(call.Args) != 0 { + return true + } + if ident := baseIdent(sel.X); ident != nil && isFiberCtx(orig, ident.Name) { + sel.Sel.Name = "RequestCtx" + modified = true + } + return true + }) + if modified { + var buf bytes.Buffer + if err := format.Node(&buf, fset, file); err == nil { + content = buf.String() + } + } + } + } + + // UserContext() replaced by Context() reUserCtx := regexp.MustCompile(`(\w+)\.UserContext\(\)`) content = reUserCtx.ReplaceAllStringFunc(content, func(match string) string { parts := reUserCtx.FindStringSubmatch(match) @@ -29,68 +73,23 @@ func MigrateContextMethods(cmd *cobra.Command, cwd string, _, _ *semver.Version) } ident := parts[1] if isFiberCtx(orig, ident) { - return ident + return ident + ".Context()" } return match }) - // old Context() returned fasthttp.RequestCtx - fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "", content, parser.ParseComments) - if err == nil { - modified := false - baseIdent := func(expr ast.Expr) *ast.Ident { - for { - switch e := expr.(type) { - case *ast.Ident: - return e - case *ast.SelectorExpr: - expr = e.X - case *ast.CallExpr: - expr = e.Fun - default: - return nil - } - } - } - ast.Inspect(file, func(n ast.Node) bool { - call, ok := n.(*ast.CallExpr) - if !ok { - return true - } - sel, ok := call.Fun.(*ast.SelectorExpr) - if !ok || sel.Sel.Name != "Context" || len(call.Args) != 0 { - return true - } - if ident := baseIdent(sel.X); ident != nil && isFiberCtx(orig, ident.Name) { - sel.Sel.Name = "RequestCtx" - modified = true - } - return true - }) - if modified { - var buf bytes.Buffer - if err := format.Node(&buf, fset, file); err == nil { - content = buf.String() - } - } - } - - // SetUserContext removed - comment out the call - reSetUserCtx := regexp.MustCompile(`(?m)^(\s*)(.*?\b(\w+)\.SetUserContext\([^\n]*\).*)$`) - content = reSetUserCtx.ReplaceAllStringFunc(content, func(line string) string { - if strings.Contains(line, "TODO: SetUserContext was removed") { - return line - } - parts := reSetUserCtx.FindStringSubmatch(line) - if len(parts) != 4 { - return line + // SetUserContext renamed to SetContext + reSetUserCtx := regexp.MustCompile(`(\w+)\.SetUserContext\(`) + content = reSetUserCtx.ReplaceAllStringFunc(content, func(match string) string { + parts := reSetUserCtx.FindStringSubmatch(match) + if len(parts) != 2 { + return match } - ident := parts[3] - if !isFiberCtx(orig, ident) { - return line + ident := parts[1] + if isFiberCtx(orig, ident) { + return ident + ".SetContext(" } - return fmt.Sprintf("%s// TODO: SetUserContext was removed, please migrate manually: %s", parts[1], parts[2]) + return match }) return content diff --git a/cmd/internal/migrations/v3/context_methods_test.go b/cmd/internal/migrations/v3/context_methods_test.go index de7faa5..40723fa 100644 --- a/cmd/internal/migrations/v3/context_methods_test.go +++ b/cmd/internal/migrations/v3/context_methods_test.go @@ -35,10 +35,11 @@ func handler(c fiber.Ctx) error { require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil)) content := readFile(t, file) - assert.Contains(t, content, ".RequestCtx()") - assert.NotContains(t, content, ".Context()") - assert.Contains(t, content, `// TODO: SetUserContext was removed, please migrate manually: c.SetUserContext(ctx)`) - assert.Contains(t, content, "uc := c") + assert.Contains(t, content, "ctx := c.RequestCtx()") + assert.Contains(t, content, "uc := c.Context()") + assert.Contains(t, content, "c.SetContext(ctx)") + assert.NotContains(t, content, ".UserContext()") + assert.NotContains(t, content, "SetUserContext") assert.Contains(t, buf.String(), "Migrating context methods") } @@ -62,8 +63,9 @@ func handler(ctx fiber.Ctx) error { require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil)) content := readFile(t, file) - assert.Contains(t, content, `// TODO: SetUserContext was removed, please migrate manually: res := ctx.SetUserContext(ctx.RequestCtx())`) + assert.Contains(t, content, `res := ctx.SetContext(ctx.RequestCtx())`) assert.NotContains(t, content, `.UserContext()`) + assert.NotContains(t, content, "SetUserContext") assert.Contains(t, buf.String(), "Migrating context methods") } @@ -88,7 +90,45 @@ func handler(c fiber.Ctx) error { require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil)) second := readFile(t, file) assert.Equal(t, first, second) - assert.Equal(t, 1, strings.Count(second, "TODO: SetUserContext was removed")) + assert.Equal(t, 1, strings.Count(second, "SetContext(")) +} + +func Test_MigrateContextMethods_MultipleRuns(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcmtestmulti2") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2" +func handler(c fiber.Ctx) error { + ctx := c.Context() + uc := c.UserContext() + c.SetUserContext(ctx) + c.SetUserContext(c.Context()) + _ = uc + return nil +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil)) + first := readFile(t, file) + + require.Contains(t, first, "ctx := c.RequestCtx()") + require.Contains(t, first, "uc := c.Context()") + require.Contains(t, first, "c.SetContext(ctx)") + require.Contains(t, first, "c.SetContext(c.RequestCtx())") + require.NotContains(t, first, ".UserContext()") + require.NotContains(t, first, "SetUserContext") + + require.NoError(t, v3.MigrateContextMethods(cmd, dir, nil, nil)) + second := readFile(t, file) + assert.Equal(t, first, second) + assert.Equal(t, 1, strings.Count(second, "uc := c.Context()")) + assert.Equal(t, 2, strings.Count(second, ".RequestCtx()")) + assert.Equal(t, 2, strings.Count(second, "SetContext(")) } func Test_MigrateContextMethods_SkipNonFiber(t *testing.T) {