Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/internal/migrations/lists.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ var Migrations = []Migration{
v3migrations.MigrateCORSConfig,
v3migrations.MigrateCSRFConfig,
v3migrations.MigrateMonitorImport,
v3migrations.MigrateSwaggerPackages,
v3migrations.MigrateContribPackages,
v3migrations.MigrateUtilsImport,
v3migrations.MigrateHealthcheckConfig,
Expand Down
107 changes: 107 additions & 0 deletions cmd/internal/migrations/v3/contrib_versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package v3

import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"

"golang.org/x/sync/singleflight"
)

const contribV3ProxyPrefix = "https://proxy.golang.org/github.com/gofiber/contrib/v3/"

var (
contribV3VersionMu sync.Mutex
contribV3VersionCache = make(map[string]string)
contribV3VersionFetcher = fetchContribV3Version
contribV3VersionGroup singleflight.Group
contribHTTPClient = &http.Client{}
)

func contribV3Version(module string) (string, error) {
contribV3VersionMu.Lock()
if v, ok := contribV3VersionCache[module]; ok {
contribV3VersionMu.Unlock()
return v, nil
}
fetcher := contribV3VersionFetcher
contribV3VersionMu.Unlock()

res, err, _ := contribV3VersionGroup.Do(module, func() (any, error) {
v, fetchErr := fetcher(module)
if fetchErr != nil {
return "", fetchErr
}

contribV3VersionMu.Lock()
contribV3VersionCache[module] = v
contribV3VersionMu.Unlock()
return v, nil
})
if err != nil {
return "", fmt.Errorf("fetch contrib version: %w", err)
}

v, ok := res.(string)
if !ok {
return "", fmt.Errorf("unexpected contrib version type %T", res)
}

return v, nil
}

func fetchContribV3Version(module string) (version string, err error) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

url := contribV3ProxyPrefix + module + "/@latest"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}

res, err := contribHTTPClient.Do(req)
if err != nil {
return "", fmt.Errorf("fetch latest version: %w", err)
}
defer func() {
if cerr := res.Body.Close(); cerr != nil && err == nil {
err = cerr
}
}()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if res.StatusCode != http.StatusOK {
return "", fmt.Errorf("fetch latest version: unexpected status %d", res.StatusCode)
}

var data struct {
Version string `json:"Version"` //nolint:tagliatelle // field name defined by proxy
}
if err := json.NewDecoder(res.Body).Decode(&data); err != nil {
return "", fmt.Errorf("parse latest version: %w", err)
}
if data.Version == "" {
return "", fmt.Errorf("latest version not found for %s", module)
}

return data.Version, nil
}

// SetContribV3VersionFetcher overrides the function used to fetch contrib module versions.
// It resets the cached versions and returns a restore function to revert the change.
func SetContribV3VersionFetcher(fn func(string) (string, error)) func() {
contribV3VersionMu.Lock()
prev := contribV3VersionFetcher
contribV3VersionFetcher = fn
contribV3VersionCache = make(map[string]string)
contribV3VersionMu.Unlock()
return func() {
contribV3VersionMu.Lock()
contribV3VersionFetcher = prev
contribV3VersionCache = make(map[string]string)
contribV3VersionMu.Unlock()
}
}
197 changes: 197 additions & 0 deletions cmd/internal/migrations/v3/swagger_packages.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package v3

import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/fs"
"os"
"path/filepath"
"regexp"
"strings"

semver "github.com/Masterminds/semver/v3"
"github.com/spf13/cobra"

"github.com/gofiber/cli/cmd/internal"
)

const (
contribSwaggerOld = "github.com/gofiber/contrib/swagger"
contribSwaggerNew = "github.com/gofiber/contrib/v3/swaggo"
fiberSwaggerOld = "github.com/gofiber/swagger"
fiberSwaggerNew = "github.com/gofiber/contrib/v3/swaggerui"
goModVersionPattern = `v[a-zA-Z0-9.+-]+`
)

func MigrateSwaggerPackages(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
changedImports, err := internal.ChangeFileContent(cwd, func(content string) string {
updated, changed := rewriteSwaggerImports(content)
if changed {
return updated
}
return content
})
if err != nil {
return fmt.Errorf("failed to migrate swagger imports: %w", err)
}

modChanged, err := migrateSwaggerModules(cwd)
if err != nil {
return err
}

if !changedImports && !modChanged {
return nil
}

cmd.Println("Migrating swagger packages")
return nil
}

