Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions cmd/internal/migrations/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package migrations
import (
"fmt"
"os"
"path/filepath"
"regexp"
"strconv"

Expand All @@ -13,22 +14,22 @@ import (
)

var (
pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`)
pkgImportRegex = regexp.MustCompile(`(?m)^(\s*(?:[\w.]+\s+)?")github\.com/gofiber/fiber/v\d+("$)`)
pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden go.mod regex: allow tabs/any whitespace and +incompatible/build metadata.

( *?) won’t catch tabs, and v[\w.-]+ misses + (e.g., v3.0.0+incompatible). Widen the character class and whitespace to ensure robust replacement across real-world go.mod lines.

Apply this diff:

-	pkgRegex         = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`)
+	pkgRegex         = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)(\s*)(v[\w.+-]+)`)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)( *?)(v[\w.-]+)`)
pkgRegex = regexp.MustCompile(`(github\.com/gofiber/fiber/)(v\d+)(\s*)(v[\w.+-]+)`)
🤖 Prompt for AI Agents
In cmd/internal/migrations/common.go around line 17, the current pkgRegex uses
`( *?)` which won't match tabs/other whitespace and `v[\w.-]+` which omits `+`
(needed for versions like v3.0.0+incompatible); update the regex to accept any
whitespace (e.g., `\s*?` or `\s+?`) and include `+` in the version character
class (e.g., `v[\w.+-]+`) so the pattern becomes robust for real-world go.mod
lines.

fiberImportRegex = regexp.MustCompile(`(^|")github\.com/gofiber/fiber/v\d+`)
)

