From 35f2b9f7591ca811b58e101c6d702ff4a1da07cd Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 24 Aug 2025 18:54:40 +0200 Subject: [PATCH 1/2] fix: handle cross-file middleware context keys --- .../migrations/v3/middleware_locals.go | 64 ++++++++++++------- .../migrations/v3/middleware_locals_test.go | 36 +++++++++++ 2 files changed, 78 insertions(+), 22 deletions(-) diff --git a/cmd/internal/migrations/v3/middleware_locals.go b/cmd/internal/migrations/v3/middleware_locals.go index b553e8b..2f6d2a2 100644 --- a/cmd/internal/migrations/v3/middleware_locals.go +++ b/cmd/internal/migrations/v3/middleware_locals.go @@ -11,30 +11,32 @@ import ( ) func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { - changed, err := internal.ChangeFileContent(cwd, func(content string) string { - ctxMap := map[string]string{ - "requestid": "requestid.FromContext(%s)", - "csrf": "csrf.TokenFromContext(%s)", - "csrf_handler": "csrf.HandlerFromContext(%s)", - "session": "session.FromContext(%s)", - "username": "basicauth.UsernameFromContext(%s)", - "password": "basicauth.PasswordFromContext(%s)", - "token": "keyauth.TokenFromContext(%s)", - } + ctxMap := map[string]string{ + "requestid": "requestid.FromContext(%s)", + "csrf": "csrf.TokenFromContext(%s)", + "csrf_handler": "csrf.HandlerFromContext(%s)", + "session": "session.FromContext(%s)", + "username": "basicauth.UsernameFromContext(%s)", + "password": "basicauth.PasswordFromContext(%s)", + "token": "keyauth.TokenFromContext(%s)", + } - extractors := []struct { - pkg string - field string - replFmt string - }{ - {"csrf", "ContextKey", "csrf.TokenFromContext(%s)"}, - {"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"}, - {"session", "ContextKey", "session.FromContext(%s)"}, - {"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"}, - {"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"}, - } + extractors := []struct { + pkg string + field string + replFmt string + }{ + {"csrf", "ContextKey", "csrf.TokenFromContext(%s)"}, + {"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"}, + {"session", "ContextKey", "session.FromContext(%s)"}, + {"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"}, + {"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"}, + } + + reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`) - reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`) + // first pass: collect context key mappings across all files + _, err := internal.ChangeFileContent(cwd, func(content string) string { imports := map[string]string{} for _, m := range reImport.FindAllStringSubmatch(content, -1) { alias := m[1] @@ -58,6 +60,24 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio } } + return content + }) + if err != nil { + return fmt.Errorf("failed to gather middleware locals: %w", err) + } + + // second pass: perform replacements and clean up + changed, err := internal.ChangeFileContent(cwd, func(content string) string { + imports := map[string]string{} + for _, m := range reImport.FindAllStringSubmatch(content, -1) { + alias := m[1] + pkg := m[2] + if alias == "" { + alias = pkg + } + imports[alias] = pkg + } + reLocals := regexp.MustCompile(`(\w+)\.Locals\("([^"]+)"\)`) content = reLocals.ReplaceAllStringFunc(content, func(s string) string { sub := reLocals.FindStringSubmatch(s) diff --git a/cmd/internal/migrations/v3/middleware_locals_test.go b/cmd/internal/migrations/v3/middleware_locals_test.go index 6954075..f1e2b1c 100644 --- a/cmd/internal/migrations/v3/middleware_locals_test.go +++ b/cmd/internal/migrations/v3/middleware_locals_test.go @@ -3,6 +3,7 @@ package v3_test import ( "bytes" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -181,3 +182,38 @@ var ConfigDefault = Config{ content := readFile(t, file) assert.Contains(t, content, "ContextKey: DefaultContextKey") } + +func Test_MigrateMiddlewareLocals_CustomContextKeyAcrossFiles(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcustomctxfiles") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + // config in separate file + cfg := `package main +import "github.com/gofiber/fiber/v2/middleware/csrf" + +var _ = csrf.New(csrf.Config{ContextKey: "token"})` + cfgPath := filepath.Join(dir, "config.go") + require.NoError(t, os.WriteFile(cfgPath, []byte(cfg), 0o600)) + + handler := `package main +import "github.com/gofiber/fiber/v2" + +func handler(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + file := writeTempFile(t, dir, handler) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.Contains(t, content, `token := csrf.TokenFromContext(c)`) + cfgContent := readFile(t, cfgPath) + assert.NotContains(t, cfgContent, "ContextKey") +} From 1f0168074aa353968663dcdbb50c957314c22b73 Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 24 Aug 2025 19:42:44 +0200 Subject: [PATCH 2/2] fix: disambiguate middleware locals by imports --- .../migrations/v3/middleware_locals.go | 83 ++++++++++++------- .../migrations/v3/middleware_locals_test.go | 68 ++++++++++++++- 2 files changed, 121 insertions(+), 30 deletions(-) diff --git a/cmd/internal/migrations/v3/middleware_locals.go b/cmd/internal/migrations/v3/middleware_locals.go index 2f6d2a2..481fe0e 100644 --- a/cmd/internal/migrations/v3/middleware_locals.go +++ b/cmd/internal/migrations/v3/middleware_locals.go @@ -10,15 +10,47 @@ import ( "github.com/gofiber/cli/cmd/internal" ) +type ctxRepl struct { + pkg string + replFmt string +} + +func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]string { + imports := map[string]string{} + for _, m := range reImport.FindAllStringSubmatch(content, -1) { + alias := m[1] + pkg := m[2] + if alias == "" { + alias = pkg + } + imports[alias] = pkg + } + return imports +} + func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { - ctxMap := map[string]string{ - "requestid": "requestid.FromContext(%s)", - "csrf": "csrf.TokenFromContext(%s)", - "csrf_handler": "csrf.HandlerFromContext(%s)", - "session": "session.FromContext(%s)", - "username": "basicauth.UsernameFromContext(%s)", - "password": "basicauth.PasswordFromContext(%s)", - "token": "keyauth.TokenFromContext(%s)", + ctxMap := map[string][]ctxRepl{ + "requestid": { + {pkg: "requestid", replFmt: "requestid.FromContext(%s)"}, + }, + "csrf": { + {pkg: "csrf", replFmt: "csrf.TokenFromContext(%s)"}, + }, + "csrf_handler": { + {pkg: "csrf", replFmt: "csrf.HandlerFromContext(%s)"}, + }, + "session": { + {pkg: "session", replFmt: "session.FromContext(%s)"}, + }, + "username": { + {pkg: "basicauth", replFmt: "basicauth.UsernameFromContext(%s)"}, + }, + "password": { + {pkg: "basicauth", replFmt: "basicauth.PasswordFromContext(%s)"}, + }, + "token": { + {pkg: "keyauth", replFmt: "keyauth.TokenFromContext(%s)"}, + }, } extractors := []struct { @@ -37,15 +69,7 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio // first pass: collect context key mappings across all files _, err := internal.ChangeFileContent(cwd, func(content string) string { - imports := map[string]string{} - for _, m := range reImport.FindAllStringSubmatch(content, -1) { - alias := m[1] - pkg := m[2] - if alias == "" { - alias = pkg - } - imports[alias] = pkg - } + imports := parseMiddlewareImports(content, reImport) for alias, pkg := range imports { for _, e := range extractors { @@ -55,7 +79,7 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio re := regexp.MustCompile(alias + `\.Config{[^}]*` + e.field + `:\s*"([^"]+)"`) matches := re.FindAllStringSubmatch(content, -1) for _, m := range matches { - ctxMap[m[1]] = e.replFmt + ctxMap[m[1]] = append(ctxMap[m[1]], ctxRepl{pkg: e.pkg, replFmt: e.replFmt}) } } } @@ -68,23 +92,24 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio // second pass: perform replacements and clean up changed, err := internal.ChangeFileContent(cwd, func(content string) string { - imports := map[string]string{} - for _, m := range reImport.FindAllStringSubmatch(content, -1) { - alias := m[1] - pkg := m[2] - if alias == "" { - alias = pkg - } - imports[alias] = pkg - } + imports := parseMiddlewareImports(content, reImport) reLocals := regexp.MustCompile(`(\w+)\.Locals\("([^"]+)"\)`) content = reLocals.ReplaceAllStringFunc(content, func(s string) string { sub := reLocals.FindStringSubmatch(s) ctx := sub[1] key := sub[2] - if fmtStr, ok := ctxMap[key]; ok { - return fmt.Sprintf(fmtStr, ctx) + if repls, ok := ctxMap[key]; ok { + if len(repls) == 1 { + return fmt.Sprintf(repls[0].replFmt, ctx) + } + for _, r := range repls { + for _, pkg := range imports { + if pkg == r.pkg { + return fmt.Sprintf(r.replFmt, ctx) + } + } + } } return s }) diff --git a/cmd/internal/migrations/v3/middleware_locals_test.go b/cmd/internal/migrations/v3/middleware_locals_test.go index f1e2b1c..c592e24 100644 --- a/cmd/internal/migrations/v3/middleware_locals_test.go +++ b/cmd/internal/migrations/v3/middleware_locals_test.go @@ -199,7 +199,10 @@ var _ = csrf.New(csrf.Config{ContextKey: "token"})` require.NoError(t, os.WriteFile(cfgPath, []byte(cfg), 0o600)) handler := `package main -import "github.com/gofiber/fiber/v2" +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/csrf" +) func handler(c fiber.Ctx) error { token := c.Locals("token").(string) @@ -217,3 +220,66 @@ func handler(c fiber.Ctx) error { cfgContent := readFile(t, cfgPath) assert.NotContains(t, cfgContent, "ContextKey") } + +func Test_MigrateMiddlewareLocals_SameContextKeyDifferentPackages(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mctxdup") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + kaCfg := `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" + +var _ = keyauth.New(keyauth.Config{ContextKey: "token"})` + kaCfgPath := filepath.Join(dir, "keyauth.go") + require.NoError(t, os.WriteFile(kaCfgPath, []byte(kaCfg), 0o600)) + + csrfCfg := `package main +import "github.com/gofiber/fiber/v2/middleware/csrf" + +var _ = csrf.New(csrf.Config{ContextKey: "token"})` + csrfCfgPath := filepath.Join(dir, "csrf.go") + require.NoError(t, os.WriteFile(csrfCfgPath, []byte(csrfCfg), 0o600)) + + kaHandler := `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/keyauth" +) + +func handlerKA(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + kaHandlerPath := filepath.Join(dir, "ka_handler.go") + require.NoError(t, os.WriteFile(kaHandlerPath, []byte(kaHandler), 0o600)) + + csrfHandler := `package main +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/csrf" +) + +func handlerCSRF(c fiber.Ctx) error { + token := c.Locals("token").(string) + _ = token + return nil +}` + csrfHandlerPath := filepath.Join(dir, "csrf_handler.go") + require.NoError(t, os.WriteFile(csrfHandlerPath, []byte(csrfHandler), 0o600)) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + kaContent := readFile(t, kaHandlerPath) + assert.Contains(t, kaContent, `token := keyauth.TokenFromContext(c)`) + + csrfContent := readFile(t, csrfHandlerPath) + assert.Contains(t, csrfContent, `token := csrf.TokenFromContext(c)`) + + assert.NotContains(t, readFile(t, kaCfgPath), "ContextKey") + assert.NotContains(t, readFile(t, csrfCfgPath), "ContextKey") +}