diff --git a/cmd/internal/migrations/v3/middleware_locals.go b/cmd/internal/migrations/v3/middleware_locals.go index b553e8b..481fe0e 100644 --- a/cmd/internal/migrations/v3/middleware_locals.go +++ b/cmd/internal/migrations/v3/middleware_locals.go @@ -10,40 +10,66 @@ import ( "github.com/gofiber/cli/cmd/internal" ) -func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { - changed, err := internal.ChangeFileContent(cwd, func(content string) string { - ctxMap := map[string]string{ - "requestid": "requestid.FromContext(%s)", - "csrf": "csrf.TokenFromContext(%s)", - "csrf_handler": "csrf.HandlerFromContext(%s)", - "session": "session.FromContext(%s)", - "username": "basicauth.UsernameFromContext(%s)", - "password": "basicauth.PasswordFromContext(%s)", - "token": "keyauth.TokenFromContext(%s)", - } +type ctxRepl struct { + pkg string + replFmt string +} - extractors := []struct { - pkg string - field string - replFmt string - }{ - {"csrf", "ContextKey", "csrf.TokenFromContext(%s)"}, - {"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"}, - {"session", "ContextKey", "session.FromContext(%s)"}, - {"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"}, - {"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"}, +func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]string { + imports := map[string]string{} + for _, m := range reImport.FindAllStringSubmatch(content, -1) { + alias := m[1] + pkg := m[2] + if alias == "" { + alias = pkg } + imports[alias] = pkg + } + return imports +} - reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`) - imports := map[string]string{} - for _, m := range reImport.FindAllStringSubmatch(content, -1) { - alias := m[1] - pkg := m[2] - if alias == "" { - alias = pkg - } - imports[alias] = pkg - } +func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { + ctxMap := map[string][]ctxRepl{ + "requestid": { + {pkg: "requestid", replFmt: "requestid.FromContext(%s)"}, + }, + "csrf": { + {pkg: "csrf", replFmt: "csrf.TokenFromContext(%s)"}, + }, + "csrf_handler": { + {pkg: "csrf", replFmt: "csrf.HandlerFromContext(%s)"}, + }, + "session": { + {pkg: "session", replFmt: "session.FromContext(%s)"}, + }, + "username": { + {pkg: "basicauth", replFmt: "basicauth.UsernameFromContext(%s)"}, + }, + "password": { + {pkg: "basicauth", replFmt: "basicauth.PasswordFromContext(%s)"}, + }, + "token": { + {pkg: "keyauth", replFmt: "keyauth.TokenFromContext(%s)"}, + }, + } + + extractors := []struct { + pkg string + field string + replFmt string + }{ + {"csrf", "ContextKey", "csrf.TokenFromContext(%s)"}, + {"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"}, + {"session", "ContextKey", "session.FromContext(%s)"}, + {"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"}, + {"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"}, + } + + reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`) + + // first pass: collect context key mappings across all files + _, err := internal.ChangeFileContent(cwd, func(content string) string { + imports := parseMiddlewareImports(content, reImport) for alias, pkg := range imports { for _, e := range extractors { @@ -53,18 +79,37 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio re := regexp.MustCompile(alias + `\.Config{[^}]*` + e.field + `:\s*"([^"]+)"`) matches := re.FindAllStringSubmatch(content, -1) for _, m := range matches { - ctxMap[m[1]] = e.replFmt + ctxMap[m[1]] = append(ctxMap[m[1]], ctxRepl{pkg: e.pkg, replFmt: e.replFmt}) } } } + return content + }) + if err != nil { + return fmt.Errorf("failed to gather middleware locals: %w", err) + } + + // second pass: perform replacements and clean up + changed, err := internal.ChangeFileContent(cwd, func(content string) string { + imports := parseMiddlewareImports(content, reImport) + reLocals := regexp.MustCompile(`(\w+)\.Locals\("([^"]+)"\)`) content = reLocals.ReplaceAllStringFunc(content, func(s string) string { sub := reLocals.FindStringSubmatch(s) ctx := sub[1] key := sub[2] - if fmtStr, ok := ctxMap[key]; ok { - return fmt.Sprintf(fmtStr, ctx) + if repls, ok := ctxMap[key]; ok { + if len(repls) == 1 { + return fmt.Sprintf(repls[0].replFmt, ctx) + } + for _, r := range repls { + for _, pkg := range imports { + if pkg == r.pkg { + return fmt.Sprintf(r.replFmt, ctx) + } + } + } } return s }) diff --git a/cmd/internal/migrations/v3/middleware_locals_test.go b/cmd/internal/migrations/v3/middleware_locals_test.go index 6954075..c592e24 100644 --- a/cmd/internal/migrations/v3/middleware_locals_test.go +++ b/cmd/internal/migrations/v3/middleware_locals_test.go @@ -3,6 +3,7 @@ package v3_test import ( "bytes" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -181,3 +182,104 @@ var ConfigDefault = Config{ content := readFile(t, file) assert.Contains(t, content, "ContextKey: DefaultContextKey") } + +func Test_MigrateMiddlewareLocals_CustomContextKeyAcrossFiles(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcustomctxfiles") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // config in separate file + cfg := `package main +import "github.com/gofiber/fiber/v2/middleware/csrf" + +var _ = csrf.New(csrf.Config{ContextKey: "token"})` + cfgPath := filepath.Join(dir, "config.go") + require.NoError(t, os.WriteFile(cfgPath, []byte(cfg), 0o600)) + + handler := `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/csrf" +) + +func handler(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + file := writeTempFile(t, dir, handler) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.Contains(t, content, `token := csrf.TokenFromContext(c)`) + cfgContent := readFile(t, cfgPath) + assert.NotContains(t, cfgContent, "ContextKey") +} + +func Test_MigrateMiddlewareLocals_SameContextKeyDifferentPackages(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mctxdup") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + kaCfg := `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" + +var _ = keyauth.New(keyauth.Config{ContextKey: "token"})` + kaCfgPath := filepath.Join(dir, "keyauth.go") + require.NoError(t, os.WriteFile(kaCfgPath, []byte(kaCfg), 0o600)) + + csrfCfg := `package main +import "github.com/gofiber/fiber/v2/middleware/csrf" + +var _ = csrf.New(csrf.Config{ContextKey: "token"})` + csrfCfgPath := filepath.Join(dir, "csrf.go") + require.NoError(t, os.WriteFile(csrfCfgPath, []byte(csrfCfg), 0o600)) + + kaHandler := `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/keyauth" +) + +func handlerKA(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + kaHandlerPath := filepath.Join(dir, "ka_handler.go") + require.NoError(t, os.WriteFile(kaHandlerPath, []byte(kaHandler), 0o600)) + + csrfHandler := `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/csrf" +) + +func handlerCSRF(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + csrfHandlerPath := filepath.Join(dir, "csrf_handler.go") + require.NoError(t, os.WriteFile(csrfHandlerPath, []byte(csrfHandler), 0o600)) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + kaContent := readFile(t, kaHandlerPath) + assert.Contains(t, kaContent, `token := keyauth.TokenFromContext(c)`) + + csrfContent := readFile(t, csrfHandlerPath) + assert.Contains(t, csrfContent, `token := csrf.TokenFromContext(c)`) + + assert.NotContains(t, readFile(t, kaCfgPath), "ContextKey") + assert.NotContains(t, readFile(t, csrfCfgPath), "ContextKey") +}