From 4c12080f366456392fad3914418fe268b8dcd6e5 Mon Sep 17 00:00:00 2001 From: RW Date: Sun, 30 Nov 2025 21:50:37 +0100 Subject: [PATCH 1/2] Fix JWT extractor migration loop --- cmd/internal/migrations/v3/common.go | 137 +++++++++++++++--- .../migrations/v3/jwt_extractor_test.go | 123 ++++++++++++++++ 2 files changed, 239 insertions(+), 21 deletions(-) diff --git a/cmd/internal/migrations/v3/common.go b/cmd/internal/migrations/v3/common.go index dd0771e..966c50f 100644 --- a/cmd/internal/migrations/v3/common.go +++ b/cmd/internal/migrations/v3/common.go @@ -166,48 +166,143 @@ func replaceField(src, field string, fn func(indent, val, comma, comment, newlin } func replaceFieldImpl(src, field string, unquote bool, fn func(indent, val, comma, comment, newline string) string) string { - re := regexp.MustCompile(`(?m)^(\s*)` + regexp.QuoteMeta(field) + `:\s*([^\n]+)(\n?)`) - return re.ReplaceAllStringFunc(src, func(s string) string { - sub := re.FindStringSubmatch(s) - indent := sub[1] - val := strings.TrimSpace(sub[2]) - newline := sub[3] - - comment := "" - if idx := strings.Index(val, "//"); idx >= 0 { - comment = strings.TrimSpace(val[idx:]) - val = strings.TrimSpace(val[:idx]) - } else if idx := strings.Index(val, "/*"); idx >= 0 { - comment = strings.TrimSpace(val[idx:]) - val = strings.TrimSpace(val[:idx]) + re := regexp.MustCompile(regexp.QuoteMeta(field) + `:\s*`) + var b strings.Builder + pos := 0 + + for { + loc := re.FindStringIndex(src[pos:]) + if loc == nil { + break + } + loc[0] += pos + loc[1] += pos + + start := loc[0] + valStart := loc[1] + + prefix := "" + prefixStart := start + if prefixStart > 0 && (src[prefixStart-1] == '{' || src[prefixStart-1] == ',') { + prefix = string(src[prefixStart-1]) + prefixStart-- + } + + indentStart := prefixStart + for indentStart > 0 && (src[indentStart-1] == ' ' || src[indentStart-1] == '\t') { + indentStart-- + } + indent := src[indentStart:prefixStart] + + if _, err := b.WriteString(src[pos:indentStart]); err != nil { + return src } + i := valStart + depth := 0 + inString := false comma := "" + newline := "" + for i < len(src) { + ch := src[i] + if inString { + if ch == '\\' && i+1 < len(src) { + i += 2 + continue + } + if ch == '"' { + inString = false + } + i++ + continue + } + + switch ch { + case '"': + inString = true + case '(', '{', '[': + depth++ + case ')', ']': + if depth > 0 { + depth-- + } + case '}': + if depth == 0 { + goto endValue + } + depth-- + case ',': + if depth == 0 { + comma = "," + suffixStart := i + i = skipCommaSuffix(src, i) + if strings.Contains(src[suffixStart:i], "\n") { + newline = "\n" + } + goto endValue + } + case '\n': + if depth == 0 { + newline = "\n" + i++ + goto endValue + } + default: + } + i++ + } + + endValue: + end := i + val := strings.TrimSpace(src[valStart:end]) + val, comment := ExtractCommentAndValue(val) + if strings.HasSuffix(val, ",") { - comma = "," + if comma == "" { + comma = "," + } val = strings.TrimSpace(strings.TrimSuffix(val, ",")) } if unquote { uq, err := strconv.Unquote(val) if err != nil { + replacement := fmt.Sprintf("%s%s// TODO: migrate %s: %s", prefix, indent, field, val) if comment != "" { - return fmt.Sprintf("%s// TODO: migrate %s: %s %s%s", indent, field, val, comment, newline) + replacement = fmt.Sprintf("%s %s", replacement, comment) + } + replacement += newline + if _, err := b.WriteString(replacement); err != nil { + return src } - return fmt.Sprintf("%s// TODO: migrate %s: %s%s", indent, field, val, newline) + pos = end + continue } val = uq } repl := fn(indent, val, comma, comment, newline) + var replacement string if repl == "" { if comment != "" { - return fmt.Sprintf("%s%s%s", indent, comment, newline) + replacement = fmt.Sprintf("%s%s%s%s", prefix, indent, comment, newline) + } else { + replacement = prefix + newline } - return newline + } else { + replacement = prefix + repl } - return repl - }) + + if _, err := b.WriteString(replacement); err != nil { + return src + } + pos = end + } + + if _, err := b.WriteString(src[pos:]); err != nil { + return src + } + return b.String() } func collectAliases(content string, reImport *regexp.Regexp, defaults []string) []string { diff --git a/cmd/internal/migrations/v3/jwt_extractor_test.go b/cmd/internal/migrations/v3/jwt_extractor_test.go index e42a64c..b0af6a4 100644 --- a/cmd/internal/migrations/v3/jwt_extractor_test.go +++ b/cmd/internal/migrations/v3/jwt_extractor_test.go @@ -94,6 +94,29 @@ var _ = authjwt.New(authjwt.Config{ assert.Contains(t, buf.String(), "Migrating jwt middleware configs") } +func Test_MigrateJWTExtractor_InlineConfig(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mjwt_inline") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import jwtware "github.com/gofiber/contrib/jwt" + +var _ = jwtware.New(jwtware.Config{TokenLookup: "cookie:jwt"})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateJWTExtractor(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "TokenLookup") + assert.Regexp(t, `Extractor:\s*extractors.FromCookie\("jwt"\)`, content) + assert.Contains(t, content, `"github.com/gofiber/fiber/v3/extractors"`) + assert.Contains(t, buf.String(), "Migrating jwt middleware configs") +} + func Test_MigrateJWTExtractor_ImportWithComment(t *testing.T) { t.Parallel() @@ -124,6 +147,76 @@ func JWTMiddleware() fiber.Handler { assert.Contains(t, buf.String(), "Migrating jwt middleware configs") } +func Test_MigrateJWTExtractor_FiberV2Middleware(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mjwt_fiber_v2") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package auth + +import ( + "os" + "strconv" + + jwtware "github.com/gofiber/contrib/jwt" + "github.com/gofiber/fiber/v2" + "github.com/golang-jwt/jwt/v5" +) + +// JWT error message. +func jwtError(c *fiber.Ctx, err error) error { + if err.Error() == "Missing or malformed JWT" { + return c.Status(fiber.StatusBadRequest).JSON(&fiber.Map{ + "status": "error", + "message": "Missing or malformed JWT!", + }) + } + + return c.Status(fiber.StatusUnauthorized).JSON(&fiber.Map{ + "status": "error", + "message": "Invalid or expired JWT!", + }) +} + +// Guards a specific endpoint in the API. +func JWTMiddleware() fiber.Handler { + return jwtware.New(jwtware.Config{ + ErrorHandler: jwtError, + SigningKey: jwtware.SigningKey{Key: []byte(os.Getenv("JWT_SECRET"))}, + TokenLookup: "cookie:jwt", + }) +} + +// Gets user data (their ID) from the JWT middleware. Should be executed after calling 'JWTMiddleware()'. +func GetDataFromJWT(c *fiber.Ctx) error { + jwtData := c.Locals("user").(*jwt.Token) + claims := jwtData.Claims.(jwt.MapClaims) + parsedUserID := claims["uid"].(string) + userID, err := strconv.Atoi(parsedUserID) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(&fiber.Map{ + "status": "fail", + "message": err.Error(), + }) + } + + c.Locals("currentUser", userID) + return c.Next() +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateJWTExtractor(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "TokenLookup") + assert.Regexp(t, `Extractor:\s*extractors.FromCookie\("jwt"\)`, content) + assert.Contains(t, content, `"github.com/gofiber/fiber/v3/extractors"`) + assert.Contains(t, buf.String(), "Migrating jwt middleware configs") +} + func Test_MigrateJWTExtractor_LegacyImportPath(t *testing.T) { t.Parallel() @@ -157,6 +250,36 @@ func JWTMiddleware() fiber.Handler { assert.Contains(t, buf.String(), "Migrating jwt middleware configs") } +func Test_MigrateJWTExtractor_PointerConfig(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mjwt_pointer") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import ( + jwtware "github.com/gofiber/contrib/jwt" + "github.com/gofiber/fiber/v2" +) + +func JWTMiddleware() fiber.Handler { + return jwtware.New(&jwtware.Config{ + TokenLookup: "cookie:jwt", + }) +}`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateJWTExtractor(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "TokenLookup") + assert.Regexp(t, `Extractor:\s*extractors.FromCookie\("jwt"\)`, content) + assert.Contains(t, content, `"github.com/gofiber/fiber/v3/extractors"`) + assert.Contains(t, buf.String(), "Migrating jwt middleware configs") +} + func Test_MigrateJWTExtractor_SkipUnrelatedPackage(t *testing.T) { t.Parallel() From cefa1a69e6229fdd14718c9019b14d8cbb5968db Mon Sep 17 00:00:00 2001 From: RW Date: Mon, 1 Dec 2025 08:25:30 +0100 Subject: [PATCH 2/2] Remove unnecessary WriteString error handling --- cmd/internal/migrations/v3/common.go | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/cmd/internal/migrations/v3/common.go b/cmd/internal/migrations/v3/common.go index 966c50f..00594b3 100644 --- a/cmd/internal/migrations/v3/common.go +++ b/cmd/internal/migrations/v3/common.go @@ -194,9 +194,7 @@ func replaceFieldImpl(src, field string, unquote bool, fn func(indent, val, comm } indent := src[indentStart:prefixStart] - if _, err := b.WriteString(src[pos:indentStart]); err != nil { - return src - } + b.WriteString(src[pos:indentStart]) //nolint:errcheck // WriteString never returns an error i := valStart depth := 0 @@ -272,9 +270,7 @@ func replaceFieldImpl(src, field string, unquote bool, fn func(indent, val, comm replacement = fmt.Sprintf("%s %s", replacement, comment) } replacement += newline - if _, err := b.WriteString(replacement); err != nil { - return src - } + b.WriteString(replacement) //nolint:errcheck // WriteString never returns an error pos = end continue } @@ -293,15 +289,11 @@ func replaceFieldImpl(src, field string, unquote bool, fn func(indent, val, comm replacement = prefix + repl } - if _, err := b.WriteString(replacement); err != nil { - return src - } + b.WriteString(replacement) //nolint:errcheck // WriteString never returns an error pos = end } - if _, err := b.WriteString(src[pos:]); err != nil { - return src - } + b.WriteString(src[pos:]) //nolint:errcheck // WriteString never returns an error return b.String() }