diff --git a/cmd/internal/migrations/lists.go b/cmd/internal/migrations/lists.go index 500e97e..dc190d9 100644 --- a/cmd/internal/migrations/lists.go +++ b/cmd/internal/migrations/lists.go @@ -57,6 +57,7 @@ var Migrations = []Migration{ v3migrations.MigrateEnvVarConfig, v3migrations.MigrateSessionConfig, v3migrations.MigrateSessionExtractor, + v3migrations.MigrateKeyAuthConfig, v3migrations.MigrateTimeoutConfig, v3migrations.MigrateBasicauthAuthorizer, v3migrations.MigrateBasicauthConfig, diff --git a/cmd/internal/migrations/v3/common.go b/cmd/internal/migrations/v3/common.go index ffb217d..4aa0813 100644 --- a/cmd/internal/migrations/v3/common.go +++ b/cmd/internal/migrations/v3/common.go @@ -1004,3 +1004,87 @@ func MigrateSessionExtractor(cmd *cobra.Command, cwd string, _, _ *semver.Versio cmd.Println("Migrating session KeyLookup config") return nil } + +// MigrateKeyAuthConfig updates keyauth middleware configuration to use Extractor +// instead of KeyLookup/AuthScheme and removes the deprecated fields. +func MigrateKeyAuthConfig(cmd *cobra.Command, cwd string, _, _ *semver.Version) error { + reConfig := regexp.MustCompile(`keyauth\.Config{[^}]*}`) + reKeyLookup := regexp.MustCompile(`(?m)(\s*)KeyLookup:\s*("[^"]+")(,?)(\n?)`) + reAuthScheme := regexp.MustCompile(`(?m)\s*AuthScheme:\s*([^,\n]+)`) + + err := internal.ChangeFileContent(cwd, func(content string) string { + return reConfig.ReplaceAllStringFunc(content, func(cfg string) string { + keyMatch := reKeyLookup.FindStringSubmatch(cfg) + if len(keyMatch) < 5 { + // remove AuthScheme if present + return removeConfigField(cfg, "AuthScheme") + } + + indent := keyMatch[1] + val := strings.TrimSpace(keyMatch[2]) + comma := keyMatch[3] + newline := keyMatch[4] + + if uq, err := strconv.Unquote(val); err == nil { + val = uq + } + + scheme := "Bearer" + if am := reAuthScheme.FindStringSubmatch(cfg); len(am) > 1 { + scheme = strings.TrimSpace(am[1]) + if uq, err := strconv.Unquote(scheme); err == nil { + scheme = uq + } + } + + parts := strings.Split(val, ",") + var extractors []string + for _, p := range parts { + p = strings.TrimSpace(p) + switch { + case strings.HasPrefix(p, "header:"): + header := strings.TrimPrefix(p, "header:") + if strings.EqualFold(header, "Authorization") { + extractors = append(extractors, fmt.Sprintf("keyauth.FromAuthHeader(%q, %q)", header, scheme)) + } else { + extractors = append(extractors, fmt.Sprintf("keyauth.FromHeader(%q)", header)) + } + case strings.HasPrefix(p, "query:"): + extractors = append(extractors, fmt.Sprintf("keyauth.FromQuery(%q)", strings.TrimPrefix(p, "query:"))) + case strings.HasPrefix(p, "param:"): + extractors = append(extractors, fmt.Sprintf("keyauth.FromParam(%q)", strings.TrimPrefix(p, "param:"))) + case strings.HasPrefix(p, "form:"): + extractors = append(extractors, fmt.Sprintf("keyauth.FromForm(%q)", strings.TrimPrefix(p, "form:"))) + case strings.HasPrefix(p, "cookie:"): + extractors = append(extractors, fmt.Sprintf("keyauth.FromCookie(%q)", strings.TrimPrefix(p, "cookie:"))) + default: + // unrecognized source; remove field and return + cfg = removeConfigField(cfg, "AuthScheme") + return removeConfigField(cfg, "KeyLookup") + } + } + + extractor := "" + if len(extractors) == 1 { + extractor = extractors[0] + } else if len(extractors) > 1 { + extractor = fmt.Sprintf("keyauth.Chain(%s)", strings.Join(extractors, ", ")) + } + + cfg = removeConfigField(cfg, "AuthScheme") + if extractor == "" { + return removeConfigField(cfg, "KeyLookup") + } + + newField := fmt.Sprintf("%sExtractor: %s%s%s", indent, extractor, comma, newline) + cfg = reKeyLookup.ReplaceAllString(cfg, newField) + return cfg + }) + }) + if err != nil { + return fmt.Errorf("failed to migrate keyauth configs: %w", err) + } + + cmd.Println("Migrating keyauth middleware configs") + return nil +} diff --git a/cmd/internal/migrations/v3/common_test.go b/cmd/internal/migrations/v3/common_test.go index 55e1e06..39fc6ad 100644 --- a/cmd/internal/migrations/v3/common_test.go +++ b/cmd/internal/migrations/v3/common_test.go @@ -1300,3 +1300,99 @@ var _ = session.New(session.Config{ assert.NotContains(t, content, "KeyLookup") assert.Contains(t, buf.String(), "Migrating session KeyLookup config") } + +func Test_MigrateKeyAuthConfig_HeaderAuth(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mkeyauth_header") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" +var _ = keyauth.New(keyauth.Config{ + KeyLookup: "header:Authorization", + AuthScheme: "Bearer", +})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateKeyAuthConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "KeyLookup") + assert.NotContains(t, content, "AuthScheme") + assert.Contains(t, content, `Extractor: keyauth.FromAuthHeader("Authorization", "Bearer")`) + assert.Contains(t, buf.String(), "Migrating keyauth middleware configs") +} + +func Test_MigrateKeyAuthConfig_Cookie(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mkeyauth_cookie") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" +var _ = keyauth.New(keyauth.Config{ + KeyLookup: "cookie:token", +})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateKeyAuthConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "KeyLookup") + assert.Contains(t, content, `Extractor: keyauth.FromCookie("token")`) + assert.Contains(t, buf.String(), "Migrating keyauth middleware configs") +} + +func Test_MigrateKeyAuthConfig_Chain(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mkeyauth_chain") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" +var _ = keyauth.New(keyauth.Config{ + KeyLookup: "query:token,header:X-API-Key", +})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateKeyAuthConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "KeyLookup") + assert.Contains(t, content, `Extractor: keyauth.Chain(keyauth.FromQuery("token"), keyauth.FromHeader("X-API-Key"))`) + assert.Contains(t, buf.String(), "Migrating keyauth middleware configs") +} + +func Test_MigrateKeyAuthConfig_Unknown(t *testing.T) { + t.Parallel() + + dir, err := os.MkdirTemp("", "mkeyauth_unknown") + require.NoError(t, err) + defer func() { require.NoError(t, os.RemoveAll(dir)) }() + + file := writeTempFile(t, dir, `package main +import "github.com/gofiber/fiber/v2/middleware/keyauth" +var _ = keyauth.New(keyauth.Config{ + KeyLookup: "unknown:token", + AuthScheme: "Bearer", +})`) + + var buf bytes.Buffer + cmd := newCmd(&buf) + require.NoError(t, v3.MigrateKeyAuthConfig(cmd, dir, nil, nil)) + + content := readFile(t, file) + assert.NotContains(t, content, "KeyLookup") + assert.NotContains(t, content, "AuthScheme") + assert.NotContains(t, content, "Extractor") + assert.Contains(t, buf.String(), "Migrating keyauth middleware configs") +}