diff --git a/cmd/internal/migrations/v3/middleware_locals.go b/cmd/internal/migrations/v3/middleware_locals.go index fba203b..1d7421b 100644 --- a/cmd/internal/migrations/v3/middleware_locals.go +++ b/cmd/internal/migrations/v3/middleware_locals.go @@ -155,7 +155,7 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio reTypeAssert := regexp.MustCompile(`([\w\.]+FromContext\([^\)]+\))\.\([^\)]+\)`) content = reTypeAssert.ReplaceAllString(content, "$1") - reComma := regexp.MustCompile(`(\w+)\s*,\s*(\w+)\s*:=\s*([\w\.]+FromContext\([^\)]+\))`) + reComma := regexp.MustCompile(`(\w+)\s*,\s*(\w+)\s*:=\s*([\w\.]+FromContext\([^\)]+\))(?:\s*,\s*true)*`) content = reComma.ReplaceAllString(content, "$1, $2 := $3, true") for alias := range imports { diff --git a/cmd/internal/migrations/v3/middleware_locals_test.go b/cmd/internal/migrations/v3/middleware_locals_test.go index b892f7b..6b04001 100644 --- a/cmd/internal/migrations/v3/middleware_locals_test.go +++ b/cmd/internal/migrations/v3/middleware_locals_test.go @@ -43,6 +43,38 @@ func handler(c fiber.Ctx) error { assert.Contains(t, buf.String(), "Migrating middleware locals") } +func Test_MigrateMiddlewareLocals_Idempotent(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mlocals-idem") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2" +func handler(c fiber.Ctx) error { + csrfToken, ok := c.Locals("csrf").(string) + _ = csrfToken + _ = ok + return nil +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + contentAfterFirstRun := readFile(t, file) + assert.Contains(t, contentAfterFirstRun, `csrfToken, ok := csrf.TokenFromContext(c), true`) + assert.Contains(t, buf.String(), "Migrating middleware locals") + + buf.Reset() + require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil)) + + contentAfterSecondRun := readFile(t, file) + assert.Equal(t, contentAfterFirstRun, contentAfterSecondRun) + assert.Empty(t, buf.String()) +} + func Test_MigrateMiddlewareLocals_ContextKey(t *testing.T) { t.Parallel()