From 0c8cf9b5a138da60a32094dc038186b12f65f628 Mon Sep 17 00:00:00 2001 From: RW Date: Sat, 19 Jul 2025 21:10:33 +0200 Subject: [PATCH] migrate: handle csrf KeyLookup --- cmd/internal/migrations/v3/common.go | 34 +++++++++++++++++++++-- cmd/internal/migrations/v3/common_test.go | 25 +++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/cmd/internal/migrations/v3/common.go b/cmd/internal/migrations/v3/common.go index d82f41b..a39254b 100644 --- a/cmd/internal/migrations/v3/common.go +++ b/cmd/internal/migrations/v3/common.go @@ -194,10 +194,40 @@ func MigrateCORSConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) err // MigrateCSRFConfig updates csrf middleware configuration fields func MigrateCSRFConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { replacer := strings.NewReplacer("Expiration:", "IdleTimeout:") - re := regexp.MustCompile(`\s*SessionKey:\s*[^,]+,?\n`) + reSession := regexp.MustCompile(`\s*SessionKey:\s*[^,]+,?\n`) + reKeyLookup := regexp.MustCompile(`(\s*)KeyLookup:\s*([^,\n]+)(,?)(\n?)`) err := internal.ChangeFileContent(cwd, func(content string) string { content = replacer.Replace(content) - return re.ReplaceAllString(content, "") + content = reSession.ReplaceAllString(content, "") + + content = reKeyLookup.ReplaceAllStringFunc(content, func(s string) string { + sub := reKeyLookup.FindStringSubmatch(s) + indent := sub[1] + val := strings.TrimSpace(sub[2]) + comma := sub[3] + newline := sub[4] + + if uq, err := strconv.Unquote(val); err == nil { + val = uq + } + + var extractor string + switch { + case strings.HasPrefix(val, "header:"): + extractor = fmt.Sprintf("Extractor: csrf.FromHeader(%q)", strings.TrimPrefix(val, "header:")) + case strings.HasPrefix(val, "form:"): + extractor = fmt.Sprintf("Extractor: csrf.FromForm(%q)", strings.TrimPrefix(val, "form:")) + case strings.HasPrefix(val, "query:"): + extractor = fmt.Sprintf("Extractor: csrf.FromQuery(%q)", strings.TrimPrefix(val, "query:")) + default: + // Unsupported or insecure value (e.g. cookie) - remove + return "" + } + + return fmt.Sprintf("%s%s%s%s", indent, extractor, comma, newline) + }) + + return content }) if err != nil { return fmt.Errorf("failed to migrate CSRF configs: %w", err) diff --git a/cmd/internal/migrations/v3/common_test.go b/cmd/internal/migrations/v3/common_test.go index 12f0862..c4feb64 100644 --- a/cmd/internal/migrations/v3/common_test.go +++ b/cmd/internal/migrations/v3/common_test.go @@ -401,6 +401,31 @@ var _ = csrf.New(csrf.Config{ assert.Contains(t, buf.String(), "Migrating CSRF middleware configs") } +func Test_MigrateCSRFConfig_KeyLookup(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcsrfkl") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import ( + "github.com/gofiber/fiber/v2/middleware/csrf" +) +var _ = csrf.New(csrf.Config{ + KeyLookup: "header:X-CSRF-Token", +})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateCSRFConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "KeyLookup") + assert.Contains(t, content, `Extractor: csrf.FromHeader("X-CSRF-Token")`) + assert.Contains(t, buf.String(), "Migrating CSRF middleware configs") +} + func Test_MigrateMonitorImport(t *testing.T) { t.Parallel()