Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 79 additions & 34 deletions cmd/internal/migrations/v3/middleware_locals.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,66 @@ import (
"github.com/gofiber/cli/cmd/internal"
)

func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
changed, err := internal.ChangeFileContent(cwd, func(content string) string {
ctxMap := map[string]string{
"requestid": "requestid.FromContext(%s)",
"csrf": "csrf.TokenFromContext(%s)",
"csrf_handler": "csrf.HandlerFromContext(%s)",
"session": "session.FromContext(%s)",
"username": "basicauth.UsernameFromContext(%s)",
"password": "basicauth.PasswordFromContext(%s)",
"token": "keyauth.TokenFromContext(%s)",
}
type ctxRepl struct {
pkg string
replFmt string
}

extractors := []struct {
pkg string
field string
replFmt string
}{
{"csrf", "ContextKey", "csrf.TokenFromContext(%s)"},
{"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"},
{"session", "ContextKey", "session.FromContext(%s)"},
{"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"},
{"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"},
func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]string {
imports := map[string]string{}
for _, m := range reImport.FindAllStringSubmatch(content, -1) {
alias := m[1]
pkg := m[2]
if alias == "" {
alias = pkg
}
imports[alias] = pkg
}
return imports
}

reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`)
imports := map[string]string{}
for _, m := range reImport.FindAllStringSubmatch(content, -1) {
alias := m[1]
pkg := m[2]
if alias == "" {
alias = pkg
}
imports[alias] = pkg
}
func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
ctxMap := map[string][]ctxRepl{
"requestid": {
{pkg: "requestid", replFmt: "requestid.FromContext(%s)"},
},
"csrf": {
{pkg: "csrf", replFmt: "csrf.TokenFromContext(%s)"},
},
"csrf_handler": {
{pkg: "csrf", replFmt: "csrf.HandlerFromContext(%s)"},
},
"session": {
{pkg: "session", replFmt: "session.FromContext(%s)"},
},
"username": {
{pkg: "basicauth", replFmt: "basicauth.UsernameFromContext(%s)"},
},
"password": {
{pkg: "basicauth", replFmt: "basicauth.PasswordFromContext(%s)"},
},
"token": {
{pkg: "keyauth", replFmt: "keyauth.TokenFromContext(%s)"},
},
}

extractors := []struct {
pkg string
field string
replFmt string
}{
{"csrf", "ContextKey", "csrf.TokenFromContext(%s)"},
{"keyauth", "ContextKey", "keyauth.TokenFromContext(%s)"},
{"session", "ContextKey", "session.FromContext(%s)"},
{"basicauth", "ContextUsername", "basicauth.UsernameFromContext(%s)"},
{"basicauth", "ContextPassword", "basicauth.PasswordFromContext(%s)"},
}

reImport := regexp.MustCompile(`(?m)^\s*(?:import\s+)?(?:([\w\.]+)\s+)?"github\.com/gofiber/fiber/(?:v2|v3)/middleware/([\w]+)"`)

// first pass: collect context key mappings across all files
_, err := internal.ChangeFileContent(cwd, func(content string) string {
imports := parseMiddlewareImports(content, reImport)

for alias, pkg := range imports {
for _, e := range extractors {
Expand All @@ -53,18 +79,37 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
re := regexp.MustCompile(alias + `\.Config{[^}]*` + e.field + `:\s*"([^"]+)"`)
matches := re.FindAllStringSubmatch(content, -1)
for _, m := range matches {
ctxMap[m[1]] = e.replFmt
ctxMap[m[1]] = append(ctxMap[m[1]], ctxRepl{pkg: e.pkg, replFmt: e.replFmt})
}
}
}

return content
})
if err != nil {
return fmt.Errorf("failed to gather middleware locals: %w", err)
}

// second pass: perform replacements and clean up
changed, err := internal.ChangeFileContent(cwd, func(content string) string {
imports := parseMiddlewareImports(content, reImport)

reLocals := regexp.MustCompile(`(\w+)\.Locals\("([^"]+)"\)`)
content = reLocals.ReplaceAllStringFunc(content, func(s string) string {
sub := reLocals.FindStringSubmatch(s)
ctx := sub[1]
key := sub[2]
if fmtStr, ok := ctxMap[key]; ok {
return fmt.Sprintf(fmtStr, ctx)
if repls, ok := ctxMap[key]; ok {
if len(repls) == 1 {
return fmt.Sprintf(repls[0].replFmt, ctx)
}
for _, r := range repls {
for _, pkg := range imports {
if pkg == r.pkg {
return fmt.Sprintf(r.replFmt, ctx)
}
}
}
}
return s
})
Expand Down
102 changes: 102 additions & 0 deletions cmd/internal/migrations/v3/middleware_locals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package v3_test
import (
"bytes"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -181,3 +182,104 @@ var ConfigDefault = Config{
content := readFile(t, file)
assert.Contains(t, content, "ContextKey: DefaultContextKey")
}

func Test_MigrateMiddlewareLocals_CustomContextKeyAcrossFiles(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mcustomctxfiles")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

// config in separate file
cfg := `package main
import "github.com/gofiber/fiber/v2/middleware/csrf"

var _ = csrf.New(csrf.Config{ContextKey: "token"})`
cfgPath := filepath.Join(dir, "config.go")
require.NoError(t, os.WriteFile(cfgPath, []byte(cfg), 0o600))

handler := `package main
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/csrf"
)

func handler(c fiber.Ctx) error {
token := c.Locals("token").(string)
_ = token
return nil
}`
file := writeTempFile(t, dir, handler)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Contains(t, content, `token := csrf.TokenFromContext(c)`)
cfgContent := readFile(t, cfgPath)
assert.NotContains(t, cfgContent, "ContextKey")
}

func Test_MigrateMiddlewareLocals_SameContextKeyDifferentPackages(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mctxdup")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

kaCfg := `package main
import "github.com/gofiber/fiber/v2/middleware/keyauth"

var _ = keyauth.New(keyauth.Config{ContextKey: "token"})`
kaCfgPath := filepath.Join(dir, "keyauth.go")
require.NoError(t, os.WriteFile(kaCfgPath, []byte(kaCfg), 0o600))

csrfCfg := `package main
import "github.com/gofiber/fiber/v2/middleware/csrf"

var _ = csrf.New(csrf.Config{ContextKey: "token"})`
csrfCfgPath := filepath.Join(dir, "csrf.go")
require.NoError(t, os.WriteFile(csrfCfgPath, []byte(csrfCfg), 0o600))

kaHandler := `package main
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
)

func handlerKA(c fiber.Ctx) error {
token := c.Locals("token").(string)
_ = token
return nil
}`
kaHandlerPath := filepath.Join(dir, "ka_handler.go")
require.NoError(t, os.WriteFile(kaHandlerPath, []byte(kaHandler), 0o600))

csrfHandler := `package main
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/csrf"
)

func handlerCSRF(c fiber.Ctx) error {
token := c.Locals("token").(string)
_ = token
return nil
}`
csrfHandlerPath := filepath.Join(dir, "csrf_handler.go")
require.NoError(t, os.WriteFile(csrfHandlerPath, []byte(csrfHandler), 0o600))

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil))

kaContent := readFile(t, kaHandlerPath)
assert.Contains(t, kaContent, `token := keyauth.TokenFromContext(c)`)

csrfContent := readFile(t, csrfHandlerPath)
assert.Contains(t, csrfContent, `token := csrf.TokenFromContext(c)`)

assert.NotContains(t, readFile(t, kaCfgPath), "ContextKey")
assert.NotContains(t, readFile(t, csrfCfgPath), "ContextKey")
}
Loading