From 32555655d9debc094621d209a8da768f66c74823 Mon Sep 17 00:00:00 2001 From: James M Snell Date: Mon, 22 Dec 2025 15:45:51 -0800 Subject: [PATCH] Add state machine utility Separated out from https://github.com/cloudflare/workerd/pull/5670 --- src/workerd/api/BUILD.bazel | 1 + src/workerd/io/BUILD.bazel | 1 + src/workerd/util/BUILD.bazel | 12 + src/workerd/util/state-machine-test.c++ | 871 +++++++++ src/workerd/util/state-machine.h | 2148 +++++++++++++++++++++++ 5 files changed, 3033 insertions(+) create mode 100644 src/workerd/util/state-machine-test.c++ create mode 100644 src/workerd/util/state-machine.h diff --git a/src/workerd/api/BUILD.bazel b/src/workerd/api/BUILD.bazel index 4c51b354e8c..63663697197 100644 --- a/src/workerd/api/BUILD.bazel +++ b/src/workerd/api/BUILD.bazel @@ -248,6 +248,7 @@ wd_cc_library( visibility = ["//visibility:public"], deps = [ "//src/workerd/io", + "//src/workerd/util:state-machine", "@nbytes", ], ) diff --git a/src/workerd/io/BUILD.bazel b/src/workerd/io/BUILD.bazel index 02e46772056..23790a66924 100644 --- a/src/workerd/io/BUILD.bazel +++ b/src/workerd/io/BUILD.bazel @@ -116,6 +116,7 @@ wd_cc_library( "//src/workerd/util:ring-buffer", "//src/workerd/util:small-set", "//src/workerd/util:sqlite", + "//src/workerd/util:state-machine", "//src/workerd/util:strong-bool", "@capnp-cpp//src/capnp:capnp-rpc", "@capnp-cpp//src/capnp/compat:http-over-capnp", diff --git a/src/workerd/util/BUILD.bazel b/src/workerd/util/BUILD.bazel index 49a61219a7e..fc8e1dc8936 100644 --- a/src/workerd/util/BUILD.bazel +++ b/src/workerd/util/BUILD.bazel @@ -387,3 +387,15 @@ kj_test( src = "small-set-test.c++", deps = [":small-set"], ) + +wd_cc_library( + name = "state-machine", + hdrs = ["state-machine.h"], + visibility = ["//visibility:public"], + deps = ["@capnp-cpp//src/kj"], +) + +kj_test( + src = "state-machine-test.c++", + deps = [":state-machine"], +) diff --git a/src/workerd/util/state-machine-test.c++ b/src/workerd/util/state-machine-test.c++ new file mode 100644 index 00000000000..24cf32b52fd --- /dev/null +++ b/src/workerd/util/state-machine-test.c++ @@ -0,0 +1,871 @@ +// Copyright (c) 2017-2025 Cloudflare, Inc. +// Licensed under the Apache 2.0 license found in the LICENSE file or at: +// https://opensource.org/licenses/Apache-2.0 + +#include "state-machine.h" + +#include + +// Entire test file was Claude-generated initially. + +namespace workerd { +namespace { + +// ============================================================================= +// Test State Types +// ============================================================================= + +struct Idle { + static constexpr kj::StringPtr NAME KJ_UNUSED = "idle"_kj; + bool initialized = false; +}; + +struct Running { + static constexpr kj::StringPtr NAME KJ_UNUSED = "running"_kj; + kj::String taskName; + int progress = 0; + + Running() = default; + explicit Running(kj::String name): taskName(kj::mv(name)) {} +}; + +struct Completed { + static constexpr kj::StringPtr NAME KJ_UNUSED = "completed"_kj; + int result; + + explicit Completed(int r): result(r) {} +}; + +struct Failed { + static constexpr kj::StringPtr NAME KJ_UNUSED = "failed"_kj; + kj::String error; + + explicit Failed(kj::String err): error(kj::mv(err)) {} +}; + +// ============================================================================= +// Basic StateMachine Tests +// ============================================================================= + +KJ_TEST("StateMachine: basic state checks") { + auto machine = StateMachine::create(); + + // Initialized to Idle via create() + KJ_EXPECT(machine.isInitialized()); + KJ_EXPECT(machine.is()); + KJ_EXPECT(!machine.is()); +} + +KJ_TEST("StateMachine: state data access") { + auto machine = + StateMachine::create(kj::str("my-task")); + + KJ_EXPECT(machine.is()); + auto& running = machine.getUnsafe(); + KJ_EXPECT(running.taskName == "my-task"); + KJ_EXPECT(running.progress == 0); + + // Modify state data + running.progress = 50; + KJ_EXPECT(machine.getUnsafe().progress == 50); +} + +KJ_TEST("StateMachine: tryGet returns none for wrong state") { + auto machine = StateMachine::create(); + + // tryGet for correct state + KJ_IF_SOME(idle, machine.tryGetUnsafe()) { + KJ_EXPECT(!idle.initialized); + } else { + KJ_FAIL_EXPECT("Should have gotten Idle state"); + } + + // tryGet for wrong state + KJ_EXPECT(machine.tryGetUnsafe() == kj::none); + KJ_EXPECT(machine.tryGetUnsafe() == kj::none); +} + +KJ_TEST("StateMachine: isAnyOf checks multiple states") { + auto machine = StateMachine::create(42); + + // Use local variables to avoid KJ_EXPECT macro parsing issues with template brackets + bool isCompletedOrFailed = machine.isAnyOf(); + bool isIdleOrRunning = machine.isAnyOf(); + KJ_EXPECT(isCompletedOrFailed); + KJ_EXPECT(!isIdleOrRunning); + + machine.transitionTo(kj::str("error")); + isCompletedOrFailed = machine.isAnyOf(); + isIdleOrRunning = machine.isAnyOf(); + KJ_EXPECT(isCompletedOrFailed); + KJ_EXPECT(!isIdleOrRunning); +} + +KJ_TEST("StateMachine: transitionFromTo with precondition") { + auto machine = StateMachine::create(); + + // Transition from wrong state fails + auto result1 = machine.transitionFromTo(42); + KJ_EXPECT(result1 == kj::none); + KJ_EXPECT(machine.is()); // Still in Idle + + // Transition from correct state succeeds + machine.transitionTo(kj::str("task")); + auto result2 = machine.transitionFromTo(100); + KJ_EXPECT(result2 != kj::none); + KJ_EXPECT(machine.is()); + KJ_EXPECT(machine.getUnsafe().result == 100); +} + +KJ_TEST("StateMachine: factory create") { + auto machine = StateMachine::create(kj::str("task")); + KJ_EXPECT(machine.is()); + KJ_EXPECT(machine.getUnsafe().taskName == "task"); +} + +// Tests for uninitialized state behavior have been removed since the default +// constructor is now private and state machines must be created via create<>(). + +KJ_TEST("StateMachine: works with KJ_SWITCH_ONEOF") { + auto machine = StateMachine::create(kj::str("test")); + + kj::String result; + KJ_SWITCH_ONEOF(machine) { + KJ_CASE_ONEOF(idle, Idle) { + result = kj::str("idle"); + } + KJ_CASE_ONEOF(running, Running) { + result = kj::str("running: ", running.taskName); + } + KJ_CASE_ONEOF(completed, Completed) { + result = kj::str("completed: ", completed.result); + } + KJ_CASE_ONEOF(failed, Failed) { + result = kj::str("failed: ", failed.error); + } + } + + KJ_EXPECT(result == "running: test"); +} + +KJ_TEST("StateMachine: currentStateName introspection") { + auto machine = StateMachine::create(); + + // Each state + KJ_EXPECT(machine.currentStateName() == "idle"_kj); + + machine.transitionTo(kj::str("task")); + KJ_EXPECT(machine.currentStateName() == "running"_kj); + + machine.transitionTo(42); + KJ_EXPECT(machine.currentStateName() == "completed"_kj); + + machine.transitionTo(kj::str("error")); + KJ_EXPECT(machine.currentStateName() == "failed"_kj); +} + +// ============================================================================= +// Memory Safety Tests +// ============================================================================= + +KJ_TEST("StateMachine: whenState provides safe scoped access") { + auto machine = StateMachine::create(kj::str("task")); + + // whenState returns result and locks transitions + auto result = machine.whenState([](Running& r) { return r.taskName.size(); }); + KJ_EXPECT(result != kj::none); + KJ_EXPECT(KJ_ASSERT_NONNULL(result) == 4); + + // Returns none for wrong state + auto result2 = machine.whenState([](Idle& i) { return i.initialized; }); + KJ_EXPECT(result2 == kj::none); +} + +KJ_TEST("StateMachine: whenState blocks transitions during callback") { + auto machine = StateMachine::create(kj::str("task")); + + // Cannot transition while locked + auto tryTransitionInCallback = [&]() { + machine.whenState([&](Running&) { + // Attempting to transition while locked should throw + machine.transitionTo(42); + }); + }; + KJ_EXPECT_THROW_MESSAGE("transitions are locked", tryTransitionInCallback()); + + // State should still be Running (transition was blocked) + KJ_EXPECT(machine.is()); +} + +KJ_TEST("StateMachine: transition lock count is tracked") { + auto machine = StateMachine::create(); + + KJ_EXPECT(!machine.isTransitionLocked()); + + { + auto lock1 = machine.acquireTransitionLock(); + KJ_EXPECT(machine.isTransitionLocked()); + + { + auto lock2 = machine.acquireTransitionLock(); + KJ_EXPECT(machine.isTransitionLocked()); + } + + // Still locked after inner lock released + KJ_EXPECT(machine.isTransitionLocked()); + } + + // Fully unlocked + KJ_EXPECT(!machine.isTransitionLocked()); +} + +KJ_TEST("StateMachine: void whenState returns bool") { + auto machine = StateMachine::create(kj::str("task")); + + bool executed = false; + + // void callback returns true when executed + bool result = machine.whenState([&](Running&) { executed = true; }); + KJ_EXPECT(result == true); + KJ_EXPECT(executed); + + // void callback returns false when not in state + executed = false; + bool result2 = machine.whenState([&](Idle&) { executed = true; }); + KJ_EXPECT(result2 == false); + KJ_EXPECT(!executed); +} + +// ============================================================================= +// StateMachine Tests +// ============================================================================= + +// Test state types for resource lifecycle tests (TerminalStates, ErrorState, ActiveState, etc.) +struct Active { + static constexpr kj::StringPtr NAME KJ_UNUSED = "active"_kj; + kj::String resourceName; + explicit Active(kj::String name): resourceName(kj::mv(name)) {} +}; + +struct Closed { + static constexpr kj::StringPtr NAME KJ_UNUSED = "closed"_kj; +}; + +struct Errored { + static constexpr kj::StringPtr NAME KJ_UNUSED = "errored"_kj; + kj::String reason; + explicit Errored(kj::String r): reason(kj::mv(r)) {} +}; + +KJ_TEST("StateMachine: basic usage without specs") { + auto machine = StateMachine::create(kj::str("resource")); + + // Basic state operations work + KJ_EXPECT(machine.isInitialized()); + KJ_EXPECT(machine.is()); + KJ_EXPECT(machine.getUnsafe().resourceName == "resource"); + + machine.transitionTo(); + KJ_EXPECT(machine.is()); + + // Can transition back (no terminal enforcement without spec) + machine.transitionTo(kj::str("another")); + KJ_EXPECT(machine.is()); +} + +// Tests for uninitialized state behavior have been removed since the default +// constructor is now private and state machines must be created via create<>(). + +KJ_TEST("StateMachine: with TerminalStates spec") { + auto machine = + StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + KJ_EXPECT(!machine.isTerminal()); + + machine.transitionTo(); + KJ_EXPECT(machine.isTerminal()); + + // Cannot transition from terminal state + auto tryTransition = [&]() { machine.transitionTo(kj::str("another")); }; + KJ_EXPECT_THROW_MESSAGE("Cannot transition from terminal state", tryTransition()); + + // But forceTransitionTo works + machine.forceTransitionTo(kj::str("forced")); + KJ_EXPECT(machine.is()); +} + +KJ_TEST("StateMachine: with ErrorState spec") { + auto machine = StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + KJ_EXPECT(!machine.isErrored()); + KJ_EXPECT(machine.tryGetErrorUnsafe() == kj::none); + + machine.transitionTo(kj::str("something went wrong")); + KJ_EXPECT(machine.isErrored()); + + KJ_IF_SOME(err, machine.tryGetErrorUnsafe()) { + KJ_EXPECT(err.reason == "something went wrong"); + } else { + KJ_FAIL_EXPECT("Should have gotten error"); + } + + KJ_EXPECT(machine.getErrorUnsafe().reason == "something went wrong"); +} + +KJ_TEST("StateMachine: with ActiveState spec") { + auto machine = StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + KJ_EXPECT(machine.isActive()); + KJ_EXPECT(!machine.isInactive()); + + KJ_IF_SOME(active, machine.tryGetActiveUnsafe()) { + KJ_EXPECT(active.resourceName == "resource"); + } else { + KJ_FAIL_EXPECT("Should be active"); + } + + // whenActive executes and returns value + auto result = machine.whenActive([](Active& a) { return a.resourceName.size(); }); + KJ_EXPECT(result != kj::none); + KJ_EXPECT(KJ_ASSERT_NONNULL(result) == 8); // "resource" + + machine.transitionTo(); + KJ_EXPECT(!machine.isActive()); + KJ_EXPECT(machine.isInactive()); + + // whenActive returns none when not active + auto result2 = machine.whenActive([](Active& a) { return a.resourceName.size(); }); + KJ_EXPECT(result2 == kj::none); +} + +KJ_TEST("StateMachine: whenActiveOr") { + auto machine = StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + // whenActiveOr executes when active + auto result = machine.whenActiveOr([](Active& a) { return a.resourceName.size(); }, 0ul); + KJ_EXPECT(result == 8); + + // After close, returns default + machine.transitionTo(); + auto result2 = machine.whenActiveOr([](Active& a) { return a.resourceName.size(); }, 999ul); + KJ_EXPECT(result2 == 999); +} + +KJ_TEST("StateMachine: requireActiveUnsafe") { + auto machine = StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + // requireActiveUnsafeUnsafe returns reference when active + auto& active = machine.requireActiveUnsafe(); + KJ_EXPECT(active.resourceName == "resource"); + + // requireActiveUnsafe with custom message works when active + auto& active2 = machine.requireActiveUnsafe("Custom message"); + KJ_EXPECT(active2.resourceName == "resource"); + + machine.transitionTo(); + + // requireActiveUnsafe throws when not active + KJ_EXPECT_THROW_MESSAGE( + "State machine is not in the active state", (void)machine.requireActiveUnsafe()); + + // requireActiveUnsafe throws custom message when not active + KJ_EXPECT_THROW_MESSAGE( + "Stream is closed", (void)machine.requireActiveUnsafe("Stream is closed")); +} + +KJ_TEST("StateMachine: with PendingStates spec") { + auto machine = + StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + // Start an operation + machine.beginOperation(); + KJ_EXPECT(machine.hasOperationInProgress()); + + // Defer a close + bool immediate = machine.deferTransitionTo(); + KJ_EXPECT(!immediate); // Deferred + KJ_EXPECT(machine.is()); // Still active + KJ_EXPECT(machine.hasPendingState()); + KJ_EXPECT(machine.pendingStateIs()); + KJ_EXPECT(machine.isOrPending()); + + // End operation - pending state applied + bool applied = machine.endOperation(); + KJ_EXPECT(applied); + KJ_EXPECT(machine.is()); + KJ_EXPECT(!machine.hasPendingState()); +} + +KJ_TEST("StateMachine: with PendingStates scoped operation") { + auto machine = + StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + { + auto scope = machine.scopedOperation(); + KJ_EXPECT(machine.hasOperationInProgress()); + + auto _ KJ_UNUSED = machine.deferTransitionTo(); + KJ_EXPECT(machine.is()); // Still active in scope + } + + // Scope ended, pending state applied + KJ_EXPECT(machine.is()); +} + +KJ_TEST("StateMachine: full-featured stream-like usage") { + // This demonstrates the common stream pattern with all features + auto machine = StateMachine, ErrorState, + ActiveState, PendingStates, Active, Closed, + Errored>::create(kj::str("http-body")); + KJ_EXPECT(machine.isActive()); + KJ_EXPECT(!machine.isTerminal()); + KJ_EXPECT(!machine.isErrored()); + + // Safe access with whenActive + machine.whenActive([](Active& a) { a.resourceName = kj::str("modified"); }); + KJ_EXPECT(machine.getUnsafe().resourceName == "modified"); + + // Start a read operation + machine.beginOperation(); + + // Close is requested mid-operation - deferred + auto deferred KJ_UNUSED = machine.deferTransitionTo(); + KJ_EXPECT(machine.isActive()); // Still active! + KJ_EXPECT(machine.isOrPending()); + KJ_EXPECT(!machine.isTerminal()); // Not terminal yet + + // End operation - close applied + auto applied KJ_UNUSED = machine.endOperation(); + KJ_EXPECT(machine.is()); + KJ_EXPECT(machine.isTerminal()); + KJ_EXPECT(!machine.isActive()); + KJ_EXPECT(machine.isInactive()); + + // Cannot transition from terminal + auto tryTransition = [&]() { machine.transitionTo(kj::str("x")); }; + KJ_EXPECT_THROW_MESSAGE("Cannot transition from terminal state", tryTransition()); +} + +KJ_TEST("StateMachine: KJ_SWITCH_ONEOF works") { + auto machine = StateMachine::create(kj::str("test")); + + kj::String result; + KJ_SWITCH_ONEOF(machine) { + KJ_CASE_ONEOF(active, Active) { + result = kj::str("active: ", active.resourceName); + } + KJ_CASE_ONEOF(closed, Closed) { + result = kj::str("closed"); + } + KJ_CASE_ONEOF(errored, Errored) { + result = kj::str("errored: ", errored.reason); + } + } + + KJ_EXPECT(result == "active: test"); +} + +KJ_TEST("StateMachine: whenState locks transitions") { + auto machine = StateMachine::create(kj::str("resource")); + + // Cannot transition while locked + auto tryTransitionInCallback = [&]() { + machine.whenState([&](Active&) { machine.transitionTo(); }); + }; + KJ_EXPECT_THROW_MESSAGE("transitions are locked", tryTransitionInCallback()); + + // State unchanged + KJ_EXPECT(machine.is()); +} + +KJ_TEST("StateMachine: currentStateName") { + auto machine = StateMachine::create(kj::str("x")); + KJ_EXPECT(machine.currentStateName() == "active"_kj); + + machine.transitionTo(); + KJ_EXPECT(machine.currentStateName() == "closed"_kj); + + machine.transitionTo(kj::str("err")); + KJ_EXPECT(machine.currentStateName() == "errored"_kj); +} + +KJ_TEST("StateMachine: const whenState works") { + auto machine = StateMachine::create(kj::str("resource")); + + const auto& constMachine = machine; + + // Const whenState works and returns value + auto result = + constMachine.whenState([](const Active& a) { return a.resourceName.size(); }); + KJ_EXPECT(result != kj::none); + KJ_EXPECT(KJ_ASSERT_NONNULL(result) == 8); // "resource" + + // Const whenState returns none for wrong state + auto result2 = constMachine.whenState([](const Closed&) { return 42; }); + KJ_EXPECT(result2 == kj::none); +} + +KJ_TEST("StateMachine: deferTransitionTo respects terminal states") { + auto machine = StateMachine, PendingStates, + Active, Closed, Errored>::create(kj::str("resource")); + + // Close the machine (terminal state) + machine.transitionTo(); + KJ_EXPECT(machine.isTerminal()); + + // deferTransitionTo should also fail from terminal state + auto tryDeferTransition = [&]() { + auto _ KJ_UNUSED = machine.deferTransitionTo(kj::str("error")); + }; + KJ_EXPECT_THROW_MESSAGE("Cannot transition from terminal state", tryDeferTransition()); +} + +// ============================================================================= +// Streams Integration Example +// ============================================================================= +// This demonstrates how StateMachine could replace the separate +// state + readState pattern found in ReadableStreamInternalController. + +namespace stream_integration_example { + +// Simulated stream source (like ReadableStreamSource) +struct MockSource { + bool dataAvailable = true; + + kj::Maybe read() { + if (dataAvailable) { + dataAvailable = false; + return kj::str("data chunk"); + } + return kj::none; + } +}; + +// State types matching the streams pattern +struct Readable { + static constexpr kj::StringPtr NAME KJ_UNUSED = "readable"_kj; + kj::Own source; + + explicit Readable(kj::Own s): source(kj::mv(s)) {} +}; + +struct StreamClosed { + static constexpr kj::StringPtr NAME KJ_UNUSED = "closed"_kj; +}; + +struct StreamErrored { + static constexpr kj::StringPtr NAME KJ_UNUSED = "errored"_kj; + kj::String reason; + + explicit StreamErrored(kj::String r): reason(kj::mv(r)) {} +}; + +// Lock states (separate state machine in the real code) +struct Unlocked { + static constexpr kj::StringPtr NAME KJ_UNUSED = "unlocked"_kj; +}; + +struct Locked { + static constexpr kj::StringPtr NAME KJ_UNUSED = "locked"_kj; +}; + +struct ReaderLocked { + static constexpr kj::StringPtr NAME KJ_UNUSED = "reader_locked"_kj; + uint32_t readerId; + explicit ReaderLocked(uint32_t id): readerId(id) {} +}; + +// The full-featured state machine type for stream data state +using StreamDataState = StateMachine, + ErrorState, + ActiveState, + PendingStates, + Readable, + StreamClosed, + StreamErrored>; + +// Lock state machine (simpler) +using StreamLockState = StateMachine; + +// Simulated controller showing combined usage +class MockReadableStreamController { + public: + MockReadableStreamController() + : dataState(StreamDataState::create(kj::heap())), + lockState(StreamLockState::create()) {} + + explicit MockReadableStreamController(kj::Own source) + : dataState(StreamDataState::create(kj::mv(source))), + lockState(StreamLockState::create()) {} + + bool isReadable() const { + return dataState.isActive(); + } + + bool isClosedOrErrored() const { + return dataState.isTerminal(); + } + + bool isErrored() const { + return dataState.isErrored(); + } + + bool isLocked() const { + return !lockState.is(); + } + + kj::Maybe read() { + // Only read if in readable state and not already reading + if (!dataState.isActive()) { + return kj::none; + } + + // Start read operation (defers close/error during read) + auto op = dataState.scopedOperation(); + + // Safe access to source + KJ_IF_SOME(result, dataState.whenActive([](Readable& r) -> kj::Maybe { + return r.source->read(); + })) { + return kj::mv(result); + } + return kj::none; + } + + void close() { + if (dataState.isTerminal()) return; + + // If operation in progress, defer the close + auto _ KJ_UNUSED = dataState.deferTransitionTo(); + } + + void error(kj::String reason) { + if (dataState.isTerminal()) return; + + // Error takes precedence - force even if operation in progress + dataState.forceTransitionTo(kj::mv(reason)); + } + + bool acquireReaderLock(uint32_t readerId) { + if (isLocked()) return false; + lockState.transitionTo(readerId); + return true; + } + + void releaseReaderLock() { + lockState.transitionTo(); + } + + private: + StreamDataState dataState; + StreamLockState lockState; +}; + +} // namespace stream_integration_example + +KJ_TEST("StateMachine: stream integration example - basic flow") { + using namespace stream_integration_example; + + MockReadableStreamController controller(kj::heap()); + + KJ_EXPECT(controller.isReadable()); + KJ_EXPECT(!controller.isClosedOrErrored()); + KJ_EXPECT(!controller.isLocked()); + + // Acquire reader lock + KJ_EXPECT(controller.acquireReaderLock(123)); + KJ_EXPECT(controller.isLocked()); + + // Read data + auto chunk1 = controller.read(); + KJ_EXPECT(chunk1 != kj::none); + KJ_EXPECT(KJ_ASSERT_NONNULL(chunk1) == "data chunk"); + + // Second read returns none (source exhausted) + auto chunk2 = controller.read(); + KJ_EXPECT(chunk2 == kj::none); + + // Close the stream + controller.close(); + KJ_EXPECT(!controller.isReadable()); + KJ_EXPECT(controller.isClosedOrErrored()); + + // Release lock + controller.releaseReaderLock(); + KJ_EXPECT(!controller.isLocked()); +} + +KJ_TEST("StateMachine: stream integration example - close during read") { + using namespace stream_integration_example; + + MockReadableStreamController controller(kj::heap()); + + // This test demonstrates that if close() is called during a read operation, + // the close is deferred until the read completes. + // + // In a real implementation, this would be more complex with async operations, + // but the pattern is the same. + + // Simulate close being called while readable (no operation in progress) + controller.close(); + KJ_EXPECT(controller.isClosedOrErrored()); +} + +KJ_TEST("StateMachine: stream integration example - error handling") { + using namespace stream_integration_example; + + MockReadableStreamController controller(kj::heap()); + + // Error the stream + controller.error(kj::str("Network failure")); + + KJ_EXPECT(!controller.isReadable()); + KJ_EXPECT(controller.isClosedOrErrored()); + KJ_EXPECT(controller.isErrored()); + + // Reads after error return none + auto chunk = controller.read(); + KJ_EXPECT(chunk == kj::none); +} + +// ============================================================================= +// StateMachine Additional API Tests +// ============================================================================= + +KJ_TEST("StateMachine: visit method") { + auto machine = StateMachine::create(kj::str("resource")); + + // Visit with return value - note: visitor must return the same type for all states + size_t result = machine.visit([](auto& s) -> size_t { + using S = std::decay_t; + if constexpr (std::is_same_v) { + return s.resourceName.size(); + } else if constexpr (std::is_same_v) { + return 0; + } else { + return s.reason.size(); + } + }); + KJ_EXPECT(result == 8); // "resource" + + machine.transitionTo(); + result = machine.visit([](auto& s) -> size_t { + using S = std::decay_t; + if constexpr (std::is_same_v) { + return s.resourceName.size(); + } else if constexpr (std::is_same_v) { + return 0; + } else { + return s.reason.size(); + } + }); + KJ_EXPECT(result == 0); +} + +KJ_TEST("StateMachine: visit const method") { + auto machine = StateMachine::create(kj::str("test")); + + const auto& constMachine = machine; + size_t result = constMachine.visit([](const auto& s) -> size_t { + using S = std::decay_t; + if constexpr (std::is_same_v) { + return 1; + } else if constexpr (std::is_same_v) { + return 2; + } else { + return 3; + } + }); + KJ_EXPECT(result == 1); +} + +KJ_TEST("StateMachine: underlying accessor") { + auto machine = StateMachine::create(kj::str("resource")); + + // Access underlying kj::OneOf + auto& underlying = machine.underlying(); + KJ_EXPECT(underlying.is()); + KJ_EXPECT(underlying.get().resourceName == "resource"_kj); + + // Const access + const auto& constMachine = machine; + const auto& constUnderlying = constMachine.underlying(); + KJ_EXPECT(constUnderlying.is()); +} + +KJ_TEST("StateMachine: applyPendingStateImpl respects terminal") { + // When we force-transition to a terminal state during an operation, + // the pending state should be discarded on endOperation. + auto machine = StateMachine, PendingStates, + Active, Closed, Errored>::create(kj::str("resource")); + + // Start an operation + machine.beginOperation(); + + // Request a deferred close + auto _ KJ_UNUSED = machine.deferTransitionTo(); + KJ_EXPECT(machine.hasPendingState()); + KJ_EXPECT(machine.is()); + + // Force transition to error (terminal state) while operation is in progress + machine.forceTransitionTo(kj::str("forced error")); + KJ_EXPECT(machine.is()); + + // End operation - pending Close should be discarded since we're in terminal state + bool pendingApplied = machine.endOperation(); + KJ_EXPECT(!pendingApplied); // Pending was discarded, not applied + KJ_EXPECT(machine.is()); // Still in errored state + KJ_EXPECT(!machine.hasPendingState()); // Pending was cleared +} + +KJ_TEST("StateMachine: endOperation inside whenState throws") { + // This test verifies that ending an operation (which could apply a pending state) + // inside a whenState() callback throws an error. This prevents UAF where a + // transition invalidates the reference being used in the callback. + auto machine = + StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + // This pattern would cause UAF without the safety check: + // whenState gets reference to Active + // scopedOperation ends, applies pending state -> Active is destroyed + // callback continues using destroyed Active reference + auto tryUnsafePattern = [&]() { + machine.whenState([&](Active&) { + { + auto op = machine.scopedOperation(); + auto _ KJ_UNUSED = machine.deferTransitionTo(); + } // op destroyed here - endOperation() would apply pending state + }); + }; + + KJ_EXPECT_THROW_MESSAGE("transitions are locked", tryUnsafePattern()); + + // Verify the machine is still in a valid state (transition was blocked) + KJ_EXPECT(machine.is()); +} + +KJ_TEST("StateMachine: endOperation outside whenState works") { + // Verify the correct pattern still works: end operations outside whenState + auto machine = + StateMachine, Active, Closed, Errored>::create( + kj::str("resource")); + + { + auto op = machine.scopedOperation(); + machine.whenState([&](Active& a) { + // Safe to use 'a' here - no operation ending in this scope + KJ_EXPECT(a.resourceName == "resource"); + }); + auto _ KJ_UNUSED = machine.deferTransitionTo(); + } // op ends here, OUTSIDE any whenState callback - safe! + + KJ_EXPECT(machine.is()); +} + +} // namespace +} // namespace workerd diff --git a/src/workerd/util/state-machine.h b/src/workerd/util/state-machine.h new file mode 100644 index 00000000000..7baea57d4d9 --- /dev/null +++ b/src/workerd/util/state-machine.h @@ -0,0 +1,2148 @@ +// Copyright (c) 2017-2025 Cloudflare, Inc. +// Licensed under the Apache 2.0 license found in the LICENSE file or at: +// https://opensource.org/licenses/Apache-2.0 + +#pragma once + +// State Machine Abstraction built on kj::OneOf. +// TODO(later): If this proves useful, consider moving it into kj itself as there +// are no workerd-specific dependencies. +// +// Entire implementation was Claude-generated initially. +// +// Most of the detailed doc comments here are largely intended to be used by agents +// and tooling. Human readers may prefer to just skip to the actual code. +// +// This header provides utilities for building type-safe state machines using kj::OneOf. +// It addresses common patterns found throughout the workerd codebase with improvements +// that provide tangible benefits over raw kj::OneOf usage. +// +// ============================================================================= +// WHY USE THIS INSTEAD OF RAW kj::OneOf? +// ============================================================================= +// +// Throughout workerd, we use kj::OneOf as a state machine to track the lifecycle +// of streams, readers, writers, and other resources. A typical pattern looks like: +// +// kj::OneOf state; +// +// void read() { +// KJ_SWITCH_ONEOF(state) { +// KJ_CASE_ONEOF(readable, Readable) { +// auto data = readable.source->read(); // Get reference to state +// processData(data); // Call some function... +// readable.source->advance(); // Use reference again - UAF! +// } +// KJ_CASE_ONEOF(closed, Closed) { ... } +// KJ_CASE_ONEOF(err, kj::Exception) { ... } +// } +// } +// +// THE PROBLEM: Use-After-Free (UAF) from unsound state-transitions +// +// The `readable` reference points into the kj::OneOf's internal storage. If ANY +// code path between obtaining that reference and using it triggers a state +// transition (even indirectly through callbacks, promise continuations, or +// nested calls), the reference becomes dangling: +// +// KJ_CASE_ONEOF(readable, Readable) { +// readable.source->read(); // This might call back into our code... +// // ...which might call close()... +// // ...which does state.init() +// readable.buffer.size(); // UAF! readable is now destroyed +// } +// +// This is particularly insidious because: +// 1. The bug may not manifest in simple tests +// 2. It depends on complex callback chains that are hard to reason about +// 3. It causes memory corruption that may crash much later +// 4. ASAN/valgrind may not catch it if the memory is quickly reused +// +// HOW StateMachine HELPS: +// +// 1. TRANSITION LOCKING via whenState()/whenActive(): +// +// state.whenState([](Readable& r) { +// r.source->read(); // If this tries to transition... +// r.buffer.size(); // ...it throws instead of UAF +// }); +// +// The callback holds a "transition lock" - any attempt to transition the +// state machine while the lock is held will throw an exception instead of +// silently corrupting memory. This turns silent UAF into a loud, debuggable +// failure. +// +// 2. DEFERRED TRANSITIONS for async operations: +// +// When code legitimately needs to transition during an operation (e.g., +// a read discovers EOF and needs to close), use deferred transitions: +// +// { +// auto op = state.scopedOperation(); +// state.whenActive([&](Readable& r) { +// if (r.source->atEof()) { +// state.deferTransitionTo(); // Queued, not immediate +// } +// }); +// } // Transition happens here, after callback completes safely +// +// 3. TERMINAL STATE ENFORCEMENT: +// +// Once a stream is Closed or Errored, it should never transition back to +// Readable. Raw kj::OneOf allows this silently: +// +// state.init(); +// state.init(...); // Oops - zombie stream! +// +// StateMachine with TerminalStates<> will throw if you attempt this, +// catching the bug immediately. +// +// 4. SEMANTIC HELPERS: +// +// Instead of: state.is() || state.is() +// Write: state.isTerminal() or state.isInactive() +// +// Instead of: KJ_IF_SOME(e, state.tryGetUnsafe()) { ... } +// Write: KJ_IF_SOME(e, state.tryGetErrorUnsafe()) { ... } +// +// WHEN TO USE: +// +// - Simple state tracking: StateMachine is fine +// - Resource lifecycle (streams, handles): Use TerminalStates + PendingStates +// - Migrating existing code: See MIGRATION GUIDE section below +// +// ============================================================================= +// STATE MACHINE +// ============================================================================= +// +// StateMachine supports composable features via spec types: +// +// // Simple (no specs) +// StateMachine basic; +// +// // With terminal state enforcement +// StateMachine, Idle, Running, Done> withTerminal; +// +// // With error extraction helpers +// StateMachine, Active, Closed, Errored> withError; +// +// // With deferred transitions +// StateMachine, Active, Closed, Errored> withDefer; +// +// // Full-featured (combine any specs) +// StateMachine< +// TerminalStates, +// ErrorState, +// ActiveState, +// PendingStates, +// Active, Closed, Errored +// > fullyFeatured; +// +// Available spec types: +// - TerminalStates - States that cannot be transitioned FROM +// Enables: isTerminal() +// - ErrorState - Designates the error state type +// Enables: isErrored(), tryGetErrorUnsafe(), getErrorUnsafe() +// - ActiveState - Designates the active/working state type +// Enables: isActive(), isInactive(), whenActive(), whenActiveOr(), +// tryGetActiveUnsafe(), requireActiveUnsafe() +// - PendingStates - States that can be deferred during operations +// Enables: beginOperation(), endOperation(), deferTransitionTo(), etc. +// +// NAMING CONVENTIONS: +// - isTerminal() = current state is in TerminalStates (enforces no outgoing transitions) +// - isInactive() = current state is NOT the ActiveState (semantic "done" state) +// +// ============================================================================= +// MEMORY SAFETY +// ============================================================================= +// +// THREAD SAFETY: State machines are NOT thread-safe. All operations on a +// single state machine instance must be performed from the same thread. +// If you need concurrent access, use external synchronization. +// +// This utility provides protections against common memory safety issues: +// +// 1. TRANSITION LOCKING: The state machine can be locked during callbacks to +// prevent transitions that would invalidate references: +// +// machine.whenState([](Active& a) { +// // machine.transitionTo(); // Would fail - locked! +// a.resource->read(); // Safe - Active cannot be destroyed +// }); +// +// 2. TRANSITION LOCK ENFORCEMENT: The machine tracks active transition locks +// and throws if a transition is attempted while locks are held. +// +// 3. SAFE ACCESS PATTERNS: Prefer whenState() and whenActive() over get() +// to ensure references don't outlive their validity. +// +// UNSAFE PATTERNS TO AVOID: +// +// // DON'T: Store references from getUnsafe() across transitions +// Active& active = machine.getUnsafe(); +// machine.transitionTo(); // active is now dangling! +// +// // DO: Use whenState() for safe scoped access +// machine.whenState([](Active& a) { +// // a is guaranteed valid for the duration of the callback +// }); +// +// // DON'T: Transition inside a callback (will fail if locked) +// machine.whenState([&](Active& a) { +// machine.transitionTo(); // Fails! +// }); +// +// // DO: Return a value and transition after +// auto result = machine.whenState([](Active& a) { +// return a.computeSomething(); +// }); +// machine.transitionTo(); +// +// ============================================================================= +// QUICK START +// ============================================================================= +// +// Define your state types (add NAME for introspection): +// +// struct Readable { +// static constexpr kj::StringPtr NAME = "readable"_kj; +// kj::Own source; +// }; +// struct Closed { static constexpr kj::StringPtr NAME = "closed"_kj; }; +// struct Errored { +// static constexpr kj::StringPtr NAME = "errored"_kj; +// jsg::Value error; +// }; +// +// Basic state machine with safe access: +// +// StateMachine state; +// state.transitionTo(...); +// +// // RECOMMENDED: Use whenState() for safe scoped access +// state.whenState([](Readable& r) { +// r.source->read(); // Safe - transitions blocked during callback +// }); +// +// // Or with a return value +// auto size = state.whenState([](Readable& r) { +// return r.source->size(); +// }); // Returns kj::Maybe +// +// Stream-like state machine (common pattern in workerd): +// +// StateMachine< +// TerminalStates, +// ErrorState, +// ActiveState, +// PendingStates, +// Readable, Closed, Errored +// > state; +// +// state.transitionTo(...); +// +// // Safe access with whenActive() +// state.whenActive([](Readable& r) { +// r.source->doSomething(); // Transitions blocked +// }); +// +// // Error checking +// if (state.isErrored()) { ... } +// KJ_IF_SOME(err, state.tryGetErrorUnsafe()) { ... } +// +// // Deferred transitions during operations +// state.beginOperation(); +// state.deferTransitionTo(); // Deferred until operation ends +// state.endOperation(); // Now transitions to Closed +// +// // Terminal enforcement +// state.transitionTo(); +// state.transitionTo(...); // FAILS - can't leave terminal state +// +// ============================================================================= +// MIGRATION GUIDE: From kj::OneOf to StateMachine +// ============================================================================= +// +// This section describes how to migrate existing kj::OneOf state management +// to use these StateMachine utilities. +// +// STEP 1: Add NAME constants to state types +// ----------------------------------------- +// StateMachine provides currentStateName() for debugging. Add NAME to states: +// +// // Before: +// struct Closed {}; +// +// // After: +// struct Closed { +// static constexpr kj::StringPtr NAME = "Closed"_kj; +// }; +// +// STEP 2: Replace kj::OneOf with appropriate StateMachine +// -------------------------------------------------------- +// +// // Before: +// kj::OneOf state; +// +// // After (basic): +// StateMachine state; +// +// // After (with features): +// StateMachine< +// TerminalStates, +// ErrorState, +// ActiveState, +// Closed, Errored, Readable +// > state; +// +// STEP 3: Update state assignments to use transitionTo() +// ------------------------------------------------------ +// +// // Before: +// state = Closed{}; +// state = Errored{kj::mv(error)}; +// +// // After: +// state.transitionTo(); +// state.transitionTo(kj::mv(error)); +// +// STEP 4: Update state checks +// --------------------------- +// +// // Before: +// if (state.is() || state.is()) { ... } +// if (state.is()) { ... } +// +// // After (with ActiveState): +// if (state.isInactive()) { ... } // Not in active state +// +// // After (with ErrorState): +// if (state.isErrored()) { ... } +// +// STEP 5: Replace unsafe get() with safe access patterns +// ------------------------------------------------------ +// +// // Before (unsafe - reference may dangle if callback transitions): +// KJ_SWITCH_ONEOF(state) { +// KJ_CASE_ONEOF(readable, Readable) { +// readable.source->read(); // May be unsafe +// } +// } +// +// // After (safe - transitions blocked during callback): +// state.whenActive([](Readable& r) { +// r.source->read(); // Safe +// }); +// +// // Or for specific state: +// state.whenState([](Readable& r) { +// r.source->read(); +// }); +// +// STEP 6: Replace manual deferred-transition bookkeeping +// ------------------------------------------------------ +// If you have code that tracks pending operations and defers close/error: +// +// // Before: +// bool closing = false; +// int pendingOps = 0; +// +// void startOp() { pendingOps++; } +// void endOp() { +// if (--pendingOps == 0 && closing) doClose(); +// } +// void close() { +// if (pendingOps > 0) { closing = true; return; } +// doClose(); +// } +// +// // After (with PendingStates): +// void startOp() { state.beginOperation(); } +// void endOp() { state.endOperation(); } // Auto-applies pending +// void close() { state.deferTransitionTo(); } +// +// // Or with RAII: +// void doWork() { +// auto op = state.scopedOperation(); +// // ... work ... +// } // endOperation() called automatically +// +// STEP 7: Update visitForGc +// ------------------------- +// +// // Before: +// void visitForGc(jsg::GcVisitor& visitor) { +// KJ_SWITCH_ONEOF(state) { +// KJ_CASE_ONEOF(e, Errored) { visitor.visit(e.reason); } +// // ... +// } +// } +// +// // After: +// void visitForGc(jsg::GcVisitor& visitor) { +// state.visitForGc(visitor); // Visits all GC-able states automatically +// } +// +// STEP 8: KJ_SWITCH_ONEOF still works +// ----------------------------------- +// If you need to keep KJ_SWITCH_ONEOF for complex logic: +// +// KJ_SWITCH_ONEOF(state.underlying()) { +// KJ_CASE_ONEOF(r, Readable) { ... } +// KJ_CASE_ONEOF(c, Closed) { ... } +// KJ_CASE_ONEOF(e, Errored) { ... } +// } +// +// Or use the visitor pattern: +// +// state.visit([](auto& s) { +// using S = kj::Decay; +// if constexpr (kj::isSameType()) { ... } +// else if constexpr (kj::isSameType()) { ... } +// else { ... } +// }); +// +// ============================================================================= + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace workerd { + +// ============================================================================= +// Type Traits and Helpers +// ============================================================================= + +namespace _ { // private + +// Helper to check if a type is in a parameter pack +template +inline constexpr bool isOneOf = false; + +template +inline constexpr bool isOneOf = + kj::isSameType() || isOneOf; + +// Concept: type has a static NAME member of type kj::StringPtr +template +concept HasStateName = requires { + { T::NAME } -> std::convertible_to; +}; + +// Get state name, using NAME if available, otherwise a placeholder +template +constexpr kj::StringPtr getStateName() { + if constexpr (HasStateName) { + return T::NAME; + } else { + return "(unnamed)"_kj; + } +} + +} // namespace _ + +// ============================================================================= +// Spec Types for Composable Features +// ============================================================================= + +// Marker type to specify terminal states (cannot transition FROM these) +template +struct TerminalStates { + template + static constexpr bool contains = _::isOneOf; + + template + static bool isTerminal(const Machine& machine) { + return (machine.template is() || ...); + } +}; + +// Marker type to specify the error state (enables isErrored(), tryGetErrorUnsafe(), etc.) +// Note: Error states are implicitly terminal - you cannot transition out of an error state +// using normal transitions. Use forceTransitionTo() if you need to reset from an error. +template +struct ErrorState { + using Type = T; +}; + +// Marker type to specify the active state (enables isActive(), whenActive(), etc.) +template +struct ActiveState { + using Type = T; +}; + +// Marker type to specify which states can be pending/deferred +template +struct PendingStates { + template + static constexpr bool contains = _::isOneOf; +}; + +// ============================================================================= +// Spec Detection Traits +// ============================================================================= + +namespace _ { // private + +// Helper to detect template instantiations +template class Template> +inline constexpr bool isInstanceOf = false; + +template