From d76b4bf05efbb79420103df00ce7d24231f54db5 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sat, 3 Jan 2026 01:42:55 -0500 Subject: [PATCH 1/4] Refine cache control parsing with utils.EqualFold --- middleware/cache/cache.go | 371 +++++++++++++++++--------- middleware/cache/cache_test.go | 4 +- middleware/cache/manager.go | 7 +- middleware/cache/manager_msgp.go | 277 +++++++++++++++---- middleware/cache/manager_msgp_test.go | 113 ++++++++ 5 files changed, 602 insertions(+), 170 deletions(-) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 9c25fd35770..809c50ca859 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -9,10 +9,8 @@ import ( "errors" "fmt" "math" - "net/http" "slices" "sort" - "strconv" "strings" "sync" "sync/atomic" @@ -74,7 +72,14 @@ type requestCacheDirectives struct { } var ignoreHeaders = map[string]struct{}{ + "Age": {}, + "Cache-Control": {}, // already stored explicitly by the cache manager "Connection": {}, + "Content-Encoding": {}, // already stored explicitly by the cache manager + "Content-Type": {}, // already stored explicitly by the cache manager + "Date": {}, + "ETag": {}, // already stored explicitly by the cache manager + "Expires": {}, // already stored explicitly by the cache manager "Keep-Alive": {}, "Proxy-Authenticate": {}, "Proxy-Authorization": {}, @@ -82,8 +87,6 @@ var ignoreHeaders = map[string]struct{}{ "Trailers": {}, "Transfer-Encoding": {}, "Upgrade": {}, - "Content-Type": {}, // already stored explicitly by the cache manager - "Content-Encoding": {}, // already stored explicitly by the cache manager } var cacheableStatusCodes = map[int]struct{}{ @@ -192,7 +195,7 @@ func New(config ...Config) fiber.Handler { // Return new handler return func(c fiber.Ctx) error { hasAuthorization := len(c.Request().Header.Peek(fiber.HeaderAuthorization)) > 0 - reqCacheControl := utils.UnsafeString(c.Request().Header.Peek(fiber.HeaderCacheControl)) + reqCacheControl := c.Request().Header.Peek(fiber.HeaderCacheControl) reqDirectives := parseRequestCacheControl(reqCacheControl) if !reqDirectives.noCache { reqPragma := utils.UnsafeString(c.Request().Header.Peek(fiber.HeaderPragma)) @@ -258,6 +261,19 @@ func New(config ...Config) fiber.Handler { // Lock entry mux.Lock() + locked := true + unlock := func() { + if locked { + mux.Unlock() + locked = false + } + } + relock := func() { + if !locked { + mux.Lock() + locked = true + } + } // Get timestamp ts := atomic.LoadUint64(×tamp) @@ -284,19 +300,22 @@ func New(config ...Config) fiber.Handler { } if e != nil && e.ttl == 0 && e.exp != 0 && ts >= e.exp { + unlock() if err := deleteKey(reqCtx, key); err != nil { if cfg.Storage != nil { manager.release(e) } - mux.Unlock() + relock() + unlock() return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } + relock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil { manager.release(e) } e = nil - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) goto continueRequest } @@ -304,7 +323,7 @@ func New(config ...Config) fiber.Handler { if e != nil { entryHasPrivate := e != nil && e.private if !entryHasPrivate && cfg.StoreResponseHeaders && len(e.headers) > 0 { - if cc, ok := e.headers[fiber.HeaderCacheControl]; ok && hasDirective(utils.UnsafeString(cc), privateDirective) { + if cc, ok := lookupCachedHeader(e.headers, fiber.HeaderCacheControl); ok && hasDirective(utils.UnsafeString(cc), privateDirective) { entryHasPrivate = true } } @@ -334,7 +353,7 @@ func New(config ...Config) fiber.Handler { handleMinFresh(ts) if revalidate { - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) @@ -346,31 +365,37 @@ func New(config ...Config) fiber.Handler { switch { case entryExpired && !allowStale: + unlock() if err := deleteKey(reqCtx, key); err != nil { if e != nil { manager.release(e) } - mux.Unlock() + relock() + unlock() return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } + relock() idx := e.heapidx manager.release(e) removeHeapEntry(key, idx) e = nil case entryHasPrivate: + unlock() if err := deleteKey(reqCtx, key); err != nil { if e != nil { manager.release(e) } - mux.Unlock() + relock() + unlock() return fmt.Errorf("cache: failed to delete private response for key %q: %w", maskKey(key), err) } + relock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil && e != nil { manager.release(e) } e = nil - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) @@ -382,7 +407,7 @@ func New(config ...Config) fiber.Handler { if cfg.Storage != nil { manager.release(e) } - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) return c.Next() } @@ -390,13 +415,15 @@ func New(config ...Config) fiber.Handler { // Separate body value to avoid msgp serialization // We can store raw bytes with Storage 👍 if cfg.Storage != nil { + unlock() rawBody, err := manager.getRaw(reqCtx, key+"_body") if err != nil { manager.release(e) - mux.Unlock() return cacheBodyFetchError(maskKey, key, err) } e.body = rawBody + } else { + unlock() } // Set response headers from cache c.Response().SetBodyRaw(e.body) @@ -415,10 +442,11 @@ func New(config ...Config) fiber.Handler { c.Response().Header.SetBytesV(fiber.HeaderETag, e.etag) } e.date = clampDateSeconds(e.date, ts) - dateStr := secondsToTime(e.date).Format(http.TimeFormat) - c.Response().Header.Set(fiber.HeaderDate, dateStr) - for k, v := range e.headers { - c.Response().Header.SetBytesV(k, v) + dateValue := fasthttp.AppendHTTPDate(nil, secondsToTime(e.date)) + c.Response().Header.SetBytesV(fiber.HeaderDate, dateValue) + for i := range e.headers { + h := e.headers[i] + c.Response().Header.SetBytesKV(h.key, h.value) } // Set Cache-Control header if not disabled and not already set if !cfg.DisableCacheControl && len(c.Response().Header.Peek(fiber.HeaderCacheControl)) == 0 { @@ -445,8 +473,6 @@ func New(config ...Config) fiber.Handler { manager.release(e) } - mux.Unlock() - // Return response return nil default: @@ -455,7 +481,7 @@ func New(config ...Config) fiber.Handler { } if e == nil && revalidate { - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) @@ -464,13 +490,13 @@ func New(config ...Config) fiber.Handler { } if e == nil && reqDirectives.onlyIfCached { - mux.Unlock() + unlock() c.Set(cfg.CacheHeader, cacheUnreachable) return c.SendStatus(fiber.StatusGatewayTimeout) } // make sure we're not blocking concurrent requests - do unlock - mux.Unlock() + unlock() continueRequest: // Continue stack, return err to Fiber if exist @@ -478,28 +504,28 @@ func New(config ...Config) fiber.Handler { return err } - cacheControl := utils.UnsafeString(c.Response().Header.Peek(fiber.HeaderCacheControl)) + cacheControlBytes := c.Response().Header.Peek(fiber.HeaderCacheControl) + respCacheControl := parseResponseCacheControl(cacheControlBytes) varyHeader := utils.UnsafeString(c.Response().Header.Peek(fiber.HeaderVary)) - hasPrivate := hasDirective(cacheControl, privateDirective) - hasNoCache := hasDirective(cacheControl, noCache) + hasPrivate := respCacheControl.hasPrivate + hasNoCache := respCacheControl.hasNoCache varyNames, varyHasStar := parseVary(varyHeader) // Respect server cache-control: no-store - if hasDirective(cacheControl, noStore) { + if respCacheControl.hasNoStore { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } if hasPrivate || hasNoCache || varyHasStar { if e != nil { - mux.Lock() if err := deleteKey(reqCtx, key); err != nil { if cfg.Storage != nil { manager.release(e) } - mux.Unlock() return fmt.Errorf("cache: failed to delete cached response for key %q: %w", maskKey(key), err) } + mux.Lock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil { manager.release(e) @@ -529,7 +555,7 @@ func New(config ...Config) fiber.Handler { } } - isSharedCacheAllowed := allowsSharedCache(cacheControl) + isSharedCacheAllowed := allowsSharedCacheDirectives(respCacheControl) if hasAuthorization && !isSharedCacheAllowed { c.Set(cfg.CacheHeader, cacheUnreachable) return nil @@ -543,10 +569,6 @@ func New(config ...Config) fiber.Handler { return nil } - // lock entry back and unlock on finish - mux.Lock() - defer mux.Unlock() - // Don't cache response if Next returns true if cfg.Next != nil && cfg.Next(c) { c.Set(cfg.CacheHeader, cacheUnreachable) @@ -560,14 +582,21 @@ func New(config ...Config) fiber.Handler { return nil } - // Remove oldest to make room for new + // Remove oldest to make room for new without holding the lock during storage I/O. if cfg.MaxBytes > 0 { - for storedBytes+bodySize > cfg.MaxBytes { + for { + mux.Lock() + if storedBytes+bodySize <= cfg.MaxBytes { + mux.Unlock() + break + } keyToRemove, size := heap.removeFirst() + storedBytes -= size + mux.Unlock() + if err := deleteKey(reqCtx, keyToRemove); err != nil { return fmt.Errorf("cache: failed to delete key %q while evicting: %w", maskKey(keyToRemove), err) } - storedBytes -= size } } @@ -578,7 +607,7 @@ func New(config ...Config) fiber.Handler { e.ctype = utils.CopyBytes(c.Response().Header.ContentType()) e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)) e.private = false - e.cacheControl = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderCacheControl)) + e.cacheControl = utils.CopyBytes(cacheControlBytes) e.expires = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderExpires)) e.etag = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderETag)) e.date = 0 @@ -600,36 +629,39 @@ func New(config ...Config) fiber.Handler { dateHeader := c.Response().Header.Peek(fiber.HeaderDate) parsedDate, _ := parseHTTPDate(dateHeader) e.date = clampDateSeconds(parsedDate, nowUnix) - dateStr := secondsToTime(e.date).Format(http.TimeFormat) - c.Response().Header.Set(fiber.HeaderDate, dateStr) + dateBytes := fasthttp.AppendHTTPDate(nil, secondsToTime(e.date)) + c.Response().Header.SetBytesV(fiber.HeaderDate, dateBytes) // Store all response headers // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1) if cfg.StoreResponseHeaders { - e.headers = make(map[string][]byte) - for key, value := range c.Response().Header.All() { - // create real copy - keyS := string(key) - if _, ok := ignoreHeaders[keyS]; !ok { - e.headers[keyS] = utils.CopyBytes(value) + allHeaders := c.Response().Header.All() + e.headers = e.headers[:0] + for key, value := range allHeaders { + keyStr := string(key) + if _, ok := ignoreHeaders[keyStr]; ok { + continue } + + e.headers = append(e.headers, cachedHeader{ + key: utils.CopyBytes(utils.UnsafeBytes(keyStr)), + value: utils.CopyBytes(value), + }) } } expirationSource := expirationSourceConfig expiresParseError := false - mustRevalidate := false + mustRevalidate := respCacheControl.mustRevalidate || respCacheControl.proxyRevalidate // default cache expiration expiration := cfg.Expiration - if sharedCacheMode { - if v, ok := parseSMaxAge(cacheControl); ok { - expiration = v - expirationSource = expirationSourceSMaxAge - } + if sharedCacheMode && respCacheControl.sMaxAgeSet { + expiration = secondsToDuration(respCacheControl.sMaxAge) + expirationSource = expirationSourceSMaxAge } if expirationSource == expirationSourceConfig { - if v, ok := parseMaxAge(cacheControl); ok { - expiration = v + if respCacheControl.maxAgeSet { + expiration = secondsToDuration(respCacheControl.maxAge) expirationSource = expirationSourceMaxAge } else if expiresBytes := c.Response().Header.Peek(fiber.HeaderExpires); len(expiresBytes) > 0 { expiresAt, err := fasthttp.ParseHTTPDate(expiresBytes) @@ -642,7 +674,6 @@ func New(config ...Config) fiber.Handler { expirationSource = expirationSourceExpires } } - mustRevalidate = hasDirective(cacheControl, "must-revalidate") || hasDirective(cacheControl, "proxy-revalidate") // Calculate expiration by response header or other setting if cfg.ExpirationGenerator != nil { expiration = cfg.ExpirationGenerator(c, &cfg) @@ -661,6 +692,7 @@ func New(config ...Config) fiber.Handler { return nil } + ts = atomic.LoadUint64(×tamp) responseTS := max(ts, nowUnix) maxAgeSeconds := uint64(time.Duration(math.MaxInt64) / time.Second) @@ -703,16 +735,20 @@ func New(config ...Config) fiber.Handler { // Store entry in heap var heapIdx int if cfg.MaxBytes > 0 { + mux.Lock() heapIdx = heap.put(key, e.exp, bodySize) e.heapidx = heapIdx storedBytes += bodySize + mux.Unlock() } cleanupOnStoreError := func(ctx context.Context, releaseEntry, rawStored bool) error { var cleanupErr error if cfg.MaxBytes > 0 { + mux.Lock() _, size := heap.remove(heapIdx) storedBytes -= size + mux.Unlock() } if releaseEntry { manager.release(e) @@ -788,72 +824,164 @@ func cacheBodyFetchError(mask func(string) string, key string, err error) error return err } -// parseMaxAge extracts the max-age directive from a Cache-Control header. -func parseMaxAge(cc string) (time.Duration, bool) { - for part := range strings.SplitSeq(cc, ",") { - part = utils.TrimSpace(utils.ToLower(part)) - if after, ok := strings.CutPrefix(part, "max-age="); ok { - if sec, err := strconv.Atoi(after); err == nil { - return time.Duration(sec) * time.Second, true +func parseUintDirective(val []byte) (uint64, bool) { + if len(val) == 0 { + return 0, false + } + parsed, err := fasthttp.ParseUint(val) + if err != nil || parsed < 0 { + return 0, false + } + return uint64(parsed), true +} + +func parseCacheControlDirectives(cc []byte, fn func(key, value []byte)) { + for i := 0; i < len(cc); { + // skip leading separators/spaces + for i < len(cc) && (cc[i] == ' ' || cc[i] == ',') { + i++ + } + if i >= len(cc) { + break + } + + start := i + for i < len(cc) && cc[i] != ',' { + i++ + } + partEnd := i + for partEnd > start && cc[partEnd-1] == ' ' { + partEnd-- + } + + keyStart := start + for keyStart < partEnd && cc[keyStart] == ' ' { + keyStart++ + } + if keyStart >= partEnd { + continue + } + + keyEnd := keyStart + for keyEnd < partEnd && cc[keyEnd] != '=' { + keyEnd++ + } + key := cc[keyStart:keyEnd] + + var value []byte + if keyEnd < partEnd && cc[keyEnd] == '=' { + valueStart := keyEnd + 1 + for valueStart < partEnd && cc[valueStart] == ' ' { + valueStart++ + } + valueEnd := partEnd + for valueEnd > valueStart && cc[valueEnd-1] == ' ' { + valueEnd-- + } + if valueStart <= valueEnd { + value = cc[valueStart:valueEnd] } } + + fn(key, value) + i++ // skip comma } - return 0, false } -func parseSMaxAge(cc string) (time.Duration, bool) { - for part := range strings.SplitSeq(cc, ",") { - part = utils.TrimSpace(utils.ToLower(part)) - if after, ok := strings.CutPrefix(part, "s-maxage="); ok { - if sec, err := strconv.Atoi(after); err == nil { - return time.Duration(sec) * time.Second, true +type responseCacheControl struct { + maxAge uint64 + sMaxAge uint64 + maxAgeSet bool + sMaxAgeSet bool + hasNoCache bool + hasNoStore bool + hasPrivate bool + hasPublic bool + mustRevalidate bool + proxyRevalidate bool +} + +func parseResponseCacheControl(cc []byte) responseCacheControl { + parsed := responseCacheControl{} + parseCacheControlDirectives(cc, func(key, value []byte) { + switch { + case utils.EqualFold(utils.UnsafeString(key), noStore): + parsed.hasNoStore = true + case utils.EqualFold(utils.UnsafeString(key), noCache): + parsed.hasNoCache = true + case utils.EqualFold(utils.UnsafeString(key), privateDirective): + parsed.hasPrivate = true + case utils.EqualFold(utils.UnsafeString(key), "public"): + parsed.hasPublic = true + case utils.EqualFold(utils.UnsafeString(key), "max-age"): + if v, ok := parseUintDirective(value); ok { + parsed.maxAgeSet = true + parsed.maxAge = v + } + case utils.EqualFold(utils.UnsafeString(key), "s-maxage"): + if v, ok := parseUintDirective(value); ok { + parsed.sMaxAgeSet = true + parsed.sMaxAge = v } + case utils.EqualFold(utils.UnsafeString(key), "must-revalidate"): + parsed.mustRevalidate = true + case utils.EqualFold(utils.UnsafeString(key), "proxy-revalidate"): + parsed.proxyRevalidate = true + default: + // ignore unknown directives } - } + }) + return parsed +} - return 0, false +// parseMaxAge extracts the max-age directive from a Cache-Control header. +func parseMaxAge(cc string) (time.Duration, bool) { + parsed := parseResponseCacheControl(utils.UnsafeBytes(cc)) + if !parsed.maxAgeSet { + return 0, false + } + return secondsToDuration(parsed.maxAge), true } -func parseRequestCacheControl(cc string) requestCacheDirectives { +func parseRequestCacheControl(cc []byte) requestCacheDirectives { directives := requestCacheDirectives{} - - for part := range strings.SplitSeq(cc, ",") { - part = utils.TrimSpace(utils.ToLower(part)) + parseCacheControlDirectives(cc, func(key, value []byte) { switch { - case part == "": - continue - case part == noStore: + case utils.EqualFold(utils.UnsafeString(key), noStore): directives.noStore = true - case part == noCache: + case utils.EqualFold(utils.UnsafeString(key), noCache): directives.noCache = true - case part == "only-if-cached": + case utils.EqualFold(utils.UnsafeString(key), "only-if-cached"): directives.onlyIfCached = true - case strings.HasPrefix(part, "max-age="): - if sec, err := strconv.Atoi(strings.TrimPrefix(part, "max-age=")); err == nil && sec >= 0 { + case utils.EqualFold(utils.UnsafeString(key), "max-age"): + if sec, ok := parseUintDirective(value); ok { directives.maxAgeSet = true - directives.maxAge = uint64(sec) + directives.maxAge = sec } - case part == "max-stale": + case utils.EqualFold(utils.UnsafeString(key), "max-stale"): directives.maxStaleSet = true - directives.maxStaleAny = true - case strings.HasPrefix(part, "max-stale="): - if sec, err := strconv.Atoi(strings.TrimPrefix(part, "max-stale=")); err == nil && sec >= 0 { - directives.maxStaleSet = true - directives.maxStale = uint64(sec) + directives.maxStaleAny = len(value) == 0 + if !directives.maxStaleAny { + if sec, ok := parseUintDirective(value); ok { + directives.maxStale = sec + } } - case strings.HasPrefix(part, "min-fresh="): - if sec, err := strconv.Atoi(strings.TrimPrefix(part, "min-fresh=")); err == nil && sec >= 0 { + case utils.EqualFold(utils.UnsafeString(key), "min-fresh"): + if sec, ok := parseUintDirective(value); ok { directives.minFreshSet = true - directives.minFresh = uint64(sec) + directives.minFresh = sec } default: - continue + // ignore unknown directives } - } - + }) return directives } +func parseRequestCacheControlString(cc string) requestCacheDirectives { + return parseRequestCacheControl(utils.UnsafeBytes(cc)) +} + func cachedResponseAge(e *item, now uint64) uint64 { clampedDate := clampDateSeconds(e.date, now) @@ -910,8 +1038,20 @@ func isHeuristicFreshness(e *item, cfg *Config, entryAge uint64) bool { return cfg.Expiration > 0 } +func lookupCachedHeader(headers []cachedHeader, name string) ([]byte, bool) { + for i := range headers { + if utils.EqualFold(utils.UnsafeString(headers[i].key), name) { + return headers[i].value, true + } + } + return nil, false +} + func parseHTTPDate(dateBytes []byte) (uint64, bool) { - parsedDate, err := http.ParseTime(utils.UnsafeString(dateBytes)) + if len(dateBytes) == 0 { + return 0, false + } + parsedDate, err := fasthttp.ParseHTTPDate(dateBytes) if err != nil { return 0, false } @@ -948,6 +1088,14 @@ func secondsToTime(sec uint64) time.Time { return time.Unix(clamped, 0).UTC() } +func secondsToDuration(sec uint64) time.Duration { + const maxSeconds = uint64(math.MaxInt64) / uint64(time.Second) + if sec > maxSeconds { + return time.Duration(math.MaxInt64) + } + return time.Duration(sec) * time.Second +} + func parseVary(vary string) ([]string, bool) { names := make([]string, 0, 8) for part := range strings.SplitSeq(vary, ",") { @@ -1031,30 +1179,11 @@ func loadVaryManifest(ctx context.Context, manager *manager, manifestKey string) return names, len(names) > 0, nil } -func allowsSharedCache(cc string) bool { - shareable := false - - for part := range strings.SplitSeq(cc, ",") { - part = utils.TrimSpace(utils.ToLower(part)) - switch { - case part == "": - continue - case part == "private": - return false - case part == "public": - shareable = true - case strings.HasPrefix(part, "s-maxage="): - shareable = true - case part == "must-revalidate": - shareable = true - case part == "proxy-revalidate": - shareable = true - default: - continue - } +func allowsSharedCacheDirectives(cc responseCacheControl) bool { + if cc.hasPrivate { + return false } - - if shareable { + if cc.hasPublic || cc.sMaxAgeSet || cc.mustRevalidate || cc.proxyRevalidate { return true } @@ -1064,6 +1193,10 @@ func allowsSharedCache(cc string) bool { return false } +func allowsSharedCache(cc string) bool { + return allowsSharedCacheDirectives(parseResponseCacheControl(utils.UnsafeBytes(cc))) +} + func makeHashAuthFunc(hexBufPool *sync.Pool) func([]byte) string { return func(authHeader []byte) string { sum := sha256.Sum256(authHeader) diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 65e977e230f..e27cebfb728 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -1961,7 +1961,7 @@ func Test_CacheMaxStaleServesStaleResponse(t *testing.T) { req.Header.Set(fiber.HeaderCacheControl, "max-stale=5") resp, err = app.Test(req) require.NoError(t, err) - require.Equalf(t, cacheHit, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControl("max-stale=5"), resp.Header.Get(fiber.HeaderAge), count) + require.Equalf(t, cacheHit, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("max-stale=5"), resp.Header.Get(fiber.HeaderAge), count) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) @@ -2516,7 +2516,7 @@ func Test_CacheMinFreshForcesRevalidation(t *testing.T) { req.Header.Set(fiber.HeaderCacheControl, "min-fresh=10") resp, err = app.Test(req) require.NoError(t, err) - require.Equalf(t, cacheMiss, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControl("min-fresh=10"), resp.Header.Get(fiber.HeaderAge), count) + require.Equalf(t, cacheMiss, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("min-fresh=10"), resp.Header.Get(fiber.HeaderAge), count) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index 3fa2acb855a..ad7a8fb5369 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -15,7 +15,7 @@ import ( // //go:generate msgp -o=manager_msgp.go -tests=true -unexported type item struct { - headers map[string][]byte + headers []cachedHeader body []byte ctype []byte cencoding []byte @@ -35,6 +35,11 @@ type item struct { heapidx int } +type cachedHeader struct { + key []byte + value []byte +} + //msgp:ignore manager type manager struct { pool sync.Pool diff --git a/middleware/cache/manager_msgp.go b/middleware/cache/manager_msgp.go index 5317020ebe0..6022bd8832c 100644 --- a/middleware/cache/manager_msgp.go +++ b/middleware/cache/manager_msgp.go @@ -6,6 +6,134 @@ import ( "github.com/tinylib/msgp/msgp" ) +// DecodeMsg implements msgp.Decodable +func (z *cachedHeader) DecodeMsg(dc *msgp.Reader) (err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, err = dc.ReadMapHeader() + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.key, err = dc.ReadBytes(z.key) + if err != nil { + err = msgp.WrapError(err, "key") + return + } + case "value": + z.value, err = dc.ReadBytes(z.value) + if err != nil { + err = msgp.WrapError(err, "value") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + return +} + +// EncodeMsg implements msgp.Encodable +func (z *cachedHeader) EncodeMsg(en *msgp.Writer) (err error) { + // map header, size 2 + // write "key" + err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79) + if err != nil { + return + } + err = en.WriteBytes(z.key) + if err != nil { + err = msgp.WrapError(err, "key") + return + } + // write "value" + err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + if err != nil { + return + } + err = en.WriteBytes(z.value) + if err != nil { + err = msgp.WrapError(err, "value") + return + } + return +} + +// MarshalMsg implements msgp.Marshaler +func (z *cachedHeader) MarshalMsg(b []byte) (o []byte, err error) { + o = msgp.Require(b, z.Msgsize()) + // map header, size 2 + // string "key" + o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79) + o = msgp.AppendBytes(o, z.key) + // string "value" + o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + o = msgp.AppendBytes(o, z.value) + return +} + +// UnmarshalMsg implements msgp.Unmarshaler +func (z *cachedHeader) UnmarshalMsg(bts []byte) (o []byte, err error) { + var field []byte + _ = field + var zb0001 uint32 + zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + for zb0001 > 0 { + zb0001-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.key, bts, err = msgp.ReadBytesBytes(bts, z.key) + if err != nil { + err = msgp.WrapError(err, "key") + return + } + case "value": + z.value, bts, err = msgp.ReadBytesBytes(bts, z.value) + if err != nil { + err = msgp.WrapError(err, "value") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err) + return + } + } + } + o = bts + return +} + +// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message +func (z *cachedHeader) Msgsize() (s int) { + s = 1 + 4 + msgp.BytesPrefixSize + len(z.key) + 6 + msgp.BytesPrefixSize + len(z.value) + return +} + // DecodeMsg implements msgp.Decodable func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte @@ -26,31 +154,51 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { switch msgp.UnsafeString(field) { case "headers": var zb0002 uint32 - zb0002, err = dc.ReadMapHeader() + zb0002, err = dc.ReadArrayHeader() if err != nil { err = msgp.WrapError(err, "headers") return } - if z.headers == nil { - z.headers = make(map[string][]byte, zb0002) - } else if len(z.headers) > 0 { - clear(z.headers) + if cap(z.headers) >= int(zb0002) { + z.headers = (z.headers)[:zb0002] + } else { + z.headers = make([]cachedHeader, zb0002) } - for zb0002 > 0 { - zb0002-- - var za0001 string - za0001, err = dc.ReadString() - if err != nil { - err = msgp.WrapError(err, "headers") - return - } - var za0002 []byte - za0002, err = dc.ReadBytes(za0002) + for za0001 := range z.headers { + var zb0003 uint32 + zb0003, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err, "headers", za0001) return } - z.headers[za0001] = za0002 + for zb0003 > 0 { + zb0003-- + field, err = dc.ReadMapKeyPtr() + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.headers[za0001].key, err = dc.ReadBytes(z.headers[za0001].key) + if err != nil { + err = msgp.WrapError(err, "headers", za0001, "key") + return + } + case "value": + z.headers[za0001].value, err = dc.ReadBytes(z.headers[za0001].value) + if err != nil { + err = msgp.WrapError(err, "headers", za0001, "value") + return + } + default: + err = dc.Skip() + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + } + } } case "body": z.body, err = dc.ReadBytes(z.body) @@ -167,20 +315,31 @@ func (z *item) EncodeMsg(en *msgp.Writer) (err error) { if err != nil { return } - err = en.WriteMapHeader(uint32(len(z.headers))) + err = en.WriteArrayHeader(uint32(len(z.headers))) if err != nil { err = msgp.WrapError(err, "headers") return } - for za0001, za0002 := range z.headers { - err = en.WriteString(za0001) + for za0001 := range z.headers { + // map header, size 2 + // write "key" + err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79) if err != nil { - err = msgp.WrapError(err, "headers") return } - err = en.WriteBytes(za0002) + err = en.WriteBytes(z.headers[za0001].key) if err != nil { - err = msgp.WrapError(err, "headers", za0001) + err = msgp.WrapError(err, "headers", za0001, "key") + return + } + // write "value" + err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + if err != nil { + return + } + err = en.WriteBytes(z.headers[za0001].value) + if err != nil { + err = msgp.WrapError(err, "headers", za0001, "value") return } } @@ -353,10 +512,15 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) { // map header, size 17 // string "headers" o = append(o, 0xde, 0x0, 0x11, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) - o = msgp.AppendMapHeader(o, uint32(len(z.headers))) - for za0001, za0002 := range z.headers { - o = msgp.AppendString(o, za0001) - o = msgp.AppendBytes(o, za0002) + o = msgp.AppendArrayHeader(o, uint32(len(z.headers))) + for za0001 := range z.headers { + // map header, size 2 + // string "key" + o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79) + o = msgp.AppendBytes(o, z.headers[za0001].key) + // string "value" + o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) + o = msgp.AppendBytes(o, z.headers[za0001].value) } // string "body" o = append(o, 0xa4, 0x62, 0x6f, 0x64, 0x79) @@ -429,31 +593,51 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { switch msgp.UnsafeString(field) { case "headers": var zb0002 uint32 - zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) + zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "headers") return } - if z.headers == nil { - z.headers = make(map[string][]byte, zb0002) - } else if len(z.headers) > 0 { - clear(z.headers) + if cap(z.headers) >= int(zb0002) { + z.headers = (z.headers)[:zb0002] + } else { + z.headers = make([]cachedHeader, zb0002) } - for zb0002 > 0 { - var za0002 []byte - zb0002-- - var za0001 string - za0001, bts, err = msgp.ReadStringBytes(bts) - if err != nil { - err = msgp.WrapError(err, "headers") - return - } - za0002, bts, err = msgp.ReadBytesBytes(bts, za0002) + for za0001 := range z.headers { + var zb0003 uint32 + zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001) return } - z.headers[za0001] = za0002 + for zb0003 > 0 { + zb0003-- + field, bts, err = msgp.ReadMapKeyZC(bts) + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + switch msgp.UnsafeString(field) { + case "key": + z.headers[za0001].key, bts, err = msgp.ReadBytesBytes(bts, z.headers[za0001].key) + if err != nil { + err = msgp.WrapError(err, "headers", za0001, "key") + return + } + case "value": + z.headers[za0001].value, bts, err = msgp.ReadBytesBytes(bts, z.headers[za0001].value) + if err != nil { + err = msgp.WrapError(err, "headers", za0001, "value") + return + } + default: + bts, err = msgp.Skip(bts) + if err != nil { + err = msgp.WrapError(err, "headers", za0001) + return + } + } + } } case "body": z.body, bts, err = msgp.ReadBytesBytes(bts, z.body) @@ -565,12 +749,9 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *item) Msgsize() (s int) { - s = 3 + 8 + msgp.MapHeaderSize - if z.headers != nil { - for za0001, za0002 := range z.headers { - _ = za0002 - s += msgp.StringPrefixSize + len(za0001) + msgp.BytesPrefixSize + len(za0002) - } + s = 3 + 8 + msgp.ArrayHeaderSize + for za0001 := range z.headers { + s += 1 + 4 + msgp.BytesPrefixSize + len(z.headers[za0001].key) + 6 + msgp.BytesPrefixSize + len(z.headers[za0001].value) } s += 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 13 + msgp.BytesPrefixSize + len(z.cacheControl) + 8 + msgp.BytesPrefixSize + len(z.expires) + 5 + msgp.BytesPrefixSize + len(z.etag) + 5 + msgp.Uint64Size + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 16 + msgp.BoolSize + 11 + msgp.BoolSize + 10 + msgp.BoolSize + 8 + msgp.BoolSize + 8 + msgp.IntSize return diff --git a/middleware/cache/manager_msgp_test.go b/middleware/cache/manager_msgp_test.go index 789b808a588..11ff508e098 100644 --- a/middleware/cache/manager_msgp_test.go +++ b/middleware/cache/manager_msgp_test.go @@ -9,6 +9,119 @@ import ( "github.com/tinylib/msgp/msgp" ) +func TestMarshalUnmarshalcachedHeader(t *testing.T) { + v := cachedHeader{} + bts, err := v.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + left, err := v.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) + } + + left, err = msgp.Skip(bts) + if err != nil { + t.Fatal(err) + } + if len(left) > 0 { + t.Errorf("%d bytes left over after Skip(): %q", len(left), left) + } +} + +func BenchmarkMarshalMsgcachedHeader(b *testing.B) { + v := cachedHeader{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.MarshalMsg(nil) + } +} + +func BenchmarkAppendMsgcachedHeader(b *testing.B) { + v := cachedHeader{} + bts := make([]byte, 0, v.Msgsize()) + bts, _ = v.MarshalMsg(bts[0:0]) + b.SetBytes(int64(len(bts))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bts, _ = v.MarshalMsg(bts[0:0]) + } +} + +func BenchmarkUnmarshalcachedHeader(b *testing.B) { + v := cachedHeader{} + bts, _ := v.MarshalMsg(nil) + b.ReportAllocs() + b.SetBytes(int64(len(bts))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := v.UnmarshalMsg(bts) + if err != nil { + b.Fatal(err) + } + } +} + +func TestEncodeDecodecachedHeader(t *testing.T) { + v := cachedHeader{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + + m := v.Msgsize() + if buf.Len() > m { + t.Log("WARNING: TestEncodeDecodecachedHeader Msgsize() is inaccurate") + } + + vn := cachedHeader{} + err := msgp.Decode(&buf, &vn) + if err != nil { + t.Error(err) + } + + buf.Reset() + msgp.Encode(&buf, &v) + err = msgp.NewReader(&buf).Skip() + if err != nil { + t.Error(err) + } +} + +func BenchmarkEncodecachedHeader(b *testing.B) { + v := cachedHeader{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + en := msgp.NewWriter(msgp.Nowhere) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v.EncodeMsg(en) + } + en.Flush() +} + +func BenchmarkDecodecachedHeader(b *testing.B) { + v := cachedHeader{} + var buf bytes.Buffer + msgp.Encode(&buf, &v) + b.SetBytes(int64(buf.Len())) + rd := msgp.NewEndlessReader(buf.Bytes(), b) + dc := msgp.NewReader(rd) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := v.DecodeMsg(dc) + if err != nil { + b.Fatal(err) + } + } +} + func TestMarshalUnmarshalitem(t *testing.T) { v := item{} bts, err := v.MarshalMsg(nil) From 27505d23f1967e175b85c1f891d4c61934bed310 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:13:12 -0500 Subject: [PATCH 2/4] Simplify error paths while keeping heap protection --- middleware/cache/cache.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index 809c50ca859..8c8c0b49992 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -305,8 +305,6 @@ func New(config ...Config) fiber.Handler { if cfg.Storage != nil { manager.release(e) } - relock() - unlock() return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } relock() @@ -370,8 +368,6 @@ func New(config ...Config) fiber.Handler { if e != nil { manager.release(e) } - relock() - unlock() return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } relock() @@ -385,8 +381,6 @@ func New(config ...Config) fiber.Handler { if e != nil { manager.release(e) } - relock() - unlock() return fmt.Errorf("cache: failed to delete private response for key %q: %w", maskKey(key), err) } relock() From c5ad490cf190a01656784659c4ed0a3748230776 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:39:34 -0500 Subject: [PATCH 3/4] Fix proxy header benchmarks --- ctx_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ctx_test.go b/ctx_test.go index 0db737f6aa3..c9dcc6da7b7 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -2713,7 +2713,7 @@ func Benchmark_Ctx_IPs_v6_With_IP_Validation(b *testing.B) { } func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { - app := New(Config{ProxyHeader: HeaderXForwardedFor}) + app := New(Config{ProxyHeader: HeaderXForwardedFor, TrustProxy: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string @@ -2725,7 +2725,7 @@ func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { } func Benchmark_Ctx_IP_With_ProxyHeader_and_IP_Validation(b *testing.B) { - app := New(Config{ProxyHeader: HeaderXForwardedFor, EnableIPValidation: true}) + app := New(Config{ProxyHeader: HeaderXForwardedFor, TrustProxy: true, EnableIPValidation: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string From 4b8608c4200e8f858f0a57b6de94414d78a5f281 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sat, 3 Jan 2026 13:19:14 -0500 Subject: [PATCH 4/4] Fix proxy header benchmarks --- ctx_test.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ctx_test.go b/ctx_test.go index c9dcc6da7b7..ee9f4a910ea 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -2713,8 +2713,16 @@ func Benchmark_Ctx_IPs_v6_With_IP_Validation(b *testing.B) { } func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { - app := New(Config{ProxyHeader: HeaderXForwardedFor, TrustProxy: true}) - c := app.AcquireCtx(&fasthttp.RequestCtx{}) + app := New(Config{ + ProxyHeader: HeaderXForwardedFor, + TrustProxy: true, + TrustProxyConfig: TrustProxyConfig{ + Loopback: true, + }, + }) + fastCtx := &fasthttp.RequestCtx{} + fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")})) + c := app.AcquireCtx(fastCtx) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string b.ReportAllocs() @@ -2725,8 +2733,17 @@ func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { } func Benchmark_Ctx_IP_With_ProxyHeader_and_IP_Validation(b *testing.B) { - app := New(Config{ProxyHeader: HeaderXForwardedFor, TrustProxy: true, EnableIPValidation: true}) - c := app.AcquireCtx(&fasthttp.RequestCtx{}) + app := New(Config{ + ProxyHeader: HeaderXForwardedFor, + TrustProxy: true, + TrustProxyConfig: TrustProxyConfig{ + Loopback: true, + }, + EnableIPValidation: true, + }) + fastCtx := &fasthttp.RequestCtx{} + fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")})) + c := app.AcquireCtx(fastCtx) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string b.ReportAllocs()