func migrateSwaggerModules(cwd string) (bool, error) {
modChanged := false
var swaggoVersion, swaggerUIVersion string

walkErr := filepath.WalkDir(cwd, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
if d.Name() == "vendor" {
return filepath.SkipDir
}
return nil
}
if d.Name() != "go.mod" {
return nil
}

info, err := d.Info()
if err != nil {
return fmt.Errorf("stat %s: %w", path, err)
}

b, err := os.ReadFile(path) // #nosec G304 -- reading module files
if err != nil {
return fmt.Errorf("read %s: %w", path, err)
}
content := string(b)

needsSwaggo := strings.Contains(content, contribSwaggerOld) || strings.Contains(content, contribSwaggerNew)
needsSwaggerUI := strings.Contains(content, fiberSwaggerOld) || strings.Contains(content, fiberSwaggerNew)
if !needsSwaggo && !needsSwaggerUI {
return nil
}

if needsSwaggo && swaggoVersion == "" {
swaggoVersion, err = contribV3Version("swaggo")
if err != nil {
return fmt.Errorf("fetch swaggo version: %w", err)
}
}
if needsSwaggerUI && swaggerUIVersion == "" {
swaggerUIVersion, err = contribV3Version("swaggerui")
if err != nil {
return fmt.Errorf("fetch swaggerui version: %w", err)
}
}

updated := content
if needsSwaggo {
updated = updateGoModModule(updated, contribSwaggerOld, contribSwaggerNew, swaggoVersion)
}
if needsSwaggerUI {
updated = updateGoModModule(updated, fiberSwaggerOld, fiberSwaggerNew, swaggerUIVersion)
}

if updated == content {
return nil
}

if err := os.WriteFile(path, []byte(updated), info.Mode().Perm()); err != nil {
return fmt.Errorf("write %s: %w", path, err)
}
modChanged = true
return nil
})
if walkErr != nil {
return false, fmt.Errorf("failed to migrate swagger go.mod entries: %w", walkErr)
}

return modChanged, nil
}

func rewriteSwaggerImports(content string) (string, bool) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, "", content, parser.ParseComments)
if err != nil {
return content, false
}

changed := false
for _, imp := range f.Imports {
path := strings.Trim(imp.Path.Value, "\"`")
var (
newPath string
wasMigrated bool
)

switch path {
case contribSwaggerOld:
newPath = contribSwaggerNew
wasMigrated = true
case fiberSwaggerOld:
newPath = fiberSwaggerNew
wasMigrated = true
case contribSwaggerNew, fiberSwaggerNew:
newPath = path
default:
continue
}

if path != newPath {
imp.Path.Value = fmt.Sprintf("%q", newPath)
changed = true
}

if wasMigrated && (imp.Name == nil || imp.Name.Name == "") {
imp.Name = ast.NewIdent("swagger")
changed = true
}
}

if !changed {
return content, false
}

var buf bytes.Buffer
if err := format.Node(&buf, fset, f); err != nil {
return content, false
}

return buf.String(), true
}

func updateGoModModule(content, oldPath, newPath, version string) string {
if version == "" {
return content
}

reRequireOld := regexp.MustCompile(fmt.Sprintf(`(?m)^(\s*(?:require\s+)?)%s\s+%s`, regexp.QuoteMeta(oldPath), goModVersionPattern))
content = reRequireOld.ReplaceAllString(content, fmt.Sprintf(`${1}%s %s`, newPath, version))

reRequireNew := regexp.MustCompile(fmt.Sprintf(`(?m)^(\s*(?:require\s+)?)%s\s+%s`, regexp.QuoteMeta(newPath), goModVersionPattern))
content = reRequireNew.ReplaceAllString(content, fmt.Sprintf(`${1}%s %s`, newPath, version))

reReplaceOld := regexp.MustCompile(fmt.Sprintf(`(?m)^(\s*replace\s+)%s(\s+%s)?(\s+=>\s+)`, regexp.QuoteMeta(oldPath), goModVersionPattern))
content = reReplaceOld.ReplaceAllString(content, fmt.Sprintf(`${1}%s${2}${3}`, newPath))

reReplaceNew := regexp.MustCompile(fmt.Sprintf(`(?m)^(\s*replace\s+)%s(\s+%s)?(\s+=>\s+)`, regexp.QuoteMeta(newPath), goModVersionPattern))
content = reReplaceNew.ReplaceAllString(content, fmt.Sprintf(`${1}%s${2}${3}`, newPath))

return content
}
Loading
Loading