From 97530f7a1a415e1cf700c31db13e7e8dc09b9f86 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:30:51 -0400 Subject: [PATCH 1/9] =?UTF-8?q?=F0=9F=94=A5=20feat:=20add=20msgpack=20cbor?= =?UTF-8?q?=20xml=20helpers=20to=20SharedState?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 44 ++- docs/api/app.md | 12 + docs/api/state.md | 116 ++++++++ docs/whats_new.md | 1 + shared_state.go | 295 ++++++++++++++++++++ shared_state_test.go | 651 +++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1118 insertions(+), 1 deletion(-) create mode 100644 shared_state.go create mode 100644 shared_state_test.go diff --git a/app.go b/app.go index 1f78f941835..57e66f815b0 100644 --- a/app.go +++ b/app.go @@ -90,6 +90,8 @@ type App struct { mountFields *mountFields // state management state *State + // shared state management (prefork-safe, storage-backed) + sharedState *SharedState // Route stack divided by HTTP methods stack [][]*Route // customConstraints is a list of external constraints @@ -279,6 +281,19 @@ type Config struct { //nolint:govet // Aligning the struct fields is not necessa // Default: nil AppName string `json:"app_name"` + // SharedStorage configures storage-backed shared state that can be used + // safely across prefork workers and processes. + // + // Default: nil + SharedStorage Storage `json:"-"` + + // SharedStatePrefix customizes the namespace prefix for keys written to + // SharedStorage. If empty, Fiber derives a prefixed namespace using + // AppName (when set) or an internal default. + // + // Default: "" + SharedStatePrefix string `json:"shared_state_prefix"` + // StreamRequestBody enables request body streaming, // and calls the handler sooner when given body is // larger than the current limit. @@ -638,6 +653,26 @@ func New(config ...Config) *App { if app.config.XMLDecoder == nil { app.config.XMLDecoder = xml.Unmarshal } + + sharedStatePrefix := app.config.SharedStatePrefix + if sharedStatePrefix == "" { + sharedStatePrefix = defaultSharedStatePrefix + if app.config.AppName != "" { + sharedStatePrefix += app.config.AppName + "-" + } + } + app.sharedState = newSharedState( + app.config.SharedStorage, + sharedStatePrefix, + app.config.JSONEncoder, + app.config.JSONDecoder, + app.config.MsgPackEncoder, + app.config.MsgPackDecoder, + app.config.CBOREncoder, + app.config.CBORDecoder, + app.config.XMLEncoder, + app.config.XMLDecoder, + ) if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } @@ -1172,11 +1207,18 @@ func (app *App) Hooks() *Hooks { return app.hooks } -// State returns the state struct to store global data in order to share it between handlers. +// State returns the in-process state struct to store global data between handlers. +// State is process-local and is not shared across prefork workers. func (app *App) State() *State { return app.state } +// SharedState returns storage-backed shared state. +// SharedState is prefork-safe when Config.SharedStorage is configured. +func (app *App) SharedState() *SharedState { + return app.sharedState +} + var ErrTestGotEmptyResponse = errors.New("test: got empty response") // TestConfig is a struct holding Test settings diff --git a/docs/api/app.md b/docs/api/app.md index 67a3ae7da9c..4ce6ba9a38b 100644 --- a/docs/api/app.md +++ b/docs/api/app.md @@ -80,6 +80,18 @@ func main() { } ``` +### State / SharedState + +`State()` returns in-process state (local to the current process). +`SharedState()` returns storage-backed state intended for prefork/multi-process sharing. + +```go title="Signature" +func (app *App) State() *State +func (app *App) SharedState() *SharedState +``` + +See [State Management](./state.md) for usage and examples. + ### MountPath The `MountPath` property contains one or more path patterns on which a sub-app was mounted. diff --git a/docs/api/state.md b/docs/api/state.md index 36be5f67b89..74bfb885b38 100644 --- a/docs/api/state.md +++ b/docs/api/state.md @@ -10,6 +10,122 @@ State management provides a global key–value store for application dependencie When prefork is enabled, each worker process has an independent state store, meaning state is not shared between them. ::: +## SharedState (Prefork-safe) + +For data that must be shared across prefork workers or multiple app processes, use `app.SharedState()` backed by `fiber.Storage`. + +Configure storage in `fiber.Config`: + +```go +app := fiber.New(fiber.Config{ + AppName: "billing-api", + SharedStorage: redisStorage, // any implementation of fiber.Storage + SharedStatePrefix: "billing-shared-", // optional +}) +``` + +If `SharedStatePrefix` is empty, Fiber derives a default namespace and includes `AppName` (when set) to reduce collisions between apps/services. + +:::warning Memory storage caveat +`SharedState` is only cross-worker / cross-process when the configured `SharedStorage` backend is shared. + +If you use an in-memory backend (for example memory storage), data remains process-local. In prefork mode, each worker process has its own independent in-memory store. +::: + +### SharedState Methods + +```go title="Signature" +func (app *App) SharedState() *SharedState + +func (s *SharedState) Set(key string, val []byte, ttl time.Duration) error +func (s *SharedState) SetWithContext(ctx context.Context, key string, val []byte, ttl time.Duration) error + +func (s *SharedState) Get(key string) (val []byte, found bool, err error) +func (s *SharedState) GetWithContext(ctx context.Context, key string) (val []byte, found bool, err error) + +func (s *SharedState) SetJSON(key string, v any, ttl time.Duration) error +func (s *SharedState) SetJSONWithContext(ctx context.Context, key string, v any, ttl time.Duration) error + +func (s *SharedState) GetJSON(key string, out any) (raw []byte, found bool, err error) +func (s *SharedState) GetJSONWithContext(ctx context.Context, key string, out any) (raw []byte, found bool, err error) + +func (s *SharedState) SetMsgPack(key string, v any, ttl time.Duration) error +func (s *SharedState) SetMsgPackWithContext(ctx context.Context, key string, v any, ttl time.Duration) error + +func (s *SharedState) GetMsgPack(key string, out any) (raw []byte, found bool, err error) +func (s *SharedState) GetMsgPackWithContext(ctx context.Context, key string, out any) (raw []byte, found bool, err error) + +func (s *SharedState) SetCBOR(key string, v any, ttl time.Duration) error +func (s *SharedState) SetCBORWithContext(ctx context.Context, key string, v any, ttl time.Duration) error + +func (s *SharedState) GetCBOR(key string, out any) (raw []byte, found bool, err error) +func (s *SharedState) GetCBORWithContext(ctx context.Context, key string, out any) (raw []byte, found bool, err error) + +func (s *SharedState) SetXML(key string, v any, ttl time.Duration) error +func (s *SharedState) SetXMLWithContext(ctx context.Context, key string, v any, ttl time.Duration) error + +func (s *SharedState) GetXML(key string, out any) (raw []byte, found bool, err error) +func (s *SharedState) GetXMLWithContext(ctx context.Context, key string, out any) (raw []byte, found bool, err error) + +func (s *SharedState) Delete(key string) error +func (s *SharedState) DeleteWithContext(ctx context.Context, key string) error + +func (s *SharedState) Has(key string) (bool, error) +func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, error) +``` + +### SharedState Example + +```go +type SessionSnapshot struct { + UserID string `json:"user_id"` + UpdatedAt time.Time `json:"updated_at"` +} + +app.Post("/sessions/:id", func(c fiber.Ctx) error { + key := "session:" + c.Params("id") + value := SessionSnapshot{ + UserID: c.Params("id"), + UpdatedAt: time.Now().UTC(), + } + + if err := app.SharedState().SetJSON(key, value, 30*time.Minute); err != nil { + return err + } + + return c.SendStatus(fiber.StatusAccepted) +}) + +app.Get("/sessions/:id", func(c fiber.Ctx) error { + key := "session:" + c.Params("id") + var snapshot SessionSnapshot + + _, found, err := app.SharedState().GetJSON(key, &snapshot) + if err != nil { + return err + } + if !found { + return c.SendStatus(fiber.StatusNotFound) + } + + return c.JSON(snapshot) +}) +``` + +### SharedState with Context (timeouts/cancellation) + +```go +ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) +defer cancel() + +err := app.SharedState().SetJSONWithContext(ctx, "job:42", fiber.Map{ + "status": "queued", +}, 2*time.Minute) +if err != nil { + // timeout, cancellation, storage error, or JSON serialization error +} +``` + ## State Type `State` is a key–value store built on top of `sync.Map` to ensure safe concurrent access. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data. diff --git a/docs/whats_new.md b/docs/whats_new.md index 20c7781d092..03296712f3d 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -84,6 +84,7 @@ We have made several changes to the Fiber app, including: - **RegisterCustomConstraint**: Allows for the registration of custom constraints. - **NewWithCustomCtx**: Initialize an app with a custom context in one step. - **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details. +- **SharedState**: Introduces storage-backed app state for prefork-safe/multi-process coordination via `Config.SharedStorage`, with optional `Config.SharedStatePrefix` namespacing and JSON/context-aware helpers (`SetJSON`, `GetJSON`, `Has`, `Delete`, and `WithContext` variants). - **NewErrorf**: Allows variadic parameters when creating formatted errors. - **GetBytes / GetString**: Helpers that detach values only when `Immutable` is enabled and the data still references request or response buffers. Access via `c.App().GetString` and `c.App().GetBytes`. - **ReloadViews**: Lets you re-run the configured view engine's `Load()` logic at runtime, including guard rails for missing or nil view engines so development hot-reload hooks can refresh templates safely. diff --git a/shared_state.go b/shared_state.go new file mode 100644 index 00000000000..af07f04b14f --- /dev/null +++ b/shared_state.go @@ -0,0 +1,295 @@ +package fiber + +import ( + "context" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/gofiber/utils/v2" +) + +const defaultSharedStatePrefix = "gofiber-shared-state-" + +var ErrSharedStorageNotConfigured = errors.New("fiber: shared storage is not configured") + +type SharedState struct { + storage Storage + jsonEncoder utils.JSONMarshal + jsonDecoder utils.JSONUnmarshal + msgPackEncoder utils.MsgPackMarshal + msgPackDecoder utils.MsgPackUnmarshal + cborEncoder utils.CBORMarshal + cborDecoder utils.CBORUnmarshal + xmlEncoder utils.XMLMarshal + xmlDecoder utils.XMLUnmarshal + prefix string +} + +func newSharedState( + storage Storage, + prefix string, + jsonEncoder utils.JSONMarshal, + jsonDecoder utils.JSONUnmarshal, + msgPackEncoder utils.MsgPackMarshal, + msgPackDecoder utils.MsgPackUnmarshal, + cborEncoder utils.CBORMarshal, + cborDecoder utils.CBORUnmarshal, + xmlEncoder utils.XMLMarshal, + xmlDecoder utils.XMLUnmarshal, +) *SharedState { + if prefix == "" { + prefix = defaultSharedStatePrefix + } + + if jsonEncoder == nil { + jsonEncoder = json.Marshal + } + if jsonDecoder == nil { + jsonDecoder = json.Unmarshal + } + + return &SharedState{ + storage: storage, + jsonEncoder: jsonEncoder, + jsonDecoder: jsonDecoder, + msgPackEncoder: msgPackEncoder, + msgPackDecoder: msgPackDecoder, + cborEncoder: cborEncoder, + cborDecoder: cborDecoder, + xmlEncoder: xmlEncoder, + xmlDecoder: xmlDecoder, + prefix: prefix, + } +} + +func (s *SharedState) Set(key string, val []byte, ttl time.Duration) error { + return s.SetWithContext(context.Background(), key, val, ttl) +} + +func (s *SharedState) SetWithContext(ctx context.Context, key string, val []byte, ttl time.Duration) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + return s.storage.SetWithContext(ctx, s.key(key), val, ttl) +} + +func (s *SharedState) Get(key string) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetWithContext(context.Background(), key) +} + +func (s *SharedState) GetWithContext(ctx context.Context, key string) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if s == nil || s.storage == nil { + return nil, false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return nil, false, err + } + if data == nil { + return nil, false, nil + } + + return append([]byte(nil), data...), true, nil +} + +func (s *SharedState) SetJSON(key string, v any, ttl time.Duration) error { + return s.SetJSONWithContext(context.Background(), key, v, ttl) +} + +func (s *SharedState) SetJSONWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + encoded, err := s.jsonEncoder(v) + if err != nil { + return fmt.Errorf("fiber: failed to encode shared state value: %w", err) + } + + return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) +} + +func (s *SharedState) GetJSON(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetJSONWithContext(context.Background(), key, out) +} + +func (s *SharedState) GetJSONWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if s == nil || s.storage == nil { + return nil, false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return nil, false, err + } + if data == nil { + return nil, false, nil + } + + if err := s.jsonDecoder(data, out); err != nil { + return nil, false, fmt.Errorf("fiber: failed to decode shared state value: %w", err) + } + + return append([]byte(nil), data...), true, nil +} + +func (s *SharedState) SetMsgPack(key string, v any, ttl time.Duration) error { + return s.SetMsgPackWithContext(context.Background(), key, v, ttl) +} + +func (s *SharedState) SetMsgPackWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + encoded, err := s.msgPackEncoder(v) + if err != nil { + return fmt.Errorf("fiber: failed to encode shared state msgpack value: %w", err) + } + + return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) +} + +func (s *SharedState) GetMsgPack(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetMsgPackWithContext(context.Background(), key, out) +} + +func (s *SharedState) GetMsgPackWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if s == nil || s.storage == nil { + return nil, false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return nil, false, err + } + if data == nil { + return nil, false, nil + } + + if err := s.msgPackDecoder(data, out); err != nil { + return nil, false, fmt.Errorf("fiber: failed to decode shared state msgpack value: %w", err) + } + + return append([]byte(nil), data...), true, nil +} + +func (s *SharedState) SetCBOR(key string, v any, ttl time.Duration) error { + return s.SetCBORWithContext(context.Background(), key, v, ttl) +} + +func (s *SharedState) SetCBORWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + encoded, err := s.cborEncoder(v) + if err != nil { + return fmt.Errorf("fiber: failed to encode shared state cbor value: %w", err) + } + + return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) +} + +func (s *SharedState) GetCBOR(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetCBORWithContext(context.Background(), key, out) +} + +func (s *SharedState) GetCBORWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if s == nil || s.storage == nil { + return nil, false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return nil, false, err + } + if data == nil { + return nil, false, nil + } + + if err := s.cborDecoder(data, out); err != nil { + return nil, false, fmt.Errorf("fiber: failed to decode shared state cbor value: %w", err) + } + + return append([]byte(nil), data...), true, nil +} + +func (s *SharedState) SetXML(key string, v any, ttl time.Duration) error { + return s.SetXMLWithContext(context.Background(), key, v, ttl) +} + +func (s *SharedState) SetXMLWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + encoded, err := s.xmlEncoder(v) + if err != nil { + return fmt.Errorf("fiber: failed to encode shared state xml value: %w", err) + } + + return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) +} + +func (s *SharedState) GetXML(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetXMLWithContext(context.Background(), key, out) +} + +func (s *SharedState) GetXMLWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if s == nil || s.storage == nil { + return nil, false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return nil, false, err + } + if data == nil { + return nil, false, nil + } + + if err := s.xmlDecoder(data, out); err != nil { + return nil, false, fmt.Errorf("fiber: failed to decode shared state xml value: %w", err) + } + + return append([]byte(nil), data...), true, nil +} + +func (s *SharedState) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *SharedState) DeleteWithContext(ctx context.Context, key string) error { + if s == nil || s.storage == nil { + return ErrSharedStorageNotConfigured + } + + return s.storage.DeleteWithContext(ctx, s.key(key)) +} + +func (s *SharedState) Has(key string) (bool, error) { + return s.HasWithContext(context.Background(), key) +} + +func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, error) { + if s == nil || s.storage == nil { + return false, ErrSharedStorageNotConfigured + } + + data, err := s.storage.GetWithContext(ctx, s.key(key)) + if err != nil { + return false, err + } + + return data != nil, nil +} + +func (s *SharedState) key(key string) string { + return utils.CopyString(s.prefix + hex.EncodeToString(utils.UnsafeBytes(key))) +} diff --git a/shared_state_test.go b/shared_state_test.go new file mode 100644 index 00000000000..47ba12aa655 --- /dev/null +++ b/shared_state_test.go @@ -0,0 +1,651 @@ +package fiber + +import ( + "context" + "encoding/json" + "errors" + "testing" + "time" + + storagememory "github.com/gofiber/fiber/v3/internal/storage/memory" + "github.com/stretchr/testify/require" +) + +func newSharedStateMemoryStorage(t *testing.T) *storagememory.Storage { + t.Helper() + + store := storagememory.New() + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} + +type contextCheckingStorage struct { + base Storage + ctxKey any +} + +type errorStorage struct { + err error +} + +func (s *errorStorage) GetWithContext(context.Context, string) ([]byte, error) { + return nil, s.err +} + +func (s *errorStorage) Get(string) ([]byte, error) { + return nil, s.err +} + +func (s *errorStorage) SetWithContext(context.Context, string, []byte, time.Duration) error { + return s.err +} + +func (s *errorStorage) Set(string, []byte, time.Duration) error { + return s.err +} + +func (s *errorStorage) DeleteWithContext(context.Context, string) error { + return s.err +} + +func (s *errorStorage) Delete(string) error { + return s.err +} + +func (s *errorStorage) ResetWithContext(context.Context) error { + return s.err +} + +func (s *errorStorage) Reset() error { + return s.err +} + +func (*errorStorage) Close() error { + return nil +} + +func (s *contextCheckingStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { + if ctx.Value(s.ctxKey) == nil { + return errors.New("context value not found") + } + return s.base.SetWithContext(ctx, key, val, exp) +} + +func (s *contextCheckingStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) { + if ctx.Value(s.ctxKey) == nil { + return nil, errors.New("context value not found") + } + return s.base.GetWithContext(ctx, key) +} + +func (s *contextCheckingStorage) DeleteWithContext(ctx context.Context, key string) error { + if ctx.Value(s.ctxKey) == nil { + return errors.New("context value not found") + } + return s.base.DeleteWithContext(ctx, key) +} + +func (s *contextCheckingStorage) Get(key string) ([]byte, error) { + return s.GetWithContext(context.Background(), key) +} + +func (s *contextCheckingStorage) Set(key string, val []byte, exp time.Duration) error { + return s.SetWithContext(context.Background(), key, val, exp) +} + +func (s *contextCheckingStorage) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *contextCheckingStorage) ResetWithContext(ctx context.Context) error { + return s.base.ResetWithContext(ctx) +} + +func (s *contextCheckingStorage) Reset() error { + return s.base.Reset() +} + +func (s *contextCheckingStorage) Close() error { + return s.base.Close() +} + +func TestSharedState_NotConfigured(t *testing.T) { + t.Parallel() + + app := New() + + err := app.SharedState().Set("raw-key", []byte("raw"), time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + raw, found, err := app.SharedState().Get("raw-key") + require.Nil(t, raw) + require.False(t, found) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = app.SharedState().SetJSON("key", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = app.SharedState().SetMsgPack("key", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = app.SharedState().SetCBOR("key", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = app.SharedState().SetXML("key", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + _, found, err = app.SharedState().GetJSON("key", &Map{}) + require.False(t, found) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, found, err = app.SharedState().GetMsgPack("key", &Map{}) + require.False(t, found) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, found, err = app.SharedState().GetCBOR("key", &Map{}) + require.False(t, found) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, found, err = app.SharedState().GetXML("key", &Map{}) + require.False(t, found) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = app.SharedState().Delete("key") + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + has, err := app.SharedState().Has("key") + require.False(t, has) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) +} + +func TestSharedState_PreforkSafeWithSharedStorage(t *testing.T) { + t.Parallel() + + store := newSharedStateMemoryStorage(t) + workerA := New(Config{AppName: "prefork-app", SharedStorage: store}) + workerB := New(Config{AppName: "prefork-app", SharedStorage: store}) + + workerA.State().Set("process-only", "from-worker-a") + _, ok := workerB.State().Get("process-only") + require.False(t, ok) + + payload := Map{"worker": "a", "version": 3} + err := workerA.SharedState().SetJSON("cluster-key", payload, time.Minute) + require.NoError(t, err) + + err = workerA.SharedState().Set("raw-cluster-key", []byte("raw-value"), time.Minute) + require.NoError(t, err) + + var out map[string]any + rawJSON, found, err := workerB.SharedState().GetJSON("cluster-key", &out) + require.NoError(t, err) + require.True(t, found) + require.NotNil(t, rawJSON) + require.Equal(t, "a", out["worker"]) + require.EqualValues(t, 3, out["version"]) + + raw, found, err := workerB.SharedState().Get("raw-cluster-key") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("raw-value"), raw) + + has, err := workerB.SharedState().Has("cluster-key") + require.NoError(t, err) + require.True(t, has) + + err = workerB.SharedState().Delete("cluster-key") + require.NoError(t, err) + + has, err = workerA.SharedState().Has("cluster-key") + require.NoError(t, err) + require.False(t, has) + + require.NoError(t, workerA.SharedState().Delete("raw-cluster-key")) + raw, found, err = workerB.SharedState().Get("raw-cluster-key") + require.NoError(t, err) + require.Nil(t, raw) + require.False(t, found) +} + +func TestSharedState_ExplicitSerializationError(t *testing.T) { + t.Parallel() + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + err := app.SharedState().SetJSON("invalid", make(chan int), time.Second) + require.Error(t, err) +} + +func TestSharedState_ContextAwareVariants(t *testing.T) { + t.Parallel() + + type testContextKey string + + ctxKey := testContextKey("tenant") + store := &contextCheckingStorage{ctxKey: ctxKey, base: newSharedStateMemoryStorage(t)} + app := New(Config{SharedStorage: store}) + + t.Run("missing context", func(t *testing.T) { + t.Parallel() + + err := app.SharedState().SetJSONWithContext(context.Background(), "key", Map{"ok": true}, time.Second) + require.Error(t, err) + }) + + t.Run("context propagation", func(t *testing.T) { + t.Parallel() + + ctx := context.WithValue(context.Background(), ctxKey, "value") + err := app.SharedState().SetJSONWithContext(ctx, "key", Map{"ok": true}, time.Second) + require.NoError(t, err) + + var out map[string]bool + _, found, err := app.SharedState().GetJSONWithContext(ctx, "key", &out) + require.NoError(t, err) + require.True(t, found) + require.True(t, out["ok"]) + + has, err := app.SharedState().HasWithContext(ctx, "key") + require.NoError(t, err) + require.True(t, has) + + err = app.SharedState().DeleteWithContext(ctx, "key") + require.NoError(t, err) + }) +} + +func TestSharedState_KeyNamespacing(t *testing.T) { + t.Parallel() + + store := newSharedStateMemoryStorage(t) + appOne := New(Config{AppName: "app-one", SharedStorage: store}) + appTwo := New(Config{AppName: "app-two", SharedStorage: store}) + + err := appOne.SharedState().SetJSON("same-key", Map{"app": 1}, time.Minute) + require.NoError(t, err) + err = appTwo.SharedState().SetJSON("same-key", Map{"app": 2}, time.Minute) + require.NoError(t, err) + + var outOne map[string]int + _, found, err := appOne.SharedState().GetJSON("same-key", &outOne) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, 1, outOne["app"]) + + var outTwo map[string]int + _, found, err = appTwo.SharedState().GetJSON("same-key", &outTwo) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, 2, outTwo["app"]) +} + +func TestSharedState_StorageErrorsArePropagated(t *testing.T) { + t.Parallel() + + expectedErr := errors.New("storage failed") + app := New(Config{ + SharedStorage: &errorStorage{err: expectedErr}, + MsgPackEncoder: func(any) ([]byte, error) { + return []byte("msgpack"), nil + }, + MsgPackDecoder: func([]byte, any) error { + return nil + }, + CBOREncoder: func(any) ([]byte, error) { + return []byte("cbor"), nil + }, + CBORDecoder: func([]byte, any) error { + return nil + }, + XMLEncoder: func(any) ([]byte, error) { + return []byte(""), nil + }, + XMLDecoder: func([]byte, any) error { + return nil + }, + }) + + err := app.SharedState().SetJSON("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, expectedErr) + err = app.SharedState().SetMsgPack("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, expectedErr) + err = app.SharedState().SetCBOR("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, expectedErr) + err = app.SharedState().SetXML("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, expectedErr) + + err = app.SharedState().Set("k", []byte("v"), time.Second) + require.ErrorIs(t, err, expectedErr) + + _, _, err = app.SharedState().Get("k") + require.ErrorIs(t, err, expectedErr) + + _, _, err = app.SharedState().GetJSON("k", &Map{}) + require.ErrorIs(t, err, expectedErr) + _, _, err = app.SharedState().GetMsgPack("k", &Map{}) + require.ErrorIs(t, err, expectedErr) + _, _, err = app.SharedState().GetCBOR("k", &Map{}) + require.ErrorIs(t, err, expectedErr) + _, _, err = app.SharedState().GetXML("k", &Map{}) + require.ErrorIs(t, err, expectedErr) + + _, err = app.SharedState().Has("k") + require.ErrorIs(t, err, expectedErr) + + err = app.SharedState().Delete("k") + require.ErrorIs(t, err, expectedErr) +} + +func TestSharedState_NilReceiver(t *testing.T) { + t.Parallel() + + var state *SharedState + + err := state.SetJSON("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = state.SetMsgPack("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = state.SetCBOR("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + err = state.SetXML("k", Map{"v": 1}, time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = state.Set("k", []byte("v"), time.Second) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + _, _, err = state.Get("k") + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + _, _, err = state.GetJSON("k", &Map{}) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, _, err = state.GetMsgPack("k", &Map{}) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, _, err = state.GetCBOR("k", &Map{}) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + _, _, err = state.GetXML("k", &Map{}) + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = state.Delete("k") + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + _, err = state.Has("k") + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) +} + +func TestSharedState_DefaultPrefixFallback(t *testing.T) { + t.Parallel() + + state := newSharedState(newSharedStateMemoryStorage(t), "", nil, nil, nil, nil, nil, nil, nil, nil) + require.Equal(t, defaultSharedStatePrefix, state.prefix) +} + +func TestSharedState_NewAppDefaultPrefixIncludesAppName(t *testing.T) { + t.Parallel() + + app := New(Config{AppName: "my-app", SharedStorage: newSharedStateMemoryStorage(t)}) + require.Equal(t, defaultSharedStatePrefix+"my-app-", app.SharedState().prefix) +} + +func TestSharedState_GetJSON_UnmarshalError(t *testing.T) { + t.Parallel() + + store := newSharedStateMemoryStorage(t) + app := New(Config{SharedStorage: store}) + + require.NoError(t, store.Set(app.SharedState().key("broken"), []byte("{"), 0)) + + var out map[string]any + _, found, err := app.SharedState().GetJSON("broken", &out) + require.False(t, found) + require.Error(t, err) +} + +func TestSharedState_Get_ReturnsCopy(t *testing.T) { + t.Parallel() + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + require.NoError(t, app.SharedState().Set("raw", []byte("value"), time.Minute)) + + got, found, err := app.SharedState().Get("raw") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("value"), got) + + got[0] = 'X' + + gotAgain, found, err := app.SharedState().Get("raw") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("value"), gotAgain) +} + +func TestSharedState_RawWithContextVariants(t *testing.T) { + t.Parallel() + + type testContextKey string + + ctxKey := testContextKey("tenant") + store := &contextCheckingStorage{ctxKey: ctxKey, base: newSharedStateMemoryStorage(t)} + app := New(Config{SharedStorage: store}) + + err := app.SharedState().SetWithContext(context.Background(), "raw", []byte("x"), time.Second) + require.Error(t, err) + + ctx := context.WithValue(context.Background(), ctxKey, "value") + require.NoError(t, app.SharedState().SetWithContext(ctx, "raw", []byte("x"), time.Second)) + + raw, found, err := app.SharedState().GetWithContext(ctx, "raw") + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("x"), raw) +} + +func TestSharedState_SetGet_RawDataKinds(t *testing.T) { + t.Parallel() + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + + testCases := []struct { + key string + value []byte + }{ + {key: "plain", value: []byte("text")}, + {key: "json", value: []byte(`{"id":42}`)}, + {key: "binary", value: []byte{0x00, 0xFF, 0x10, 0x7F}}, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + t.Parallel() + + require.NoError(t, app.SharedState().Set(tc.key, tc.value, time.Minute)) + + got, found, err := app.SharedState().Get(tc.key) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, tc.value, got) + }) + } +} + +func TestSharedState_SetGet_JSONDataKinds(t *testing.T) { + t.Parallel() + + type sample struct { + Name string `json:"name"` + Count int `json:"count"` + } + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + + t.Run("map", func(t *testing.T) { + t.Parallel() + + expected := map[string]any{ + "name": "fiber", + "ok": true, + } + require.NoError(t, app.SharedState().SetJSON("map", expected, time.Minute)) + + var out map[string]any + raw, found, err := app.SharedState().GetJSON("map", &out) + require.NoError(t, err) + require.True(t, found) + require.NotEmpty(t, raw) + require.Equal(t, expected["name"], out["name"]) + require.Equal(t, expected["ok"], out["ok"]) + }) + + t.Run("slice", func(t *testing.T) { + t.Parallel() + + expected := []int{1, 2, 3, 4} + require.NoError(t, app.SharedState().SetJSON("slice", expected, time.Minute)) + + var out []int + _, found, err := app.SharedState().GetJSON("slice", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, expected, out) + }) + + t.Run("struct", func(t *testing.T) { + t.Parallel() + + expected := sample{Name: "shared", Count: 7} + require.NoError(t, app.SharedState().SetJSON("struct", expected, time.Minute)) + + var out sample + _, found, err := app.SharedState().GetJSON("struct", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, expected, out) + }) +} + +func TestSharedState_UsesAppJSONCodec(t *testing.T) { + t.Parallel() + + encoderCalled := false + decoderCalled := false + + app := New(Config{ + SharedStorage: newSharedStateMemoryStorage(t), + JSONEncoder: func(_ any) ([]byte, error) { + encoderCalled = true + return json.Marshal(Map{"via": "custom-encoder"}) + }, + JSONDecoder: func(data []byte, out any) error { + decoderCalled = true + return json.Unmarshal(data, out) + }, + }) + + require.NoError(t, app.SharedState().SetJSON("codec", Map{"ignored": true}, time.Minute)) + + var out map[string]string + _, found, err := app.SharedState().GetJSON("codec", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, "custom-encoder", out["via"]) + require.True(t, encoderCalled) + require.True(t, decoderCalled) +} + +func TestSharedState_UsesAppMsgPackCodec(t *testing.T) { + t.Parallel() + + encoderCalled := false + decoderCalled := false + + app := New(Config{ + SharedStorage: newSharedStateMemoryStorage(t), + MsgPackEncoder: func(_ any) ([]byte, error) { + encoderCalled = true + return []byte("msgpack-payload"), nil + }, + MsgPackDecoder: func(data []byte, out any) error { + decoderCalled = true + ptr, ok := out.(*string) + if ok { + *ptr = string(data) + } + return nil + }, + }) + + require.NoError(t, app.SharedState().SetMsgPack("codec", Map{"ignored": true}, time.Minute)) + + var out string + raw, found, err := app.SharedState().GetMsgPack("codec", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("msgpack-payload"), raw) + require.Equal(t, "msgpack-payload", out) + require.True(t, encoderCalled) + require.True(t, decoderCalled) +} + +func TestSharedState_UsesAppCBORCodec(t *testing.T) { + t.Parallel() + + encoderCalled := false + decoderCalled := false + + app := New(Config{ + SharedStorage: newSharedStateMemoryStorage(t), + CBOREncoder: func(_ any) ([]byte, error) { + encoderCalled = true + return []byte("cbor-payload"), nil + }, + CBORDecoder: func(data []byte, out any) error { + decoderCalled = true + ptr, ok := out.(*string) + if ok { + *ptr = string(data) + } + return nil + }, + }) + + require.NoError(t, app.SharedState().SetCBOR("codec", Map{"ignored": true}, time.Minute)) + + var out string + raw, found, err := app.SharedState().GetCBOR("codec", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("cbor-payload"), raw) + require.Equal(t, "cbor-payload", out) + require.True(t, encoderCalled) + require.True(t, decoderCalled) +} + +func TestSharedState_UsesAppXMLCodec(t *testing.T) { + t.Parallel() + + encoderCalled := false + decoderCalled := false + + app := New(Config{ + SharedStorage: newSharedStateMemoryStorage(t), + XMLEncoder: func(_ any) ([]byte, error) { + encoderCalled = true + return []byte("xml-payload"), nil + }, + XMLDecoder: func(data []byte, out any) error { + decoderCalled = true + ptr, ok := out.(*string) + if ok { + *ptr = string(data) + } + return nil + }, + }) + + require.NoError(t, app.SharedState().SetXML("codec", Map{"ignored": true}, time.Minute)) + + var out string + raw, found, err := app.SharedState().GetXML("codec", &out) + require.NoError(t, err) + require.True(t, found) + require.Equal(t, []byte("xml-payload"), raw) + require.Equal(t, "xml-payload", out) + require.True(t, encoderCalled) + require.True(t, decoderCalled) +} From 9083f3759d5e0c509695055f5089e81c3d3ac8f4 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Tue, 28 Apr 2026 00:47:18 -0400 Subject: [PATCH 2/9] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20simplify=20SharedSt?= =?UTF-8?q?ate=20key=20strategy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- shared_state.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/shared_state.go b/shared_state.go index af07f04b14f..ae41c72f95c 100644 --- a/shared_state.go +++ b/shared_state.go @@ -2,7 +2,6 @@ package fiber import ( "context" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -291,5 +290,5 @@ func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, err } func (s *SharedState) key(key string) string { - return utils.CopyString(s.prefix + hex.EncodeToString(utils.UnsafeBytes(key))) + return s.prefix + key } From a72c503f0418ebb605cd63944fed2a5188e03168 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Tue, 28 Apr 2026 09:35:43 -0400 Subject: [PATCH 3/9] =?UTF-8?q?=F0=9F=90=9B=20bug:=20preserve=20empty-key?= =?UTF-8?q?=20behavior=20in=20SharedState=20methods?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- shared_state.go | 16 ++++++++++++++++ shared_state_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/shared_state.go b/shared_state.go index ae41c72f95c..14116c3d915 100644 --- a/shared_state.go +++ b/shared_state.go @@ -104,6 +104,9 @@ func (s *SharedState) SetJSONWithContext(ctx context.Context, key string, v any, if s == nil || s.storage == nil { return ErrSharedStorageNotConfigured } + if key == "" { + return nil + } encoded, err := s.jsonEncoder(v) if err != nil { @@ -145,6 +148,9 @@ func (s *SharedState) SetMsgPackWithContext(ctx context.Context, key string, v a if s == nil || s.storage == nil { return ErrSharedStorageNotConfigured } + if key == "" { + return nil + } encoded, err := s.msgPackEncoder(v) if err != nil { @@ -186,6 +192,9 @@ func (s *SharedState) SetCBORWithContext(ctx context.Context, key string, v any, if s == nil || s.storage == nil { return ErrSharedStorageNotConfigured } + if key == "" { + return nil + } encoded, err := s.cborEncoder(v) if err != nil { @@ -227,6 +236,9 @@ func (s *SharedState) SetXMLWithContext(ctx context.Context, key string, v any, if s == nil || s.storage == nil { return ErrSharedStorageNotConfigured } + if key == "" { + return nil + } encoded, err := s.xmlEncoder(v) if err != nil { @@ -290,5 +302,9 @@ func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, err } func (s *SharedState) key(key string) string { + if key == "" { + return "" + } + return s.prefix + key } diff --git a/shared_state_test.go b/shared_state_test.go index 47ba12aa655..1862d1d072b 100644 --- a/shared_state_test.go +++ b/shared_state_test.go @@ -649,3 +649,42 @@ func TestSharedState_UsesAppXMLCodec(t *testing.T) { require.True(t, encoderCalled) require.True(t, decoderCalled) } + +func TestSharedState_EmptyKeyBehavior(t *testing.T) { + t.Parallel() + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + + require.NoError(t, app.SharedState().Set("", []byte("raw"), time.Minute)) + require.NoError(t, app.SharedState().SetJSON("", Map{"v": 1}, time.Minute)) + require.NoError(t, app.SharedState().SetMsgPack("", Map{"v": 1}, time.Minute)) + require.NoError(t, app.SharedState().SetCBOR("", Map{"v": 1}, time.Minute)) + require.NoError(t, app.SharedState().SetXML("", Map{"v": 1}, time.Minute)) + + raw, found, err := app.SharedState().Get("") + require.NoError(t, err) + require.Nil(t, raw) + require.False(t, found) + + _, found, err = app.SharedState().GetJSON("", &Map{}) + require.NoError(t, err) + require.False(t, found) + + _, found, err = app.SharedState().GetMsgPack("", &Map{}) + require.NoError(t, err) + require.False(t, found) + + _, found, err = app.SharedState().GetCBOR("", &Map{}) + require.NoError(t, err) + require.False(t, found) + + _, found, err = app.SharedState().GetXML("", &Map{}) + require.NoError(t, err) + require.False(t, found) + + require.NoError(t, app.SharedState().Delete("")) + + has, err := app.SharedState().Has("") + require.NoError(t, err) + require.False(t, has) +} From c057a7c13c9f8e4ea928b4eca8d91074321785f7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 30 Apr 2026 04:06:58 +0000 Subject: [PATCH 4/9] =?UTF-8?q?=F0=9F=90=9B=20bug:=20harden=20shared=20sta?= =?UTF-8?q?te=20codec=20and=20key=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/ae846c95-4e83-4e0a-87ee-1b7b9c3deaa3 Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- app.go | 20 +-- docs/api/state.md | 6 + docs/whats_new.md | 2 +- shared_state.go | 338 +++++++++++++++++++++++++++---------------- shared_state_test.go | 54 ++++++- 5 files changed, 268 insertions(+), 152 deletions(-) diff --git a/app.go b/app.go index 57e66f815b0..01923004d4d 100644 --- a/app.go +++ b/app.go @@ -654,25 +654,7 @@ func New(config ...Config) *App { app.config.XMLDecoder = xml.Unmarshal } - sharedStatePrefix := app.config.SharedStatePrefix - if sharedStatePrefix == "" { - sharedStatePrefix = defaultSharedStatePrefix - if app.config.AppName != "" { - sharedStatePrefix += app.config.AppName + "-" - } - } - app.sharedState = newSharedState( - app.config.SharedStorage, - sharedStatePrefix, - app.config.JSONEncoder, - app.config.JSONDecoder, - app.config.MsgPackEncoder, - app.config.MsgPackDecoder, - app.config.CBOREncoder, - app.config.CBORDecoder, - app.config.XMLEncoder, - app.config.XMLDecoder, - ) + app.sharedState = newSharedState(app.config) if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } diff --git a/docs/api/state.md b/docs/api/state.md index 74bfb885b38..a761f770ca5 100644 --- a/docs/api/state.md +++ b/docs/api/state.md @@ -26,6 +26,8 @@ app := fiber.New(fiber.Config{ If `SharedStatePrefix` is empty, Fiber derives a default namespace and includes `AppName` (when set) to reduce collisions between apps/services. +MsgPack and CBOR helpers require the corresponding `Config` encoders/decoders to be configured. If they are unavailable, the helper methods return an error instead of panicking. + :::warning Memory storage caveat `SharedState` is only cross-worker / cross-process when the configured `SharedStorage` backend is shared. @@ -72,6 +74,10 @@ func (s *SharedState) DeleteWithContext(ctx context.Context, key string) error func (s *SharedState) Has(key string) (bool, error) func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, error) + +func (s *SharedState) Reset() error +func (s *SharedState) ResetWithContext(ctx context.Context) error +func (s *SharedState) Close() error ``` ### SharedState Example diff --git a/docs/whats_new.md b/docs/whats_new.md index 03296712f3d..c0bcf9cafeb 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -84,7 +84,7 @@ We have made several changes to the Fiber app, including: - **RegisterCustomConstraint**: Allows for the registration of custom constraints. - **NewWithCustomCtx**: Initialize an app with a custom context in one step. - **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details. -- **SharedState**: Introduces storage-backed app state for prefork-safe/multi-process coordination via `Config.SharedStorage`, with optional `Config.SharedStatePrefix` namespacing and JSON/context-aware helpers (`SetJSON`, `GetJSON`, `Has`, `Delete`, and `WithContext` variants). +- **SharedState**: Introduces storage-backed app state for prefork-safe/multi-process coordination via `Config.SharedStorage`, with optional `Config.SharedStatePrefix` namespacing, codec-aware helpers (`SetJSON`, `SetMsgPack`, `SetCBOR`, `SetXML`, matching getters, and `WithContext` variants), empty-key no-op handling, and `Reset`/`Close` passthrough helpers. - **NewErrorf**: Allows variadic parameters when creating formatted errors. - **GetBytes / GetString**: Helpers that detach values only when `Immutable` is enabled and the data still references request or response buffers. Access via `c.App().GetString` and `c.App().GetBytes`. - **ReloadViews**: Lets you re-run the configured view engine's `Load()` logic at runtime, including guard rails for missing or nil view engines so development hot-reload hooks can refresh templates safely. diff --git a/shared_state.go b/shared_state.go index 14116c3d915..974d5849d01 100644 --- a/shared_state.go +++ b/shared_state.go @@ -3,6 +3,7 @@ package fiber import ( "context" "encoding/json" + "encoding/xml" "errors" "fmt" "time" @@ -28,36 +29,41 @@ type SharedState struct { } func newSharedState( - storage Storage, - prefix string, - jsonEncoder utils.JSONMarshal, - jsonDecoder utils.JSONUnmarshal, - msgPackEncoder utils.MsgPackMarshal, - msgPackDecoder utils.MsgPackUnmarshal, - cborEncoder utils.CBORMarshal, - cborDecoder utils.CBORUnmarshal, - xmlEncoder utils.XMLMarshal, - xmlDecoder utils.XMLUnmarshal, + cfg Config, ) *SharedState { + prefix := cfg.SharedStatePrefix if prefix == "" { prefix = defaultSharedStatePrefix + if cfg.AppName != "" { + prefix += cfg.AppName + "-" + } } + jsonEncoder := cfg.JSONEncoder if jsonEncoder == nil { jsonEncoder = json.Marshal } + jsonDecoder := cfg.JSONDecoder if jsonDecoder == nil { jsonDecoder = json.Unmarshal } + xmlEncoder := cfg.XMLEncoder + if xmlEncoder == nil { + xmlEncoder = xml.Marshal + } + xmlDecoder := cfg.XMLDecoder + if xmlDecoder == nil { + xmlDecoder = xml.Unmarshal + } return &SharedState{ - storage: storage, + storage: cfg.SharedStorage, jsonEncoder: jsonEncoder, jsonDecoder: jsonDecoder, - msgPackEncoder: msgPackEncoder, - msgPackDecoder: msgPackDecoder, - cborEncoder: cborEncoder, - cborDecoder: cborDecoder, + msgPackEncoder: cfg.MsgPackEncoder, + msgPackDecoder: cfg.MsgPackDecoder, + cborEncoder: cfg.CBOREncoder, + cborDecoder: cfg.CBORDecoder, xmlEncoder: xmlEncoder, xmlDecoder: xmlDecoder, prefix: prefix, @@ -69,11 +75,16 @@ func (s *SharedState) Set(key string, val []byte, ttl time.Duration) error { } func (s *SharedState) SetWithContext(ctx context.Context, key string, val []byte, ttl time.Duration) error { - if s == nil || s.storage == nil { - return ErrSharedStorageNotConfigured + if err := s.ensureStorage(); err != nil { + return err + } + + storageKey, ok := s.storageKey(key) + if !ok { + return nil } - return s.storage.SetWithContext(ctx, s.key(key), val, ttl) + return s.storage.SetWithContext(ctx, storageKey, val, ttl) } func (s *SharedState) Get(key string) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. @@ -81,11 +92,16 @@ func (s *SharedState) Get(key string) ([]byte, bool, error) { //nolint:gocritic } func (s *SharedState) GetWithContext(ctx context.Context, key string) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - if s == nil || s.storage == nil { - return nil, false, ErrSharedStorageNotConfigured + if err := s.ensureStorage(); err != nil { + return nil, false, err + } + + storageKey, ok := s.storageKey(key) + if !ok { + return nil, false, nil } - data, err := s.storage.GetWithContext(ctx, s.key(key)) + data, err := s.storage.GetWithContext(ctx, storageKey) if err != nil { return nil, false, err } @@ -101,19 +117,11 @@ func (s *SharedState) SetJSON(key string, v any, ttl time.Duration) error { } func (s *SharedState) SetJSONWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { - if s == nil || s.storage == nil { - return ErrSharedStorageNotConfigured - } - if key == "" { - return nil - } - - encoded, err := s.jsonEncoder(v) - if err != nil { - return fmt.Errorf("fiber: failed to encode shared state value: %w", err) + if err := s.ensureStorage(); err != nil { + return err } - return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) + return s.setEncodedWithContext(ctx, key, v, ttl, s.jsonEncoder, "json") } func (s *SharedState) GetJSON(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. @@ -121,23 +129,11 @@ func (s *SharedState) GetJSON(key string, out any) ([]byte, bool, error) { //nol } func (s *SharedState) GetJSONWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - if s == nil || s.storage == nil { - return nil, false, ErrSharedStorageNotConfigured - } - - data, err := s.storage.GetWithContext(ctx, s.key(key)) - if err != nil { + if err := s.ensureStorage(); err != nil { return nil, false, err } - if data == nil { - return nil, false, nil - } - if err := s.jsonDecoder(data, out); err != nil { - return nil, false, fmt.Errorf("fiber: failed to decode shared state value: %w", err) - } - - return append([]byte(nil), data...), true, nil + return s.getEncodedWithContext(ctx, key, out, s.jsonDecoder, "json") } func (s *SharedState) SetMsgPack(key string, v any, ttl time.Duration) error { @@ -145,19 +141,11 @@ func (s *SharedState) SetMsgPack(key string, v any, ttl time.Duration) error { } func (s *SharedState) SetMsgPackWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { - if s == nil || s.storage == nil { - return ErrSharedStorageNotConfigured - } - if key == "" { - return nil + if err := s.ensureStorage(); err != nil { + return err } - encoded, err := s.msgPackEncoder(v) - if err != nil { - return fmt.Errorf("fiber: failed to encode shared state msgpack value: %w", err) - } - - return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) + return s.setEncodedWithContext(ctx, key, v, ttl, s.msgPackEncoder, "msgpack") } func (s *SharedState) GetMsgPack(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. @@ -165,23 +153,11 @@ func (s *SharedState) GetMsgPack(key string, out any) ([]byte, bool, error) { // } func (s *SharedState) GetMsgPackWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - if s == nil || s.storage == nil { - return nil, false, ErrSharedStorageNotConfigured - } - - data, err := s.storage.GetWithContext(ctx, s.key(key)) - if err != nil { + if err := s.ensureStorage(); err != nil { return nil, false, err } - if data == nil { - return nil, false, nil - } - if err := s.msgPackDecoder(data, out); err != nil { - return nil, false, fmt.Errorf("fiber: failed to decode shared state msgpack value: %w", err) - } - - return append([]byte(nil), data...), true, nil + return s.getEncodedWithContext(ctx, key, out, s.msgPackDecoder, "msgpack") } func (s *SharedState) SetCBOR(key string, v any, ttl time.Duration) error { @@ -189,19 +165,11 @@ func (s *SharedState) SetCBOR(key string, v any, ttl time.Duration) error { } func (s *SharedState) SetCBORWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { - if s == nil || s.storage == nil { - return ErrSharedStorageNotConfigured - } - if key == "" { - return nil - } - - encoded, err := s.cborEncoder(v) - if err != nil { - return fmt.Errorf("fiber: failed to encode shared state cbor value: %w", err) + if err := s.ensureStorage(); err != nil { + return err } - return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) + return s.setEncodedWithContext(ctx, key, v, ttl, s.cborEncoder, "cbor") } func (s *SharedState) GetCBOR(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. @@ -209,55 +177,146 @@ func (s *SharedState) GetCBOR(key string, out any) ([]byte, bool, error) { //nol } func (s *SharedState) GetCBORWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - if s == nil || s.storage == nil { - return nil, false, ErrSharedStorageNotConfigured + if err := s.ensureStorage(); err != nil { + return nil, false, err } - data, err := s.storage.GetWithContext(ctx, s.key(key)) - if err != nil { + return s.getEncodedWithContext(ctx, key, out, s.cborDecoder, "cbor") +} + +func (s *SharedState) SetXML(key string, v any, ttl time.Duration) error { + return s.SetXMLWithContext(context.Background(), key, v, ttl) +} + +func (s *SharedState) SetXMLWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { + if err := s.ensureStorage(); err != nil { + return err + } + + return s.setEncodedWithContext(ctx, key, v, ttl, s.xmlEncoder, "xml") +} + +func (s *SharedState) GetXML(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + return s.GetXMLWithContext(context.Background(), key, out) +} + +func (s *SharedState) GetXMLWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. + if err := s.ensureStorage(); err != nil { return nil, false, err } - if data == nil { - return nil, false, nil + + return s.getEncodedWithContext(ctx, key, out, s.xmlDecoder, "xml") +} + +func (s *SharedState) Delete(key string) error { + return s.DeleteWithContext(context.Background(), key) +} + +func (s *SharedState) DeleteWithContext(ctx context.Context, key string) error { + if err := s.ensureStorage(); err != nil { + return err } - if err := s.cborDecoder(data, out); err != nil { - return nil, false, fmt.Errorf("fiber: failed to decode shared state cbor value: %w", err) + storageKey, ok := s.storageKey(key) + if !ok { + return nil } - return append([]byte(nil), data...), true, nil + return s.storage.DeleteWithContext(ctx, storageKey) } -func (s *SharedState) SetXML(key string, v any, ttl time.Duration) error { - return s.SetXMLWithContext(context.Background(), key, v, ttl) +func (s *SharedState) Has(key string) (bool, error) { + return s.HasWithContext(context.Background(), key) } -func (s *SharedState) SetXMLWithContext(ctx context.Context, key string, v any, ttl time.Duration) error { +func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, error) { + if err := s.ensureStorage(); err != nil { + return false, err + } + + storageKey, ok := s.storageKey(key) + if !ok { + return false, nil + } + + data, err := s.storage.GetWithContext(ctx, storageKey) + if err != nil { + return false, err + } + + return data != nil, nil +} + +func (s *SharedState) Reset() error { + return s.ResetWithContext(context.Background()) +} + +func (s *SharedState) ResetWithContext(ctx context.Context) error { + if err := s.ensureStorage(); err != nil { + return err + } + + return s.storage.ResetWithContext(ctx) +} + +func (s *SharedState) Close() error { + if err := s.ensureStorage(); err != nil { + return err + } + + return s.storage.Close() +} + +func (s *SharedState) ensureStorage() error { if s == nil || s.storage == nil { return ErrSharedStorageNotConfigured } - if key == "" { + + return nil +} + +func (s *SharedState) setEncodedWithContext( + ctx context.Context, + key string, + v any, + ttl time.Duration, + encoder func(any) ([]byte, error), + format string, +) error { + if err := s.ensureStorage(); err != nil { + return err + } + + storageKey, ok := s.storageKey(key) + if !ok { return nil } - encoded, err := s.xmlEncoder(v) + encoded, err := encodeSharedStateValue(v, encoder, format) if err != nil { - return fmt.Errorf("fiber: failed to encode shared state xml value: %w", err) + return err } - return s.storage.SetWithContext(ctx, s.key(key), encoded, ttl) + return s.storage.SetWithContext(ctx, storageKey, encoded, ttl) } -func (s *SharedState) GetXML(key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - return s.GetXMLWithContext(context.Background(), key, out) -} +func (s *SharedState) getEncodedWithContext( + ctx context.Context, + key string, + out any, + decoder func([]byte, any) error, + format string, +) ([]byte, bool, error) { + if err := s.ensureStorage(); err != nil { + return nil, false, err + } -func (s *SharedState) GetXMLWithContext(ctx context.Context, key string, out any) ([]byte, bool, error) { //nolint:gocritic // Keep unnamed returns for clarity. - if s == nil || s.storage == nil { - return nil, false, ErrSharedStorageNotConfigured + storageKey, ok := s.storageKey(key) + if !ok { + return nil, false, nil } - data, err := s.storage.GetWithContext(ctx, s.key(key)) + data, err := s.storage.GetWithContext(ctx, storageKey) if err != nil { return nil, false, err } @@ -265,46 +324,71 @@ func (s *SharedState) GetXMLWithContext(ctx context.Context, key string, out any return nil, false, nil } - if err := s.xmlDecoder(data, out); err != nil { - return nil, false, fmt.Errorf("fiber: failed to decode shared state xml value: %w", err) + if err := decodeSharedStateValue(data, out, decoder, format); err != nil { + return nil, false, err } return append([]byte(nil), data...), true, nil } -func (s *SharedState) Delete(key string) error { - return s.DeleteWithContext(context.Background(), key) -} +func encodeSharedStateValue(v any, encoder func(any) ([]byte, error), format string) (encoded []byte, err error) { + if encoder == nil { + return nil, sharedStateCodecNotConfiguredError(format, "encoder") + } -func (s *SharedState) DeleteWithContext(ctx context.Context, key string) error { - if s == nil || s.storage == nil { - return ErrSharedStorageNotConfigured + defer func() { + if recovered := recover(); recovered != nil { + err = sharedStateCodecPanicError("encode", format, recovered) + } + }() + + encoded, err = encoder(v) + if err != nil { + return nil, fmt.Errorf("fiber: failed to encode shared state %s value: %w", format, err) } - return s.storage.DeleteWithContext(ctx, s.key(key)) + return encoded, nil } -func (s *SharedState) Has(key string) (bool, error) { - return s.HasWithContext(context.Background(), key) -} +func decodeSharedStateValue(data []byte, out any, decoder func([]byte, any) error, format string) (err error) { + if decoder == nil { + return sharedStateCodecNotConfiguredError(format, "decoder") + } -func (s *SharedState) HasWithContext(ctx context.Context, key string) (bool, error) { - if s == nil || s.storage == nil { - return false, ErrSharedStorageNotConfigured + defer func() { + if recovered := recover(); recovered != nil { + err = sharedStateCodecPanicError("decode", format, recovered) + } + }() + + if err = decoder(data, out); err != nil { + return fmt.Errorf("fiber: failed to decode shared state %s value: %w", format, err) } - data, err := s.storage.GetWithContext(ctx, s.key(key)) - if err != nil { - return false, err + return nil +} + +func sharedStateCodecNotConfiguredError(format, direction string) error { + return fmt.Errorf("fiber: shared state %s %s is not configured", format, direction) +} + +func sharedStateCodecPanicError(operation, format string, recovered any) error { + if err, ok := recovered.(error); ok { + return fmt.Errorf("fiber: failed to %s shared state %s value: %w", operation, format, err) } - return data != nil, nil + return fmt.Errorf("fiber: failed to %s shared state %s value: %v", operation, format, recovered) } -func (s *SharedState) key(key string) string { +func (s *SharedState) storageKey(key string) (string, bool) { if key == "" { - return "" + return "", false } - return s.prefix + key + return s.prefix + key, true +} + +func (s *SharedState) key(key string) string { + storageKey, _ := s.storageKey(key) + return storageKey } diff --git a/shared_state_test.go b/shared_state_test.go index 1862d1d072b..a449c0ea65f 100644 --- a/shared_state_test.go +++ b/shared_state_test.go @@ -28,7 +28,8 @@ type contextCheckingStorage struct { } type errorStorage struct { - err error + err error + closeErr error } func (s *errorStorage) GetWithContext(context.Context, string) ([]byte, error) { @@ -63,8 +64,8 @@ func (s *errorStorage) Reset() error { return s.err } -func (*errorStorage) Close() error { - return nil +func (s *errorStorage) Close() error { + return s.closeErr } func (s *contextCheckingStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { @@ -153,6 +154,12 @@ func TestSharedState_NotConfigured(t *testing.T) { has, err := app.SharedState().Has("key") require.False(t, has) require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = app.SharedState().Reset() + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = app.SharedState().Close() + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) } func TestSharedState_PreforkSafeWithSharedStorage(t *testing.T) { @@ -280,7 +287,7 @@ func TestSharedState_StorageErrorsArePropagated(t *testing.T) { expectedErr := errors.New("storage failed") app := New(Config{ - SharedStorage: &errorStorage{err: expectedErr}, + SharedStorage: &errorStorage{err: expectedErr, closeErr: expectedErr}, MsgPackEncoder: func(any) ([]byte, error) { return []byte("msgpack"), nil }, @@ -330,6 +337,12 @@ func TestSharedState_StorageErrorsArePropagated(t *testing.T) { err = app.SharedState().Delete("k") require.ErrorIs(t, err, expectedErr) + + err = app.SharedState().Reset() + require.ErrorIs(t, err, expectedErr) + + err = app.SharedState().Close() + require.ErrorIs(t, err, expectedErr) } func TestSharedState_NilReceiver(t *testing.T) { @@ -366,12 +379,18 @@ func TestSharedState_NilReceiver(t *testing.T) { _, err = state.Has("k") require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = state.Reset() + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) + + err = state.Close() + require.ErrorIs(t, err, ErrSharedStorageNotConfigured) } func TestSharedState_DefaultPrefixFallback(t *testing.T) { t.Parallel() - state := newSharedState(newSharedStateMemoryStorage(t), "", nil, nil, nil, nil, nil, nil, nil, nil) + state := newSharedState(Config{SharedStorage: newSharedStateMemoryStorage(t)}) require.Equal(t, defaultSharedStatePrefix, state.prefix) } @@ -582,6 +601,31 @@ func TestSharedState_UsesAppMsgPackCodec(t *testing.T) { require.True(t, decoderCalled) } +func TestSharedState_UnconfiguredCodecsReturnErrorInsteadOfPanic(t *testing.T) { + t.Parallel() + + app := New(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + + err := app.SharedState().SetMsgPack("codec", Map{"ignored": true}, time.Minute) + require.ErrorContains(t, err, "shared state msgpack") + + require.NoError(t, app.SharedState().Set("msgpack-payload", []byte("payload"), time.Minute)) + + var out Map + _, found, err := app.SharedState().GetMsgPack("msgpack-payload", &out) + require.False(t, found) + require.ErrorContains(t, err, "shared state msgpack") + + err = app.SharedState().SetCBOR("codec", Map{"ignored": true}, time.Minute) + require.ErrorContains(t, err, "shared state cbor") + + require.NoError(t, app.SharedState().Set("cbor-payload", []byte("payload"), time.Minute)) + + _, found, err = app.SharedState().GetCBOR("cbor-payload", &out) + require.False(t, found) + require.ErrorContains(t, err, "shared state cbor") +} + func TestSharedState_UsesAppCBORCodec(t *testing.T) { t.Parallel() From a4946e981be7e77220550202d5363f42791b8b88 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 30 Apr 2026 04:14:54 +0000 Subject: [PATCH 5/9] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20polish=20shared=20s?= =?UTF-8?q?tate=20review=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/ae846c95-4e83-4e0a-87ee-1b7b9c3deaa3 Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- app.go | 2 +- shared_state.go | 62 ++++++++++++++++++++++++++++++-------------- shared_state_test.go | 11 +++++--- 3 files changed, 50 insertions(+), 25 deletions(-) diff --git a/app.go b/app.go index 01923004d4d..dbb821e95e9 100644 --- a/app.go +++ b/app.go @@ -654,7 +654,7 @@ func New(config ...Config) *App { app.config.XMLDecoder = xml.Unmarshal } - app.sharedState = newSharedState(app.config) + app.sharedState = newSharedState(&app.config) if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } diff --git a/shared_state.go b/shared_state.go index 974d5849d01..2f06fb50a64 100644 --- a/shared_state.go +++ b/shared_state.go @@ -28,9 +28,11 @@ type SharedState struct { prefix string } -func newSharedState( - cfg Config, -) *SharedState { +func newSharedState(cfg *Config) *SharedState { + if cfg == nil { + cfg = &Config{} + } + prefix := cfg.SharedStatePrefix if prefix == "" { prefix = defaultSharedStatePrefix @@ -300,6 +302,7 @@ func (s *SharedState) setEncodedWithContext( return s.storage.SetWithContext(ctx, storageKey, encoded, ttl) } +//nolint:gocritic // Keep unnamed returns for clarity. func (s *SharedState) getEncodedWithContext( ctx context.Context, key string, @@ -331,18 +334,30 @@ func (s *SharedState) getEncodedWithContext( return append([]byte(nil), data...), true, nil } -func encodeSharedStateValue(v any, encoder func(any) ([]byte, error), format string) (encoded []byte, err error) { +func encodeSharedStateValue(v any, encoder func(any) ([]byte, error), format string) ([]byte, error) { if encoder == nil { return nil, sharedStateCodecNotConfiguredError(format, "encoder") } - defer func() { - if recovered := recover(); recovered != nil { - err = sharedStateCodecPanicError("encode", format, recovered) - } + var ( + encoded []byte + err error + recovered any + ) + func() { + // App-configured codecs may be nil or may still use Fiber's + // binder.Unimplemented* placeholders, which panic instead of returning an + // error, so recover here and surface a regular error. + defer func() { + recovered = recover() + }() + + encoded, err = encoder(v) }() - encoded, err = encoder(v) + if recovered != nil { + return nil, sharedStateCodecPanicError("encode", format, recovered) + } if err != nil { return nil, fmt.Errorf("fiber: failed to encode shared state %s value: %w", format, err) } @@ -350,18 +365,30 @@ func encodeSharedStateValue(v any, encoder func(any) ([]byte, error), format str return encoded, nil } -func decodeSharedStateValue(data []byte, out any, decoder func([]byte, any) error, format string) (err error) { +func decodeSharedStateValue(data []byte, out any, decoder func([]byte, any) error, format string) error { if decoder == nil { return sharedStateCodecNotConfiguredError(format, "decoder") } - defer func() { - if recovered := recover(); recovered != nil { - err = sharedStateCodecPanicError("decode", format, recovered) - } + var ( + err error + recovered any + ) + func() { + // App-configured codecs may be nil or may still use Fiber's + // binder.Unimplemented* placeholders, which panic instead of returning an + // error, so recover here and surface a regular error. + defer func() { + recovered = recover() + }() + + err = decoder(data, out) }() - if err = decoder(data, out); err != nil { + if recovered != nil { + return sharedStateCodecPanicError("decode", format, recovered) + } + if err != nil { return fmt.Errorf("fiber: failed to decode shared state %s value: %w", format, err) } @@ -387,8 +414,3 @@ func (s *SharedState) storageKey(key string) (string, bool) { return s.prefix + key, true } - -func (s *SharedState) key(key string) string { - storageKey, _ := s.storageKey(key) - return storageKey -} diff --git a/shared_state_test.go b/shared_state_test.go index a449c0ea65f..0fe0437694b 100644 --- a/shared_state_test.go +++ b/shared_state_test.go @@ -286,8 +286,9 @@ func TestSharedState_StorageErrorsArePropagated(t *testing.T) { t.Parallel() expectedErr := errors.New("storage failed") + closeErr := errors.New("close failed") app := New(Config{ - SharedStorage: &errorStorage{err: expectedErr, closeErr: expectedErr}, + SharedStorage: &errorStorage{err: expectedErr, closeErr: closeErr}, MsgPackEncoder: func(any) ([]byte, error) { return []byte("msgpack"), nil }, @@ -342,7 +343,7 @@ func TestSharedState_StorageErrorsArePropagated(t *testing.T) { require.ErrorIs(t, err, expectedErr) err = app.SharedState().Close() - require.ErrorIs(t, err, expectedErr) + require.ErrorIs(t, err, closeErr) } func TestSharedState_NilReceiver(t *testing.T) { @@ -390,7 +391,7 @@ func TestSharedState_NilReceiver(t *testing.T) { func TestSharedState_DefaultPrefixFallback(t *testing.T) { t.Parallel() - state := newSharedState(Config{SharedStorage: newSharedStateMemoryStorage(t)}) + state := newSharedState(&Config{SharedStorage: newSharedStateMemoryStorage(t)}) require.Equal(t, defaultSharedStatePrefix, state.prefix) } @@ -407,7 +408,9 @@ func TestSharedState_GetJSON_UnmarshalError(t *testing.T) { store := newSharedStateMemoryStorage(t) app := New(Config{SharedStorage: store}) - require.NoError(t, store.Set(app.SharedState().key("broken"), []byte("{"), 0)) + storageKey, ok := app.SharedState().storageKey("broken") + require.True(t, ok) + require.NoError(t, store.Set(storageKey, []byte("{"), 0)) var out map[string]any _, found, err := app.SharedState().GetJSON("broken", &out) From 26547f41ca69e9cd550346b08c044ea04e87505c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:49:25 +0000 Subject: [PATCH 6/9] =?UTF-8?q?=F0=9F=90=9B=20bug:=20close=20shared=20stat?= =?UTF-8?q?e=20on=20app=20shutdown?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/7d092650-1040-458d-9f36-30712bdc39a0 Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- app.go | 5 +++ services_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) diff --git a/app.go b/app.go index dbb821e95e9..124f83c071e 100644 --- a/app.go +++ b/app.go @@ -1386,6 +1386,11 @@ func (app *App) init() *App { if err := app.shutdownServices(app.servicesShutdownCtx()); err != nil { log.Errorf("failed to shutdown services: %v", err) } + if app.sharedState != nil && app.sharedState.storage != nil { + if err := app.sharedState.Close(); err != nil { + log.Errorf("failed to close sharedState: %v", err) + } + } return nil }) diff --git a/services_test.go b/services_test.go index 5e29063bb95..96bea13303c 100644 --- a/services_test.go +++ b/services_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "testing" "time" @@ -30,6 +31,48 @@ type mockService struct { terminateDelay time.Duration } +type shutdownHookStorage struct { + closeErr error + closeCalled atomic.Bool +} + +func (*shutdownHookStorage) GetWithContext(context.Context, string) ([]byte, error) { + return nil, nil +} + +func (*shutdownHookStorage) Get(string) ([]byte, error) { + return nil, nil +} + +func (*shutdownHookStorage) SetWithContext(context.Context, string, []byte, time.Duration) error { + return nil +} + +func (*shutdownHookStorage) Set(string, []byte, time.Duration) error { + return nil +} + +func (*shutdownHookStorage) DeleteWithContext(context.Context, string) error { + return nil +} + +func (*shutdownHookStorage) Delete(string) error { + return nil +} + +func (*shutdownHookStorage) ResetWithContext(context.Context) error { + return nil +} + +func (*shutdownHookStorage) Reset() error { + return nil +} + +func (s *shutdownHookStorage) Close() error { + s.closeCalled.Store(true) + return s.closeErr +} + func (m *mockService) Start(ctx context.Context) error { select { case <-ctx.Done(): @@ -213,6 +256,56 @@ func Test_InitServices(t *testing.T) { require.Contains(t, buf.String(), "failed to shutdown services: service dep2 terminate: terminate error 2") }) + + t.Run("shutdown-hooks/close-shared-state", func(t *testing.T) { + storage := &shutdownHookStorage{} + app := New(Config{ + Services: []Service{&mockService{name: "dep1"}}, + SharedStorage: storage, + }) + + require.NotPanics(t, app.initServices) + + type stringsLogger struct { + strings.Builder + } + + var buf stringsLogger + log.SetOutput(&buf) + t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) + + app.Hooks().executeOnPostShutdownHooks(nil) + + require.True(t, storage.closeCalled.Load()) + require.NotContains(t, buf.String(), "failed to close sharedState:") + }) + + t.Run("shutdown-hooks/close-shared-state-after-service-error", func(t *testing.T) { + storage := &shutdownHookStorage{closeErr: errors.New("close error")} + app := New(Config{ + Services: []Service{ + &mockService{name: "dep1"}, + &mockService{name: "dep2", terminateError: errors.New(terminateErrorMessage + " 2")}, + }, + SharedStorage: storage, + }) + + require.NotPanics(t, app.initServices) + + type stringsLogger struct { + strings.Builder + } + + var buf stringsLogger + log.SetOutput(&buf) + t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) + + app.Hooks().executeOnPostShutdownHooks(nil) + + require.True(t, storage.closeCalled.Load()) + require.Contains(t, buf.String(), "failed to shutdown services: service dep2 terminate: terminate error 2") + require.Contains(t, buf.String(), "failed to close sharedState: close error") + }) } func Test_StartServices(t *testing.T) { From 34c4bc91e246181db6b66e38208003eccb785e36 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 11:50:51 +0000 Subject: [PATCH 7/9] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20clean=20shutdown=20?= =?UTF-8?q?hook=20test=20helpers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/7d092650-1040-458d-9f36-30712bdc39a0 Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- services_test.go | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/services_test.go b/services_test.go index 96bea13303c..c419bbbbd31 100644 --- a/services_test.go +++ b/services_test.go @@ -36,6 +36,10 @@ type shutdownHookStorage struct { closeCalled atomic.Bool } +type stringsLogger struct { + strings.Builder +} + func (*shutdownHookStorage) GetWithContext(context.Context, string) ([]byte, error) { return nil, nil } @@ -221,10 +225,6 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) - type stringsLogger struct { - strings.Builder - } - var buf stringsLogger log.SetOutput(&buf) @@ -245,10 +245,6 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) - type stringsLogger struct { - strings.Builder - } - var buf stringsLogger log.SetOutput(&buf) @@ -266,10 +262,6 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) - type stringsLogger struct { - strings.Builder - } - var buf stringsLogger log.SetOutput(&buf) t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) @@ -292,10 +284,6 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) - type stringsLogger struct { - strings.Builder - } - var buf stringsLogger log.SetOutput(&buf) t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) From 72fc780bbbb4bc163f977fe64853ebf0c2c1a4fd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 12:03:11 +0000 Subject: [PATCH 8/9] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20isolate=20shutdown?= =?UTF-8?q?=20hook=20logger=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/e5263c7d-b450-40dd-9aaf-31d81e0efb5d Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- services_test.go | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/services_test.go b/services_test.go index c419bbbbd31..1f9485d6cd7 100644 --- a/services_test.go +++ b/services_test.go @@ -5,7 +5,9 @@ import ( "context" "errors" "fmt" + stdlog "log" //nolint:depguard // Test needs the concrete stdlib logger type to restore the previous writer. "strings" + "sync" "sync/atomic" "testing" "time" @@ -40,6 +42,8 @@ type stringsLogger struct { strings.Builder } +var servicesTestLogOutputMu sync.Mutex + func (*shutdownHookStorage) GetWithContext(context.Context, string) ([]byte, error) { return nil, nil } @@ -174,6 +178,8 @@ func Test_HasConfiguredServices(t *testing.T) { } func Test_InitServices(t *testing.T) { + t.Parallel() + t.Run("no-services", func(t *testing.T) { app := &App{configured: Config{}} require.NotPanics(t, app.initServices) @@ -254,6 +260,8 @@ func Test_InitServices(t *testing.T) { }) t.Run("shutdown-hooks/close-shared-state", func(t *testing.T) { + t.Parallel() + storage := &shutdownHookStorage{} app := New(Config{ Services: []Service{&mockService{name: "dep1"}}, @@ -263,8 +271,13 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) var buf stringsLogger + currentOutput := log.DefaultLogger[*stdlog.Logger]().Logger().Writer() + servicesTestLogOutputMu.Lock() log.SetOutput(&buf) - t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) + t.Cleanup(func() { + log.SetOutput(currentOutput) + servicesTestLogOutputMu.Unlock() + }) app.Hooks().executeOnPostShutdownHooks(nil) @@ -273,6 +286,8 @@ func Test_InitServices(t *testing.T) { }) t.Run("shutdown-hooks/close-shared-state-after-service-error", func(t *testing.T) { + t.Parallel() + storage := &shutdownHookStorage{closeErr: errors.New("close error")} app := New(Config{ Services: []Service{ @@ -285,8 +300,13 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) var buf stringsLogger + currentOutput := log.DefaultLogger[*stdlog.Logger]().Logger().Writer() + servicesTestLogOutputMu.Lock() log.SetOutput(&buf) - t.Cleanup(func() { log.SetOutput(bytes.NewBuffer(nil)) }) + t.Cleanup(func() { + log.SetOutput(currentOutput) + servicesTestLogOutputMu.Unlock() + }) app.Hooks().executeOnPostShutdownHooks(nil) From 6543581ed641d54270588298958aefd969d27306 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 1 May 2026 12:09:32 +0000 Subject: [PATCH 9/9] =?UTF-8?q?=F0=9F=A7=B9=20chore:=20harden=20shutdown?= =?UTF-8?q?=20hook=20logger=20helper?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Agent-Logs-Url: https://github.com/gofiber/fiber/sessions/e5263c7d-b450-40dd-9aaf-31d81e0efb5d Co-authored-by: gaby <835733+gaby@users.noreply.github.com> --- services_test.go | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/services_test.go b/services_test.go index 1f9485d6cd7..f03db3333e0 100644 --- a/services_test.go +++ b/services_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "io" stdlog "log" //nolint:depguard // Test needs the concrete stdlib logger type to restore the previous writer. "strings" "sync" @@ -44,6 +45,27 @@ type stringsLogger struct { var servicesTestLogOutputMu sync.Mutex +func withCapturedLogOutput(t *testing.T, writer io.Writer) { + t.Helper() + + servicesTestLogOutputMu.Lock() + cleanupRegistered := false + defer func() { + if !cleanupRegistered { + servicesTestLogOutputMu.Unlock() + } + }() + + currentOutput := log.DefaultLogger[*stdlog.Logger]().Logger().Writer() + t.Cleanup(func() { + log.SetOutput(currentOutput) + servicesTestLogOutputMu.Unlock() + }) + cleanupRegistered = true + + log.SetOutput(writer) +} + func (*shutdownHookStorage) GetWithContext(context.Context, string) ([]byte, error) { return nil, nil } @@ -271,13 +293,7 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) var buf stringsLogger - currentOutput := log.DefaultLogger[*stdlog.Logger]().Logger().Writer() - servicesTestLogOutputMu.Lock() - log.SetOutput(&buf) - t.Cleanup(func() { - log.SetOutput(currentOutput) - servicesTestLogOutputMu.Unlock() - }) + withCapturedLogOutput(t, &buf) app.Hooks().executeOnPostShutdownHooks(nil) @@ -300,13 +316,7 @@ func Test_InitServices(t *testing.T) { require.NotPanics(t, app.initServices) var buf stringsLogger - currentOutput := log.DefaultLogger[*stdlog.Logger]().Logger().Writer() - servicesTestLogOutputMu.Lock() - log.SetOutput(&buf) - t.Cleanup(func() { - log.SetOutput(currentOutput) - servicesTestLogOutputMu.Unlock() - }) + withCapturedLogOutput(t, &buf) app.Hooks().executeOnPostShutdownHooks(nil)