diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 4403def9949..4bf824ce6e1 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -181,6 +181,7 @@ set(ARROW_SRCS util/bitmap_builders.cc util/bitmap_ops.cc util/bpacking.cc + util/cancel.cc util/compression.cc util/cpu_info.cc util/decimal.cc diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index cfc5eb1e345..0f02cb57a23 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -68,6 +68,9 @@ std::string Status::CodeAsString(StatusCode code) { case StatusCode::Invalid: type = "Invalid"; break; + case StatusCode::Cancelled: + type = "Cancelled"; + break; case StatusCode::IOError: type = "IOError"; break; diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 36f1b9023da..43879e6c6a3 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -83,6 +83,7 @@ enum class StatusCode : char { IOError = 5, CapacityError = 6, IndexError = 7, + Cancelled = 8, UnknownError = 9, NotImplemented = 10, SerializationError = 11, @@ -204,6 +205,12 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< return Status::FromArgs(StatusCode::Invalid, std::forward(args)...); } + /// Return an error status for cancelled operation + template + static Status Cancelled(Args&&... args) { + return Status::FromArgs(StatusCode::Cancelled, std::forward(args)...); + } + /// Return an error status when an index is out of bounds template static Status IndexError(Args&&... args) { @@ -263,6 +270,8 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< bool IsKeyError() const { return code() == StatusCode::KeyError; } /// Return true iff the status indicates invalid data. bool IsInvalid() const { return code() == StatusCode::Invalid; } + /// Return true iff the status indicates a cancelled operation. + bool IsCancelled() const { return code() == StatusCode::Cancelled; } /// Return true iff the status indicates an IO-related failure. bool IsIOError() const { return code() == StatusCode::IOError; } /// Return true iff the status indicates a container reaching capacity limits. diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index f5c658d08f2..3f1e8b97ae1 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -67,9 +67,10 @@ add_arrow_test(utility-test add_arrow_test(threading-utility-test SOURCES - future_test - task_group_test - thread_pool_test) + cancel_test.cc + future_test.cc + task_group_test.cc + thread_pool_test.cc) add_arrow_benchmark(bit_block_counter_benchmark) add_arrow_benchmark(bit_util_benchmark) diff --git a/cpp/src/arrow/util/cancel.cc b/cpp/src/arrow/util/cancel.cc new file mode 100644 index 00000000000..292dcfc15d9 --- /dev/null +++ b/cpp/src/arrow/util/cancel.cc @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "arrow/util/cancel.h" +#include "arrow/util/logging.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +StopCallback::StopCallback(StopToken* token, Callable cb) + : token_(token), cb_(std::move(cb)) { + if (token_ != nullptr) { + DCHECK(cb_); + // May call *this + token_->SetCallback(this); + } +} + +StopCallback::~StopCallback() { + if (token_ != nullptr) { + token_->RemoveCallback(this); + } +} + +StopCallback::StopCallback(StopCallback&& other) { *this = std::move(other); } + +StopCallback& StopCallback::operator=(StopCallback&& other) { + token_ = other.token_; + if (token_ != nullptr) { + other.token_ = nullptr; + token_->RemoveCallback(&other); + } + cb_ = std::move(other.cb_); + if (token_ != nullptr) { + // May call *this + token_->SetCallback(this); + } + return *this; +} + +void StopCallback::Call(const Status& st) { + if (cb_) { + // Forget callable after calling it + Callable local_cb; + cb_.swap(local_cb); + local_cb(st); + } +} + +// NOTE: We care mainly about the making the common case (not cancelled) fast. + +struct StopToken::Impl { + std::atomic requested_{false}; + std::mutex mutex_; + StopCallback* cb_{nullptr}; + Status cancel_error_; +}; + +StopToken::StopToken() : impl_(new Impl()) {} + +StopToken::~StopToken() {} + +Status StopToken::Poll() { + if (impl_->requested_) { + std::lock_guard lock(impl_->mutex_); + return impl_->cancel_error_; + } + return Status::OK(); +} + +bool StopToken::IsStopRequested() { return impl_->requested_; } + +void StopToken::RequestStop() { RequestStop(Status::Cancelled("Operation cancelled")); } + +void StopToken::RequestStop(Status st) { + std::lock_guard lock(impl_->mutex_); + DCHECK(!st.ok()); + if (!impl_->requested_) { + impl_->requested_ = true; + impl_->cancel_error_ = std::move(st); + if (impl_->cb_) { + impl_->cb_->Call(impl_->cancel_error_); + } + } +} + +StopCallback StopToken::SetCallback(StopCallback::Callable cb) { + return StopCallback(this, std::move(cb)); +} + +void StopToken::SetCallback(StopCallback* cb) { + std::lock_guard lock(impl_->mutex_); + DCHECK_EQ(impl_->cb_, nullptr); + impl_->cb_ = cb; + if (impl_->requested_) { + impl_->cb_->Call(impl_->cancel_error_); + } +} + +void StopToken::RemoveCallback(StopCallback* cb) { + std::lock_guard lock(impl_->mutex_); + DCHECK_EQ(impl_->cb_, cb); + impl_->cb_ = nullptr; +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/cancel.h b/cpp/src/arrow/util/cancel.h new file mode 100644 index 00000000000..f6431e48c07 --- /dev/null +++ b/cpp/src/arrow/util/cancel.h @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class StopToken; + +// A RAII wrapper that automatically registers and unregisters +// a callback to a StopToken. +class ARROW_MUST_USE_TYPE ARROW_EXPORT StopCallback { + public: + using Callable = std::function; + StopCallback(StopToken* token, Callable cb); + ~StopCallback(); + + ARROW_DISALLOW_COPY_AND_ASSIGN(StopCallback); + StopCallback(StopCallback&&); + StopCallback& operator=(StopCallback&&); + + void Call(const Status&); + + protected: + StopToken* token_; + Callable cb_; +}; + +class ARROW_EXPORT StopToken { + public: + StopToken(); + ~StopToken(); + ARROW_DISALLOW_COPY_AND_ASSIGN(StopToken); + + // NOTE: all APIs here are non-blocking. For consumers, waiting is done + // at a higher level using e.g. Future. Producers don't have to wait + // on a StopToken. + + // Consumer API (the side that stops) + void RequestStop(); + void RequestStop(Status error); + + // Producer API (the side that gets asked to stopped) + Status Poll(); + bool IsStopRequested(); + + // Register a callback that will be called whenever cancellation happens. + // Note the callback may be called immediately, if cancellation was already + // requested. The callback will be unregistered when the returned object + // is destroyed. + StopCallback SetCallback(StopCallback::Callable cb); + + protected: + struct Impl; + std::unique_ptr impl_; + + void SetCallback(StopCallback* cb); + void RemoveCallback(StopCallback* cb); + + friend class StopCallback; +}; + +} // namespace arrow diff --git a/cpp/src/arrow/util/cancel_test.cc b/cpp/src/arrow/util/cancel_test.cc new file mode 100644 index 00000000000..4ea3bf34a87 --- /dev/null +++ b/cpp/src/arrow/util/cancel_test.cc @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/cancel.h" +#include "arrow/util/future.h" +#include "arrow/util/logging.h" + +namespace arrow { + +static constexpr double kLongWait = 5; // seconds + +class CancelTest : public ::testing::Test {}; + +TEST_F(CancelTest, TokenBasics) { + { + StopToken token; + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + token.RequestStop(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + } + { + StopToken token; + token.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(IOError, token.Poll()); + } +} + +TEST_F(CancelTest, RequestStopTwice) { + StopToken token; + token.RequestStop(); + // Second RequestStop() call is ignored + token.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); +} + +TEST_F(CancelTest, SetCallback) { + std::vector results; + StopToken token; + { + const auto cb = token.SetCallback([&](const Status& st) { results.push_back(1); }); + ASSERT_EQ(results.size(), 0); + } + { + const auto cb = token.SetCallback([&](const Status& st) { results.push_back(1); }); + ASSERT_EQ(results.size(), 0); + token.RequestStop(); + ASSERT_EQ(results, std::vector{1}); + token.RequestStop(); + ASSERT_EQ(results, std::vector{1}); + } + { + const auto cb = token.SetCallback([&](const Status& st) { results.push_back(2); }); + ASSERT_EQ(results, std::vector({1, 2})); + token.RequestStop(); + ASSERT_EQ(results, std::vector({1, 2})); + } +} + +TEST_F(CancelTest, StopCallbackMove) { + std::vector results; + StopToken token; + + StopCallback cb1(&token, [&](const Status& st) { results.push_back(1); }); + const auto cb2 = std::move(cb1); + + ASSERT_EQ(results.size(), 0); + token.RequestStop(); + ASSERT_EQ(results, std::vector{1}); +} + +TEST_F(CancelTest, ThreadedPollSuccess) { + constexpr int kNumThreads = 10; + + std::vector results(kNumThreads); + std::vector threads; + + StopToken token; + std::atomic terminate_flag{false}; + + const auto worker_func = [&](int thread_num) { + while (token.Poll().ok() && !terminate_flag) { + } + results[thread_num] = token.Poll(); + }; + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(std::bind(worker_func, i)); + } + + // Let the threads start and hammer on Poll() for a while + SleepFor(1e-2); + // Tell threads to stop + terminate_flag = true; + for (auto& thread : threads) { + thread.join(); + } + + for (const auto& st : results) { + ASSERT_OK(st); + } +} + +TEST_F(CancelTest, ThreadedPollCancel) { + constexpr int kNumThreads = 10; + + std::vector results(kNumThreads); + std::vector threads; + + StopToken token; + std::atomic terminate_flag{false}; + const auto stop_error = Status::IOError("Operation cancelled"); + + const auto worker_func = [&](int thread_num) { + while (token.Poll().ok() && !terminate_flag) { + } + results[thread_num] = token.Poll(); + }; + + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(std::bind(worker_func, i)); + } + // Let the threads start + SleepFor(1e-2); + // Cancel token while threads are hammering on Poll() + token.RequestStop(stop_error); + // Tell threads to stop + terminate_flag = true; + for (auto& thread : threads) { + thread.join(); + } + + for (const auto& st : results) { + ASSERT_EQ(st, stop_error); + } +} + +TEST_F(CancelTest, ThreadedSetCallbackCancel) { + constexpr int kIterations = 100; + constexpr double kMaxWait = 1e-3; + + std::default_random_engine gen(42); + std::uniform_real_distribution wait_dist(0.0, kMaxWait); + + for (int i = 0; i < kIterations; ++i) { + Status result; + + StopToken token; + auto barrier = Future<>::Make(); + const auto stop_error = Status::IOError("Operation cancelled"); + + const auto worker_func = [&]() { + ARROW_CHECK(barrier.Wait(kLongWait)); + token.RequestStop(stop_error); + }; + std::thread thread(worker_func); + + // Unblock thread + barrier.MarkFinished(); + // Use a variable wait time to maximize potential synchronization issues + const auto wait_time = wait_dist(gen); + if (wait_time > kMaxWait * 0.5) { + SleepFor(wait_time); + } + + // Register callback while thread might be cancelling + StopCallback stop_cb(&token, [&](const Status& st) { result = st; }); + thread.join(); + + ASSERT_EQ(result, stop_error); + } +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc index 3a77f34e68f..b177f4f6bd5 100644 --- a/cpp/src/arrow/util/future.cc +++ b/cpp/src/arrow/util/future.cc @@ -225,9 +225,7 @@ class ConcreteFutureImpl : public FutureImpl { waiter_ = nullptr; } - void DoMarkFinished() { DoMarkFinishedOrFailed(FutureState::SUCCESS); } - - void DoMarkFailed() { DoMarkFinishedOrFailed(FutureState::FAILURE); } + void DoMarkFinished(FutureState state) { DoMarkFinishedOrFailed(state); } void AddCallback(Callback callback) { std::unique_lock lock(mutex_); @@ -304,14 +302,18 @@ ConcreteFutureImpl* GetConcreteFuture(FutureImpl* future) { } // namespace -std::unique_ptr FutureImpl::Make() { - return std::unique_ptr(new ConcreteFutureImpl()); +std::unique_ptr FutureImpl::Make(bool cancellable) { + auto ptr = new ConcreteFutureImpl(); + if (cancellable) { + ptr->stop_token_.emplace(); + } + return std::unique_ptr(ptr); } std::unique_ptr FutureImpl::MakeFinished(FutureState state) { - std::unique_ptr ptr(new ConcreteFutureImpl()); + auto ptr = new ConcreteFutureImpl(); ptr->state_ = state; - return std::move(ptr); + return std::unique_ptr(ptr); } FutureImpl::FutureImpl() : state_(FutureState::PENDING) {} @@ -328,9 +330,9 @@ void FutureImpl::Wait() { GetConcreteFuture(this)->DoWait(); } bool FutureImpl::Wait(double seconds) { return GetConcreteFuture(this)->DoWait(seconds); } -void FutureImpl::MarkFinished() { GetConcreteFuture(this)->DoMarkFinished(); } - -void FutureImpl::MarkFailed() { GetConcreteFuture(this)->DoMarkFailed(); } +void FutureImpl::MarkFinished(FutureState state) { + GetConcreteFuture(this)->DoMarkFinished(state); +} void FutureImpl::AddCallback(Callback callback) { GetConcreteFuture(this)->AddCallback(std::move(callback)); diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index ee053cf3096..82906fb7786 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -27,6 +27,7 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/util/cancel.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" #include "arrow/util/optional.h" @@ -130,7 +131,7 @@ struct ContinueFuture::ForReturnImpl> { } // namespace detail /// A Future's execution or completion status -enum class FutureState : int8_t { PENDING, SUCCESS, FAILURE }; +enum class FutureState : int8_t { PENDING, SUCCESS, FAILURE, CANCEL }; inline bool IsFutureFinished(FutureState state) { return state != FutureState::PENDING; } @@ -142,12 +143,11 @@ class ARROW_EXPORT FutureImpl { FutureState state() { return state_.load(); } - static std::unique_ptr Make(); + static std::unique_ptr Make(bool cancellable); static std::unique_ptr MakeFinished(FutureState state); // Future API - void MarkFinished(); - void MarkFailed(); + void MarkFinished(FutureState state); void Wait(); bool Wait(double seconds); @@ -167,6 +167,7 @@ class ARROW_EXPORT FutureImpl { Storage result_{NULLPTR, NULLPTR}; std::vector callbacks_; + util::optional stop_token_; }; // An object that waits on multiple futures at once. Only one waiter @@ -248,6 +249,7 @@ class ARROW_MUST_USE_TYPE Future { // of being able to presize a vector of Futures. Future() = default; + // ----------------------------------------------------------------------- // Consumer API bool is_valid() const { return impl_ != NULLPTR; } @@ -319,6 +321,22 @@ class ARROW_MUST_USE_TYPE Future { return impl_->Wait(seconds); } + template + bool Cancel(CancelArgs&&... args) { + auto& stop_token = impl_->stop_token_; + if (stop_token) { + stop_token->RequestStop(std::forward(args)...); + SetResult(stop_token->Poll()); + impl_->MarkFinished(FutureState::CANCEL); + return true; + } else { + return false; + } + } + + bool is_cancellable() const { return impl_->stop_token_.has_value(); } + + // ----------------------------------------------------------------------- // Producer API /// \brief Producer API: mark Future finished @@ -341,18 +359,22 @@ class ARROW_MUST_USE_TYPE Future { /// to memory leaks (for example, see Loop). static Future Make() { Future fut; - fut.impl_ = FutureImpl::Make(); + fut.impl_ = FutureImpl::Make(/*cancellable=*/false); + return fut; + } + + static Future MakeCancellable() { + Future fut; + fut.impl_ = FutureImpl::Make(/*cancellable=*/true); return fut; } /// \brief Producer API: instantiate a finished Future static Future MakeFinished(Result res) { + const auto state = + ARROW_PREDICT_TRUE(res.ok()) ? FutureState::SUCCESS : FutureState::FAILURE; Future fut; - if (ARROW_PREDICT_TRUE(res.ok())) { - fut.impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS); - } else { - fut.impl_ = FutureImpl::MakeFinished(FutureState::FAILURE); - } + fut.impl_ = FutureImpl::MakeFinished(state); fut.SetResult(std::move(res)); return fut; } @@ -493,6 +515,14 @@ class ARROW_MUST_USE_TYPE Future { }); } + StopToken* stop_token() { + if (impl_->stop_token_.has_value()) { + return &*impl_->stop_token_; + } else { + return NULLPTR; + } + } + protected: template struct Callback { @@ -515,13 +545,15 @@ class ARROW_MUST_USE_TYPE Future { } void DoMarkFinished(Result res) { - SetResult(std::move(res)); - - if (ARROW_PREDICT_TRUE(GetResult()->ok())) { - impl_->MarkFinished(); - } else { - impl_->MarkFailed(); + if (ARROW_PREDICT_FALSE(impl_->state() == FutureState::CANCEL)) { + // The consumer cancelled the future, but the producer didn't notice + // before finishing its task => ignore producer result. + return; } + SetResult(std::move(res)); + const auto state = ARROW_PREDICT_TRUE(GetResult()->ok()) ? FutureState::SUCCESS + : FutureState::FAILURE; + impl_->MarkFinished(state); } void CheckValid() const { diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 97b643316a7..0ef48c5d482 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -146,6 +146,13 @@ void AssertFailed(const Future& fut) { } } +// Assert the future is cancelled *now* +template +void AssertCancelled(const Future& fut) { + ASSERT_EQ(fut.state(), FutureState::CANCEL); + ASSERT_FALSE(fut.status().ok()); +} + template struct IteratorResults { std::vector values; @@ -987,6 +994,35 @@ TEST(FutureCompletionTest, FutureVoid) { } } +TEST(FutureSyncTest, Cancellation) { + { + auto fut = Future<>::Make(); + ASSERT_FALSE(fut.is_cancellable()); + ASSERT_EQ(fut.stop_token(), nullptr); + ASSERT_FALSE(fut.Cancel()); + AssertNotFinished(fut); + fut.MarkFinished(); + AssertSuccessful(fut); + } + { + auto fut = Future<>::MakeCancellable(); + ASSERT_TRUE(fut.is_cancellable()); + StopToken* stop_token = fut.stop_token(); + ASSERT_NE(stop_token, nullptr); + const Status error = Status::Cancelled("charabia"); + ASSERT_TRUE(fut.Cancel(error)); + AssertCancelled(fut); + ASSERT_EQ(fut.status(), error); + // The producer may still try to set another status, it will be ignored + fut.MarkFinished(Status::IOError("xxx")); + AssertCancelled(fut); + ASSERT_EQ(fut.status(), error); + fut.MarkFinished(Status::OK()); + AssertCancelled(fut); + ASSERT_EQ(fut.status(), error); + } +} + TEST(FutureAllTest, Simple) { auto f1 = Future::Make(); auto f2 = Future::Make(); diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 33eb937ba43..7454fb2bfef 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -34,6 +34,15 @@ namespace internal { Executor::~Executor() = default; +namespace { + +struct Task { + FnOnce callable; + StopToken* stop_token; +}; + +} // namespace + struct ThreadPool::State { State() = default; @@ -47,7 +56,7 @@ struct ThreadPool::State { std::list workers_; // Trashcan for finished threads std::vector finished_workers_; - std::deque> pending_tasks_; + std::deque pending_tasks_; // Desired number of threads int desired_capacity_ = 0; @@ -91,12 +100,15 @@ static void WorkerLoop(std::shared_ptr state, --state->ready_count_; DCHECK_GE(state->ready_count_, 0); { - FnOnce task = std::move(state->pending_tasks_.front()); + Task task = std::move(state->pending_tasks_.front()); state->pending_tasks_.pop_front(); - lock.unlock(); - std::move(task)(); + StopToken* stop_token = task.stop_token; + if (!stop_token || !stop_token->IsStopRequested()) { + lock.unlock(); + std::move(task.callable)(); + lock.lock(); + } } - lock.lock(); ++state->ready_count_; } // Now either the queue is empty *or* a quick shutdown was requested @@ -242,7 +254,8 @@ void ThreadPool::LaunchWorkersUnlocked(int threads) { } } -Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task) { +Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task, + StopToken* stop_token) { { ProtectAgainstFork(); std::lock_guard lock(state_->mutex_); @@ -256,7 +269,7 @@ Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task) { // spawn one more thread. LaunchWorkersUnlocked(/*threads=*/1); } - state_->pending_tasks_.push_back(std::move(task)); + state_->pending_tasks_.push_back({std::move(task), stop_token}); } state_->cv_.notify_one(); return Status::OK(); diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 5db3a9a4722..d9b585e73df 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -115,7 +115,7 @@ class ARROW_EXPORT Executor { typename FutureType = typename ::arrow::detail::ContinueFuture::ForSignature< Function && (Args && ...)>> Result Submit(TaskHints hints, Function&& func, Args&&... args) { - auto future = FutureType::Make(); + auto future = FutureType::MakeCancellable(); auto task = std::bind(::arrow::detail::ContinueFuture{}, future, std::forward(func), std::forward(args)...); @@ -141,7 +141,8 @@ class ARROW_EXPORT Executor { Executor() = default; // Subclassing API - virtual Status SpawnReal(TaskHints hints, FnOnce task) = 0; + virtual Status SpawnReal(TaskHints hints, FnOnce task, + StopToken* = NULLPTR) = 0; }; // An Executor implementation spawning tasks in FIFO manner on a fixed-size @@ -192,7 +193,7 @@ class ARROW_EXPORT ThreadPool : public Executor { ThreadPool(); - Status SpawnReal(TaskHints hints, FnOnce task) override; + Status SpawnReal(TaskHints hints, FnOnce task, StopToken* = NULLPTR) override; // Collect finished worker threads, making sure the OS threads have exited void CollectFinishedWorkersUnlocked();