diff --git a/cmd/internal/migrations/v3/middleware_locals.go b/cmd/internal/migrations/v3/middleware_locals.go index 9d43cdc..a13245f 100644 --- a/cmd/internal/migrations/v3/middleware_locals.go +++ b/cmd/internal/migrations/v3/middleware_locals.go @@ -12,29 +12,54 @@ import ( func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { changed, err := internal.ChangeFileContent(cwd, func(content string) string { - replacements := []struct { - re *regexp.Regexp - repl string + ctxMap := map[string]string{ + "requestid": "requestid.FromContext(%s)", + "csrf": "csrf.TokenFromContext(%s)", + "csrf_handler": "csrf.HandlerFromContext(%s)", + "session": "session.FromContext(%s)", + "username": "basicauth.UsernameFromContext(%s)", + "password": "basicauth.PasswordFromContext(%s)", + "token": "keyauth.TokenFromContext(%s)", + } + + extractors := []struct { + pkg string + field string + replFmt string }{ - {regexp.MustCompile(`(\w+)\.Locals\("requestid"\)`), `requestid.FromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("csrf"\)`), `csrf.TokenFromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("csrf_handler"\)`), `csrf.HandlerFromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("session"\)`), `session.FromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("username"\)`), `basicauth.UsernameFromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("password"\)`), `basicauth.PasswordFromContext($1)`}, - {regexp.MustCompile(`(\w+)\.Locals\("token"\)`), `keyauth.TokenFromContext($1)`}, + {"csrf", "ContextKey", "csrf.TokenFromContext(%s)"}, + {"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"}, + {"session", "ContextKey", "session.FromContext(%s)"}, + {"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"}, + {"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"}, } - for _, r := range replacements { - content = r.re.ReplaceAllString(content, r.repl) + + for _, e := range extractors { + re := regexp.MustCompile(e.pkg + `\.Config{[^}]*` + e.field + `:\s*"([^"]+)"`) + matches := re.FindAllStringSubmatch(content, -1) + for _, m := range matches { + ctxMap[m[1]] = e.replFmt + } } + reLocals := regexp.MustCompile(`(\w+)\.Locals\("([^"]+)"\)`) + content = reLocals.ReplaceAllStringFunc(content, func(s string) string { + sub := reLocals.FindStringSubmatch(s) + ctx := sub[1] + key := sub[2] + if fmtStr, ok := ctxMap[key]; ok { + return fmt.Sprintf(fmtStr, ctx) + } + return s + }) + reTypeAssert := regexp.MustCompile(`([\w\.]+FromContext\([^\)]+\))\.\([^\)]+\)`) content = reTypeAssert.ReplaceAllString(content, "$1") reComma := regexp.MustCompile(`(\w+)\s*,\s*\w+\s*:=\s*([\w\.]+FromContext\([^\)]+\))`) content = reComma.ReplaceAllString(content, "$1 := $2") - reCtxKey := regexp.MustCompile(`\s*ContextKey:\s*[^,}\n]+,?`) + reCtxKey := regexp.MustCompile(`\s*Context(?:Username|Password|Key):\s*[^,}\n]+,?`) content = reCtxKey.ReplaceAllString(content, "") return content diff --git a/cmd/internal/migrations/v3/middleware_locals_test.go b/cmd/internal/migrations/v3/middleware_locals_test.go index 8544dfd..f303e74 100644 --- a/cmd/internal/migrations/v3/middleware_locals_test.go +++ b/cmd/internal/migrations/v3/middleware_locals_test.go @@ -69,3 +69,33 @@ func main() { content := readFile(t, file) assert.NotContains(t, content, "ContextKey") } + +func Test_MigrateMiddlewareLocals_CustomContextKey(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcustomctx") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/csrf" +) + +var _ = csrf.New(csrf.Config{ContextKey: "token"}) + +func handler(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.Contains(t, content, `token := csrf.TokenFromContext(c)`) + assert.NotContains(t, content, "ContextKey") +} diff --git a/cmd/migrate_test.go b/cmd/migrate_test.go index 20509b6..cf0fd93 100644 --- a/cmd/migrate_test.go +++ b/cmd/migrate_test.go @@ -39,6 +39,7 @@ func Test_Migrate_V2_to_V3(t *testing.T) { import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/monitor" + "github.com/gofiber/fiber/v2/middleware/csrf" ) func handler(c *fiber.Ctx) error { @@ -50,7 +51,9 @@ func handler(c *fiber.Ctx) error { ctx := c.Context() uc := c.UserContext() c.SetUserContext(uc) + csrfToken := c.Locals("token").(string) _ = ctx + _ = csrfToken return c.Bind(fiber.Map{}) } @@ -60,6 +63,7 @@ func main() { Prefork: true, Network: "tcp", }) + app.Use(csrf.New(csrf.Config{ContextKey: "token"})) app.Static("/", "./public") app.Add(fiber.MethodGet, "/foo", handler) app.Mount("/api", app) @@ -92,6 +96,8 @@ func main() { at.Contains(content, ".Redirect().To(\"/foo\")") at.Contains(content, ".Redirect().Back()") at.Contains(content, "fiber.Params[int](c, \"id\"") + at.Contains(content, "csrf.TokenFromContext(c)") + at.NotContains(content, "ContextKey") at.Contains(content, ".Use(\"/api\", app)") at.Contains(content, ".Listen(") at.Contains(content, "MIMETextJavaScript")