diff --git a/cmd/internal/migrations/common.go b/cmd/internal/migrations/common.go index 0569b42..a7b7597 100644 --- a/cmd/internal/migrations/common.go +++ b/cmd/internal/migrations/common.go @@ -3,6 +3,7 @@ package migrations import ( "fmt" "os" + "path/filepath" "regexp" "strconv" @@ -13,22 +14,22 @@ import ( ) var ( - pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`) - pkgImportRegex = regexp.MustCompile(`(?m)^(\s*(?:[\w.]+\s+)?")github\.com/gofiber/fiber/v\d+("$)`) + pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`) + fiberImportRegex = regexp.MustCompile(`(^|")github\.com/gofiber/fiber/v\d+`) ) func MigrateGoPkgs(cmd *cobra.Command, cwd string, _, target *semver.Version) error { err := internal.ChangeFileContent(cwd, func(content string) string { - replacement := fmt.Sprintf("${1}github.com/gofiber/fiber/v%d${2}", target.Major()) - return pkgImportRegex.ReplaceAllString(content, replacement) + replacement := fmt.Sprintf("${1}github.com/gofiber/fiber/v%d", target.Major()) + return fiberImportRegex.ReplaceAllString(content, replacement) }) if err != nil { return fmt.Errorf("failed to migrate Go packages: %w", err) } // get go.mod file - modFile := "go.mod" - fileContent, err := os.ReadFile(modFile) + modFile := filepath.Join(cwd, "go.mod") + fileContent, err := os.ReadFile(modFile) // #nosec G304 -- reading module file if err != nil { return fmt.Errorf("read %s: %w", modFile, err) } diff --git a/cmd/internal/migrations/common_test.go b/cmd/internal/migrations/common_test.go new file mode 100644 index 0000000..169be79 --- /dev/null +++ b/cmd/internal/migrations/common_test.go @@ -0,0 +1,50 @@ +package migrations_test + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + semver "github.com/Masterminds/semver/v3" + "github.com/gofiber/cli/cmd/internal/migrations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_MigrateGoPkgs(t *testing.T) { + dir, err := os.MkdirTemp("", "mgpkgs") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + mainContent := `package main +import ( + fiber "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/adaptor" +) +func main() { + _, _ = fiber.New(), adaptor.New() +}` + file := filepath.Join(dir, "main.go") + require.NoError(t, os.WriteFile(file, []byte(mainContent), 0o600)) + + modContent := `module example + +go 1.22 + +require github.com/gofiber/fiber/v2 v2.0.0` + require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte(modContent), 0o600)) + + var buf bytes.Buffer + cmd := newCmd(&buf) + target := semver.MustParse("3.0.0") + require.NoError(t, migrations.MigrateGoPkgs(cmd, dir, nil, target)) + + content := readFile(t, file) + assert.Contains(t, content, "github.com/gofiber/fiber/v3") + assert.NotContains(t, content, "github.com/gofiber/fiber/v2") + + mod := readFile(t, filepath.Join(dir, "go.mod")) + assert.Contains(t, mod, "github.com/gofiber/fiber/v3 v3.0.0") + assert.Contains(t, buf.String(), "Migrating Go packages") +} diff --git a/cmd/internal/migrations/v3/common.go b/cmd/internal/migrations/v3/common.go index f7b29b8..8ba3a63 100644 --- a/cmd/internal/migrations/v3/common.go +++ b/cmd/internal/migrations/v3/common.go @@ -359,13 +359,60 @@ func MigrateMount(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { // MigrateAddMethod adapts the Add method signature func MigrateAddMethod(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { err := internal.ChangeFileContent(cwd, func(content string) string { - return replaceCall(content, ".Add", func(call string, args []string) string { - if len(args) < 2 { - return call + re := regexp.MustCompile(`\.Add\(`) + matches := re.FindAllStringIndex(content, -1) + if len(matches) == 0 { + return content + } + + var b strings.Builder + last := 0 + for _, m := range matches { + if m[0] < last { + continue } - args[0] = fmt.Sprintf("[]string{%s}", args[0]) - return fmt.Sprintf(".Add(%s)", strings.Join(args, ", ")) - }) + + startCall := m[0] + if startCall > 0 { + if _, err := b.WriteString(content[last:startCall]); err != nil { + return content + } + } + + end, inner := extractCall(content, m[1]) + identStart := startCall - 1 + for identStart >= 0 { + ch := content[identStart] + if !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_') { + break + } + identStart-- + } + ident := content[identStart+1 : startCall] + + switch ident { + case "Header", "httpServerActiveRequests": + if _, err := b.WriteString(content[startCall:end]); err != nil { + return content + } + default: + args := splitArgs(inner) + if len(args) >= 2 { + args[0] = fmt.Sprintf("[]string{%s}", args[0]) + } + if _, err := b.WriteString(".Add(" + strings.Join(args, ", ") + ")"); err != nil { + return content + } + } + + last = end + } + + if _, err := b.WriteString(content[last:]); err != nil { + return content + } + + return b.String() }) if err != nil { return fmt.Errorf("failed to migrate Add method calls: %w", err) @@ -414,11 +461,13 @@ func MigrateCORSConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) err // MigrateCSRFConfig updates csrf middleware configuration fields func MigrateCSRFConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { - replacer := strings.NewReplacer("Expiration:", "IdleTimeout:") + reConfig := regexp.MustCompile(`csrf\.Config{[^}]*}`) reSession := regexp.MustCompile(`\s*SessionKey:\s*[^,]+,?\n`) reKeyLookup := regexp.MustCompile(`(\s*)KeyLookup:\s*([^,\n]+)(,?)(\n?)`) err := internal.ChangeFileContent(cwd, func(content string) string { - content = replacer.Replace(content) + content = reConfig.ReplaceAllStringFunc(content, func(s string) string { + return strings.ReplaceAll(s, "Expiration:", "IdleTimeout:") + }) content = reSession.ReplaceAllString(content, "") content = reKeyLookup.ReplaceAllStringFunc(content, func(s string) string { diff --git a/cmd/internal/migrations/v3/common_test.go b/cmd/internal/migrations/v3/common_test.go index d324fff..80dd8fb 100644 --- a/cmd/internal/migrations/v3/common_test.go +++ b/cmd/internal/migrations/v3/common_test.go @@ -298,6 +298,32 @@ func main() { assert.Contains(t, buf.String(), "Migrating Add method calls") } +func Test_MigrateAddMethod_SkipUnrelated(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "maddskip") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +func main() { + req.Header.Add("Authorization", "Bearer "+test.Token) + c.Response().Header.Add("X-Key", "Value") + httpServerActiveRequests.Add(1) +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateAddMethod(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.Contains(t, content, `req.Header.Add("Authorization", "Bearer "+test.Token)`) + assert.Contains(t, content, `c.Response().Header.Add("X-Key", "Value")`) + assert.Contains(t, content, `httpServerActiveRequests.Add(1)`) + assert.NotContains(t, content, "[]string{") + assert.Contains(t, buf.String(), "Migrating Add method calls") +} + func Test_MigrateMimeConstants(t *testing.T) { t.Parallel() @@ -477,6 +503,36 @@ var _ = csrf.New(csrf.Config{ assert.Contains(t, buf.String(), "Migrating CSRF middleware configs") } +func Test_MigrateCSRFConfig_IgnoresPaseto(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mcsrfpaseto") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/o1egl/paseto" +func main() { + payload := &paseto.JSONToken{ + Audience: pasetoTokenAudience, + Jti: tokenID.String(), + Subject: pasetoTokenSubject, + IssuedAt: timeNow, + Expiration: timeNow.Add(duration), + NotBefore: timeNow, + } + _ = payload +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateCSRFConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.Contains(t, content, "Expiration:") + assert.NotContains(t, content, "IdleTimeout:") +} + func Test_MigrateMonitorImport(t *testing.T) { t.Parallel()