diff --git a/src/node_api.cc b/src/node_api.cc index 11b9e44233d94c..cd17f7c199dae5 100644 --- a/src/node_api.cc +++ b/src/node_api.cc @@ -200,7 +200,7 @@ class BufferFinalizer : private Finalizer { ~BufferFinalizer() { env()->Unref(); } }; -class ThreadSafeFunction : public node::AsyncResource { +class ThreadSafeFunction { public: ThreadSafeFunction(v8::Local func, v8::Local resource, @@ -212,11 +212,12 @@ class ThreadSafeFunction : public node::AsyncResource { void* finalize_data_, napi_finalize finalize_cb_, napi_threadsafe_function_call_js call_js_cb_) - : AsyncResource(env_->isolate, - resource, - node::Utf8Value(env_->isolate, name).ToStringView()), + : async_resource(std::in_place, + env_->isolate, + resource, + node::Utf8Value(env_->isolate, name).ToStringView()), thread_count(thread_count_), - is_closing(false), + state(kOpen), dispatch_state(kDispatchIdle), context(context_), max_queue_size(max_queue_size_), @@ -230,76 +231,104 @@ class ThreadSafeFunction : public node::AsyncResource { env->Ref(); } - ~ThreadSafeFunction() override { - node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this); - env->Unref(); - } + ~ThreadSafeFunction() { ReleaseResources(); } // These methods can be called from any thread. napi_status Push(void* data, napi_threadsafe_function_call_mode mode) { - node::Mutex::ScopedLock lock(this->mutex); + { + node::Mutex::ScopedLock lock(this->mutex); - while (queue.size() >= max_queue_size && max_queue_size > 0 && - !is_closing) { - if (mode == napi_tsfn_nonblocking) { - return napi_queue_full; + while (queue.size() >= max_queue_size && max_queue_size > 0 && + state == kOpen) { + if (mode == napi_tsfn_nonblocking) { + return napi_queue_full; + } + cond->Wait(lock); } - cond->Wait(lock); - } - if (is_closing) { + if (state == kOpen) { + queue.push(data); + Send(); + return napi_ok; + } if (thread_count == 0) { return napi_invalid_arg; - } else { - thread_count--; + } + thread_count--; + if (!(state == kClosed && thread_count == 0)) { return napi_closing; } - } else { - queue.push(data); - Send(); - return napi_ok; } + // Make sure to release lock before destroying + delete this; + return napi_closing; } napi_status Acquire() { node::Mutex::ScopedLock lock(this->mutex); - if (is_closing) { - return napi_closing; - } + if (state == kOpen) { + thread_count++; - thread_count++; + return napi_ok; + } - return napi_ok; + return napi_closing; } napi_status Release(napi_threadsafe_function_release_mode mode) { - node::Mutex::ScopedLock lock(this->mutex); + { + node::Mutex::ScopedLock lock(this->mutex); - if (thread_count == 0) { - return napi_invalid_arg; - } + if (thread_count == 0) { + return napi_invalid_arg; + } - thread_count--; + thread_count--; - if (thread_count == 0 || mode == napi_tsfn_abort) { - if (!is_closing) { - is_closing = (mode == napi_tsfn_abort); - if (is_closing && max_queue_size > 0) { - cond->Signal(lock); + if (thread_count == 0 || mode == napi_tsfn_abort) { + if (state == kOpen) { + if (mode == napi_tsfn_abort) { + state = kClosing; + } + if (state == kClosing && max_queue_size > 0) { + cond->Signal(lock); + } + Send(); } - Send(); } - } + if (!(state == kClosed && thread_count == 0)) { + return napi_ok; + } + } + // Make sure to release lock before destroying + delete this; return napi_ok; } - void EmptyQueueAndDelete() { - for (; !queue.empty(); queue.pop()) { - call_js_cb(nullptr, nullptr, context, queue.front()); + void EmptyQueueAndMaybeDelete() { + std::queue drain_queue; + { + node::Mutex::ScopedLock lock(this->mutex); + queue.swap(drain_queue); } + for (; !drain_queue.empty(); drain_queue.pop()) { + call_js_cb(nullptr, nullptr, context, drain_queue.front()); + } + { + node::Mutex::ScopedLock lock(this->mutex); + if (thread_count > 0) { + // At this point this TSFN is effectively done, but we need to keep + // it alive for other threads that still have pointers to it until + // they release them. + // But we already release all the resources that we can at this point + ReleaseResources(); + return; + } + } + // Make sure to release lock before destroying delete this; } @@ -351,6 +380,16 @@ class ThreadSafeFunction : public node::AsyncResource { inline void* Context() { return context; } protected: + void ReleaseResources() { + if (state != kClosed) { + state = kClosed; + ref.Reset(); + node::RemoveEnvironmentCleanupHook(env->isolate, Cleanup, this); + env->Unref(); + async_resource.reset(); + } + } + void Dispatch() { bool has_more = true; @@ -379,9 +418,7 @@ class ThreadSafeFunction : public node::AsyncResource { { node::Mutex::ScopedLock lock(this->mutex); - if (is_closing) { - CloseHandlesAndMaybeDelete(); - } else { + if (state == kOpen) { size_t size = queue.size(); if (size > 0) { data = queue.front(); @@ -395,7 +432,7 @@ class ThreadSafeFunction : public node::AsyncResource { if (size == 0) { if (thread_count == 0) { - is_closing = true; + state = kClosing; if (max_queue_size > 0) { cond->Signal(lock); } @@ -404,12 +441,14 @@ class ThreadSafeFunction : public node::AsyncResource { } else { has_more = true; } + } else { + CloseHandlesAndMaybeDelete(); } } if (popped_value) { v8::HandleScope scope(env->isolate); - CallbackScope cb_scope(this); + AsyncResource::CallbackScope cb_scope(&*async_resource); napi_value js_callback = nullptr; if (!ref.IsEmpty()) { v8::Local js_cb = @@ -426,17 +465,17 @@ class ThreadSafeFunction : public node::AsyncResource { void Finalize() { v8::HandleScope scope(env->isolate); if (finalize_cb) { - CallbackScope cb_scope(this); + AsyncResource::CallbackScope cb_scope(&*async_resource); env->CallFinalizer(finalize_cb, finalize_data, context); } - EmptyQueueAndDelete(); + EmptyQueueAndMaybeDelete(); } void CloseHandlesAndMaybeDelete(bool set_closing = false) { v8::HandleScope scope(env->isolate); if (set_closing) { node::Mutex::ScopedLock lock(this->mutex); - is_closing = true; + state = kClosing; if (max_queue_size > 0) { cond->Signal(lock); } @@ -501,19 +540,30 @@ class ThreadSafeFunction : public node::AsyncResource { } private: + // Needed because node::AsyncResource::CallbackScope is protected + class AsyncResource : public node::AsyncResource { + public: + using node::AsyncResource::AsyncResource; + using node::AsyncResource::CallbackScope; + }; + + enum State : unsigned char { kOpen, kClosing, kClosed }; + static const unsigned char kDispatchIdle = 0; static const unsigned char kDispatchRunning = 1 << 0; static const unsigned char kDispatchPending = 1 << 1; static const unsigned int kMaxIterationCount = 1000; + std::optional async_resource; + // These are variables protected by the mutex. node::Mutex mutex; std::unique_ptr cond; std::queue queue; uv_async_t async; size_t thread_count; - bool is_closing; + State state; std::atomic_uchar dispatch_state; // These are variables set once, upon creation, and then never again, which diff --git a/test/node-api/test_threadsafe_function_shutdown/binding.cc b/test/node-api/test_threadsafe_function_shutdown/binding.cc new file mode 100644 index 00000000000000..11696b849a0713 --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/binding.cc @@ -0,0 +1,83 @@ +#include +#include +#include + +#include +#include +#include +#include // NOLINT(build/c++11) +#include +#include + +template +inline auto call(const char* name, Args&&... args) -> R { + napi_status status; + if constexpr (std::is_same_v) { + status = func(std::forward(args)...); + if (status == napi_ok) { + return; + } + } else { + R ret; + status = func(std::forward(args)..., &ret); + if (status == napi_ok) { + return ret; + } + } + std::fprintf(stderr, "%s: %d\n", name, status); + std::abort(); +} + +#define NAPI_CALL(ret_type, func, ...) \ + call(#func, ##__VA_ARGS__) + +void thread_func(napi_threadsafe_function tsfn) { + fprintf(stderr, "thread_func: starting\n"); + auto status = + napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking); + while (status == napi_ok) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + status = napi_call_threadsafe_function(tsfn, nullptr, napi_tsfn_blocking); + } + fprintf(stderr, "thread_func: Got status %d, exiting...\n", status); +} + +void tsfn_callback(napi_env env, napi_value js_cb, void* ctx, void* data) { + if (env == nullptr) { + fprintf(stderr, "tsfn_callback: env=%p\n", env); + } +} + +void tsfn_finalize(napi_env env, void* finalize_data, void* finalize_hint) { + fprintf(stderr, "tsfn_finalize: env=%p\n", env); +} + +auto run(napi_env env, napi_callback_info info) -> napi_value { + auto global = NAPI_CALL(napi_value, napi_get_global, env); + auto undefined = NAPI_CALL(napi_value, napi_get_undefined, env); + auto n_threads = 32; + auto tsfn = NAPI_CALL(napi_threadsafe_function, + napi_create_threadsafe_function, + env, + nullptr, + global, + undefined, + 0, + n_threads, + nullptr, + tsfn_finalize, + nullptr, + tsfn_callback); + for (auto i = 0; i < n_threads; ++i) { + std::thread([tsfn] { thread_func(tsfn); }).detach(); + } + NAPI_CALL(void, napi_unref_threadsafe_function, env, tsfn); + return NAPI_CALL(napi_value, napi_get_undefined, env); +} + +napi_value init(napi_env env, napi_value exports) { + return NAPI_CALL( + napi_value, napi_create_function, env, nullptr, 0, run, nullptr); +} + +NAPI_MODULE(NODE_GYP_MODULE_NAME, init) diff --git a/test/node-api/test_threadsafe_function_shutdown/binding.gyp b/test/node-api/test_threadsafe_function_shutdown/binding.gyp new file mode 100644 index 00000000000000..eb08b447a94a86 --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/binding.gyp @@ -0,0 +1,11 @@ +{ + "targets": [ + { + "target_name": "binding", + "sources": ["binding.cc"], + "cflags_cc": ["--std=c++20"], + 'cflags!': [ '-fno-exceptions', '-fno-rtti' ], + 'cflags_cc!': [ '-fno-exceptions', '-fno-rtti' ], + } + ] +} diff --git a/test/node-api/test_threadsafe_function_shutdown/test.js b/test/node-api/test_threadsafe_function_shutdown/test.js new file mode 100644 index 00000000000000..9b2587b5bbdfde --- /dev/null +++ b/test/node-api/test_threadsafe_function_shutdown/test.js @@ -0,0 +1,17 @@ +'use strict'; + +const common = require('../../common'); +const process = require('process'); +const assert = require('assert'); +const { fork } = require('child_process'); +const binding = require(`./build/${common.buildType}/binding`); + +if (process.argv[2] === 'child') { + binding(); + setTimeout(() => {}, 100); +} else { + const child = fork(__filename, ['child']); + child.on('close', common.mustCall((code) => { + assert.strictEqual(code, 0); + })); +}