func MigrateGoPkgs(cmd *cobra.Command, cwd string, _, target *semver.Version) error {
err := internal.ChangeFileContent(cwd, func(content string) string {
replacement := fmt.Sprintf("${1}github.com/gofiber/fiber/v%d${2}", target.Major())
return pkgImportRegex.ReplaceAllString(content, replacement)
replacement := fmt.Sprintf("${1}github.com/gofiber/fiber/v%d", target.Major())
return fiberImportRegex.ReplaceAllString(content, replacement)
})
if err != nil {
return fmt.Errorf("failed to migrate Go packages: %w", err)
}

// get go.mod file
modFile := "go.mod"
fileContent, err := os.ReadFile(modFile)
modFile := filepath.Join(cwd, "go.mod")
fileContent, err := os.ReadFile(modFile) // #nosec G304 -- reading module file
if err != nil {
return fmt.Errorf("read %s: %w", modFile, err)
}
Expand Down
50 changes: 50 additions & 0 deletions cmd/internal/migrations/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package migrations_test

import (
"bytes"
"os"
"path/filepath"
"testing"

semver "github.com/Masterminds/semver/v3"
"github.com/gofiber/cli/cmd/internal/migrations"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_MigrateGoPkgs(t *testing.T) {
dir, err := os.MkdirTemp("", "mgpkgs")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

mainContent := `package main
import (
fiber "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
)
func main() {
_, _ = fiber.New(), adaptor.New()
}`
file := filepath.Join(dir, "main.go")
require.NoError(t, os.WriteFile(file, []byte(mainContent), 0o600))

modContent := `module example

go 1.22

require github.com/gofiber/fiber/v2 v2.0.0`
require.NoError(t, os.WriteFile(filepath.Join(dir, "go.mod"), []byte(modContent), 0o600))

var buf bytes.Buffer
cmd := newCmd(&buf)
target := semver.MustParse("3.0.0")
require.NoError(t, migrations.MigrateGoPkgs(cmd, dir, nil, target))

content := readFile(t, file)
assert.Contains(t, content, "github.com/gofiber/fiber/v3")
assert.NotContains(t, content, "github.com/gofiber/fiber/v2")

mod := readFile(t, filepath.Join(dir, "go.mod"))
assert.Contains(t, mod, "github.com/gofiber/fiber/v3 v3.0.0")
assert.Contains(t, buf.String(), "Migrating Go packages")
}
65 changes: 57 additions & 8 deletions cmd/internal/migrations/v3/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,60 @@ func MigrateMount(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
// MigrateAddMethod adapts the Add method signature
func MigrateAddMethod(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
err := internal.ChangeFileContent(cwd, func(content string) string {
return replaceCall(content, ".Add", func(call string, args []string) string {
if len(args) < 2 {
return call
re := regexp.MustCompile(`\.Add\(`)
matches := re.FindAllStringIndex(content, -1)
if len(matches) == 0 {
return content
}

var b strings.Builder
last := 0
for _, m := range matches {
if m[0] < last {
continue
}
args[0] = fmt.Sprintf("[]string{%s}", args[0])
return fmt.Sprintf(".Add(%s)", strings.Join(args, ", "))
})

startCall := m[0]
if startCall > 0 {
if _, err := b.WriteString(content[last:startCall]); err != nil {
return content
}
}

end, inner := extractCall(content, m[1])
identStart := startCall - 1
for identStart >= 0 {
ch := content[identStart]
if !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_') {
break
}
identStart--
}
ident := content[identStart+1 : startCall]

switch ident {
case "Header", "httpServerActiveRequests":
if _, err := b.WriteString(content[startCall:end]); err != nil {
return content
}
default:
args := splitArgs(inner)
if len(args) >= 2 {
args[0] = fmt.Sprintf("[]string{%s}", args[0])
}
if _, err := b.WriteString(".Add(" + strings.Join(args, ", ") + ")"); err != nil {
return content
}
}

last = end
}

if _, err := b.WriteString(content[last:]); err != nil {
return content
}

return b.String()
})
if err != nil {
return fmt.Errorf("failed to migrate Add method calls: %w", err)
Expand Down Expand Up @@ -414,11 +461,13 @@ func MigrateCORSConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) err

// MigrateCSRFConfig updates csrf middleware configuration fields
func MigrateCSRFConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
replacer := strings.NewReplacer("Expiration:", "IdleTimeout:")
reConfig := regexp.MustCompile(`csrf\.Config{[^}]*}`)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The regular expression csrf\.Config{[^}]*} is not robust enough to handle all cases, specifically when the csrf.Config struct literal contains nested blocks with curly braces, such as an inline function literal. The [^}]* pattern will stop at the first closing brace it finds, which can lead to incorrect partial replacements and corrupt the code.

For example, consider this code:

var _ = csrf.New(csrf.Config{
    Expiration: 10 * time.Minute,
    Next: func(c *fiber.Ctx) bool {
        if c.Path() == "/login" {
            return true
        }
        return false
    },
})

The regex will only match up to the } of the Next function, and the replacement will be applied incorrectly.

A more robust approach would be to parse the struct literal by matching braces, similar to how extractCall handles parentheses for function calls. You could create a helper function extractBraceBlock to find the matching closing brace for a struct literal and then perform the replacement on the correctly identified block.

reSession := regexp.MustCompile(`\s*SessionKey:\s*[^,]+,?\n`)
reKeyLookup := regexp.MustCompile(`(\s*)KeyLookup:\s*([^,\n]+)(,?)(\n?)`)
err := internal.ChangeFileContent(cwd, func(content string) string {
content = replacer.Replace(content)
content = reConfig.ReplaceAllStringFunc(content, func(s string) string {
return strings.ReplaceAll(s, "Expiration:", "IdleTimeout:")
})
content = reSession.ReplaceAllString(content, "")

content = reKeyLookup.ReplaceAllStringFunc(content, func(s string) string {
Expand Down
56 changes: 56 additions & 0 deletions cmd/internal/migrations/v3/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,32 @@ func main() {
assert.Contains(t, buf.String(), "Migrating Add method calls")
}

func Test_MigrateAddMethod_SkipUnrelated(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "maddskip")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

file := writeTempFile(t, dir, `package main
func main() {
req.Header.Add("Authorization", "Bearer "+test.Token)
c.Response().Header.Add("X-Key", "Value")
httpServerActiveRequests.Add(1)
}`)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateAddMethod(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Contains(t, content, `req.Header.Add("Authorization", "Bearer "+test.Token)`)
assert.Contains(t, content, `c.Response().Header.Add("X-Key", "Value")`)
assert.Contains(t, content, `httpServerActiveRequests.Add(1)`)
assert.NotContains(t, content, "[]string{")
assert.Contains(t, buf.String(), "Migrating Add method calls")
}

func Test_MigrateMimeConstants(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -477,6 +503,36 @@ var _ = csrf.New(csrf.Config{
assert.Contains(t, buf.String(), "Migrating CSRF middleware configs")
}

func Test_MigrateCSRFConfig_IgnoresPaseto(t *testing.T) {
t.Parallel()

dir, err := os.MkdirTemp("", "mcsrfpaseto")
require.NoError(t, err)
defer func() { require.NoError(t, os.RemoveAll(dir)) }()

file := writeTempFile(t, dir, `package main
import "github.com/o1egl/paseto"
func main() {
payload := &paseto.JSONToken{
Audience: pasetoTokenAudience,
Jti: tokenID.String(),
Subject: pasetoTokenSubject,
IssuedAt: timeNow,
Expiration: timeNow.Add(duration),
NotBefore: timeNow,
}
_ = payload
}`)

var buf bytes.Buffer
cmd := newCmd(&buf)
require.NoError(t, v3.MigrateCSRFConfig(cmd, dir, nil, nil))

content := readFile(t, file)
assert.Contains(t, content, "Expiration:")
assert.NotContains(t, content, "IdleTimeout:")
}

func Test_MigrateMonitorImport(t *testing.T) {
t.Parallel()

Expand Down
Loading