diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index abd5428b3d7..2125b019f8e 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -181,6 +181,7 @@ set(ARROW_SRCS util/bitmap_builders.cc util/bitmap_ops.cc util/bpacking.cc + util/cancel.cc util/compression.cc util/cpu_info.cc util/decimal.cc diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index bbba60c79c1..9fc180137d9 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -313,12 +313,13 @@ class ReaderMixin { public: ReaderMixin(MemoryPool* pool, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) + const ConvertOptions& convert_options, StopToken stop_token) : pool_(pool), read_options_(read_options), parse_options_(parse_options), convert_options_(convert_options), - input_(std::move(input)) {} + input_(std::move(input)), + stop_token_(std::move(stop_token)) {} protected: // Read header and column names from buffer, create column builders @@ -500,6 +501,7 @@ class ReaderMixin { std::shared_ptr input_; std::shared_ptr task_group_; + StopToken stop_token_; }; ///////////////////////////////////////////////////////////////////////// @@ -697,7 +699,7 @@ class SerialStreamingReader : public BaseStreamingReader { ARROW_ASSIGN_OR_RAISE(auto rh_it, MakeReadaheadIterator(std::move(istream_it), block_queue_size)); buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it)); - task_group_ = internal::TaskGroup::MakeSerial(); + task_group_ = internal::TaskGroup::MakeSerial(stop_token_); // Read schema from first batch ARROW_ASSIGN_OR_RAISE(pending_batch_, ReadNext()); @@ -710,6 +712,10 @@ class SerialStreamingReader : public BaseStreamingReader { if (eof_) { return nullptr; } + if (stop_token_.IsStopRequested()) { + eof_ = true; + return stop_token_.Poll(); + } if (!block_iterator_) { Status st = SetupReader(); if (!st.ok()) { @@ -790,7 +796,7 @@ class SerialTableReader : public BaseTableReader { } Result> Read() override { - task_group_ = internal::TaskGroup::MakeSerial(); + task_group_ = internal::TaskGroup::MakeSerial(stop_token_); // First block ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next()); @@ -804,6 +810,8 @@ class SerialTableReader : public BaseTableReader { MakeChunker(parse_options_), std::move(first_buffer)); while (true) { + RETURN_NOT_OK(stop_token_.Poll()); + ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next()); if (maybe_block == IterationTraits::End()) { // EOF @@ -833,9 +841,10 @@ class AsyncThreadedTableReader AsyncThreadedTableReader(MemoryPool* pool, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options, Executor* cpu_executor, - Executor* io_executor) - : BaseTableReader(pool, input, read_options, parse_options, convert_options), + const ConvertOptions& convert_options, StopToken stop_token, + Executor* cpu_executor, Executor* io_executor) + : BaseTableReader(pool, input, read_options, parse_options, convert_options, + std::move(stop_token)), cpu_executor_(cpu_executor), io_executor_(io_executor) {} @@ -870,7 +879,7 @@ class AsyncThreadedTableReader Result> Read() override { return ReadAsync().result(); } Future> ReadAsync() override { - task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_); + task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_, stop_token_); auto self = shared_from_this(); return ProcessFirstBuffer().Then([self](std::shared_ptr first_buffer) { @@ -939,17 +948,30 @@ Result> MakeTableReader( if (read_options.use_threads) { auto cpu_executor = internal::GetCpuThreadPool(); auto io_executor = io_context.executor(); - reader = std::make_shared(pool, input, read_options, - parse_options, convert_options, - cpu_executor, io_executor); + reader = std::make_shared( + pool, input, read_options, parse_options, convert_options, + io_context.stop_token(), cpu_executor, io_executor); } else { - reader = std::make_shared(pool, input, read_options, parse_options, - convert_options); + reader = + std::make_shared(pool, input, read_options, parse_options, + convert_options, io_context.stop_token()); } RETURN_NOT_OK(reader->Init()); return reader; } +Result> MakeStreamingReader( + io::IOContext io_context, std::shared_ptr input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + std::shared_ptr reader; + reader = std::make_shared(io_context.pool(), input, read_options, + parse_options, convert_options, + io_context.stop_token()); + RETURN_NOT_OK(reader->Init()); + return reader; +} + } // namespace ///////////////////////////////////////////////////////////////////////// @@ -975,13 +997,17 @@ Result> StreamingReader::Make( MemoryPool* pool, std::shared_ptr input, const ReadOptions& read_options, const ParseOptions& parse_options, const ConvertOptions& convert_options) { - std::shared_ptr reader; - reader = std::make_shared(pool, input, read_options, - parse_options, convert_options); - RETURN_NOT_OK(reader->Init()); - return reader; + return MakeStreamingReader(io::IOContext(pool), std::move(input), read_options, + parse_options, convert_options); } -} // namespace csv +Result> StreamingReader::Make( + io::IOContext io_context, std::shared_ptr input, + const ReadOptions& read_options, const ParseOptions& parse_options, + const ConvertOptions& convert_options) { + return MakeStreamingReader(io_context, std::move(input), read_options, parse_options, + convert_options); +} +} // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/csv/reader.h b/cpp/src/arrow/csv/reader.h index b18dc04eb65..76a575b3349 100644 --- a/cpp/src/arrow/csv/reader.h +++ b/cpp/src/arrow/csv/reader.h @@ -52,7 +52,7 @@ class ARROW_EXPORT TableReader { const ParseOptions&, const ConvertOptions&); - ARROW_DEPRECATED("Use MemoryPool-less overload (the IOContext holds a pool already)") + ARROW_DEPRECATED("Use MemoryPool-less variant (the IOContext holds a pool already)") static Result> Make( MemoryPool* pool, io::IOContext io_context, std::shared_ptr input, const ReadOptions&, const ParseOptions&, const ConvertOptions&); @@ -67,6 +67,11 @@ class ARROW_EXPORT StreamingReader : public RecordBatchReader { /// /// Currently, the StreamingReader is always single-threaded (parallel /// readahead is not supported). + static Result> Make( + io::IOContext io_context, std::shared_ptr input, + const ReadOptions&, const ParseOptions&, const ConvertOptions&); + + ARROW_DEPRECATED("Use IOContext-based overload") static Result> Make( MemoryPool* pool, std::shared_ptr input, const ReadOptions&, const ParseOptions&, const ConvertOptions&); diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index fbbdfa246d6..0c023b87dcd 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -119,8 +119,9 @@ static inline Result> OpenReader( GetConvertOptions(format, scan_options, *first_block, pool)); } - auto maybe_reader = csv::StreamingReader::Make(pool, std::move(input), reader_options, - parse_options, convert_options); + auto maybe_reader = + csv::StreamingReader::Make(io::IOContext(pool), std::move(input), reader_options, + parse_options, convert_options); if (!maybe_reader.ok()) { return maybe_reader.status().WithMessage("Could not open CSV input source '", source.path(), "': ", maybe_reader.status()); diff --git a/cpp/src/arrow/io/interfaces.cc b/cpp/src/arrow/io/interfaces.cc index 22abbb27bce..dc2112ebddd 100644 --- a/cpp/src/arrow/io/interfaces.cc +++ b/cpp/src/arrow/io/interfaces.cc @@ -48,7 +48,8 @@ namespace io { static IOContext g_default_io_context{}; -IOContext::IOContext(MemoryPool* pool) : IOContext(pool, internal::GetIOThreadPool()) {} +IOContext::IOContext(MemoryPool* pool, StopToken stop_token) + : IOContext(pool, internal::GetIOThreadPool(), std::move(stop_token)) {} const IOContext& default_io_context() { return g_default_io_context; } diff --git a/cpp/src/arrow/io/interfaces.h b/cpp/src/arrow/io/interfaces.h index 07c01324ea1..0afd2f236b8 100644 --- a/cpp/src/arrow/io/interfaces.h +++ b/cpp/src/arrow/io/interfaces.h @@ -24,6 +24,7 @@ #include "arrow/io/type_fwd.h" #include "arrow/type_fwd.h" +#include "arrow/util/cancel.h" #include "arrow/util/macros.h" #include "arrow/util/string_view.h" #include "arrow/util/type_fwd.h" @@ -56,17 +57,28 @@ struct ReadRange { /// multiple sources and must distinguish tasks associated with this IOContext). struct ARROW_EXPORT IOContext { // No specified executor: will use a global IO thread pool - IOContext() : IOContext(default_memory_pool()) {} + IOContext() : IOContext(default_memory_pool(), StopToken::Unstoppable()) {} - // No specified executor: will use a global IO thread pool - explicit IOContext(MemoryPool* pool); + explicit IOContext(StopToken stop_token) + : IOContext(default_memory_pool(), std::move(stop_token)) {} + + explicit IOContext(MemoryPool* pool, StopToken stop_token = StopToken::Unstoppable()); explicit IOContext(MemoryPool* pool, ::arrow::internal::Executor* executor, + StopToken stop_token = StopToken::Unstoppable(), int64_t external_id = -1) - : pool_(pool), executor_(executor), external_id_(external_id) {} + : pool_(pool), + executor_(executor), + external_id_(external_id), + stop_token_(std::move(stop_token)) {} - explicit IOContext(::arrow::internal::Executor* executor, int64_t external_id = -1) - : pool_(default_memory_pool()), executor_(executor), external_id_(external_id) {} + explicit IOContext(::arrow::internal::Executor* executor, + StopToken stop_token = StopToken::Unstoppable(), + int64_t external_id = -1) + : pool_(default_memory_pool()), + executor_(executor), + external_id_(external_id), + stop_token_(std::move(stop_token)) {} MemoryPool* pool() const { return pool_; } @@ -75,10 +87,13 @@ struct ARROW_EXPORT IOContext { // An application-specific ID, forwarded to executor task submissions int64_t external_id() const { return external_id_; } + StopToken stop_token() const { return stop_token_; } + private: MemoryPool* pool_; ::arrow::internal::Executor* executor_; int64_t external_id_; + StopToken stop_token_; }; struct ARROW_DEPRECATED("renamed to IOContext in 4.0.0") AsyncContext : public IOContext { diff --git a/cpp/src/arrow/status.cc b/cpp/src/arrow/status.cc index cfc5eb1e345..0f02cb57a23 100644 --- a/cpp/src/arrow/status.cc +++ b/cpp/src/arrow/status.cc @@ -68,6 +68,9 @@ std::string Status::CodeAsString(StatusCode code) { case StatusCode::Invalid: type = "Invalid"; break; + case StatusCode::Cancelled: + type = "Cancelled"; + break; case StatusCode::IOError: type = "IOError"; break; diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 36f1b9023da..43879e6c6a3 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -83,6 +83,7 @@ enum class StatusCode : char { IOError = 5, CapacityError = 6, IndexError = 7, + Cancelled = 8, UnknownError = 9, NotImplemented = 10, SerializationError = 11, @@ -204,6 +205,12 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< return Status::FromArgs(StatusCode::Invalid, std::forward(args)...); } + /// Return an error status for cancelled operation + template + static Status Cancelled(Args&&... args) { + return Status::FromArgs(StatusCode::Cancelled, std::forward(args)...); + } + /// Return an error status when an index is out of bounds template static Status IndexError(Args&&... args) { @@ -263,6 +270,8 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< bool IsKeyError() const { return code() == StatusCode::KeyError; } /// Return true iff the status indicates invalid data. bool IsInvalid() const { return code() == StatusCode::Invalid; } + /// Return true iff the status indicates a cancelled operation. + bool IsCancelled() const { return code() == StatusCode::Cancelled; } /// Return true iff the status indicates an IO-related failure. bool IsIOError() const { return code() == StatusCode::IOError; } /// Return true iff the status indicates a container reaching capacity limits. diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index aa9d22dae2f..609572d2f50 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -539,6 +539,24 @@ EnvVarGuard::~EnvVarGuard() { } } +struct SignalHandlerGuard::Impl { + int signum_; + internal::SignalHandler old_handler_; + + Impl(int signum, const internal::SignalHandler& handler) + : signum_(signum), old_handler_(*internal::SetSignalHandler(signum, handler)) {} + + ~Impl() { ARROW_EXPECT_OK(internal::SetSignalHandler(signum_, old_handler_)); } +}; + +SignalHandlerGuard::SignalHandlerGuard(int signum, Callback cb) + : SignalHandlerGuard(signum, internal::SignalHandler(cb)) {} + +SignalHandlerGuard::SignalHandlerGuard(int signum, const internal::SignalHandler& handler) + : impl_(new Impl{signum, handler}) {} + +SignalHandlerGuard::~SignalHandlerGuard() = default; + namespace { // Used to prevent compiler optimizing away side-effect-less statements @@ -576,6 +594,13 @@ void SleepFor(double seconds) { std::chrono::nanoseconds(static_cast(seconds * 1e9))); } +void BusyWait(double seconds, std::function predicate) { + const double period = 0.001; + for (int i = 0; !predicate() && i * period < seconds; ++i) { + SleepFor(period); + } +} + /////////////////////////////////////////////////////////////////////////// // Extension types diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index ff3b751a394..8751cc5131e 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -472,6 +473,10 @@ inline void BitmapFromVector(const std::vector& is_valid, ARROW_TESTING_EXPORT void SleepFor(double seconds); +// Wait until predicate is true or timeout in seconds expires. +ARROW_TESTING_EXPORT +void BusyWait(double seconds, std::function predicate); + template std::vector IteratorToVector(Iterator iterator) { EXPECT_OK_AND_ASSIGN(auto out, iterator.ToVector()); @@ -504,6 +509,23 @@ class ARROW_TESTING_EXPORT EnvVarGuard { bool was_set_; }; +namespace internal { +class SignalHandler; +} + +class ARROW_TESTING_EXPORT SignalHandlerGuard { + public: + typedef void (*Callback)(int); + + SignalHandlerGuard(int signum, Callback cb); + SignalHandlerGuard(int signum, const internal::SignalHandler& handler); + ~SignalHandlerGuard(); + + protected: + struct Impl; + std::unique_ptr impl_; +}; + #ifndef ARROW_LARGE_MEMORY_TESTS #define LARGE_MEMORY_TEST(name) DISABLED_##name #else diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 718307deedf..da26547c7c9 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -68,9 +68,10 @@ add_arrow_test(utility-test add_arrow_test(threading-utility-test SOURCES - future_test - task_group_test - thread_pool_test) + cancel_test.cc + future_test.cc + task_group_test.cc + thread_pool_test.cc) add_arrow_benchmark(bit_block_counter_benchmark) add_arrow_benchmark(bit_util_benchmark) diff --git a/cpp/src/arrow/util/cancel.cc b/cpp/src/arrow/util/cancel.cc new file mode 100644 index 00000000000..533075a9a64 --- /dev/null +++ b/cpp/src/arrow/util/cancel.cc @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/cancel.h" + +#include +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +#if ATOMIC_INT_LOCK_FREE != 2 +#error Lock-free atomic int required for signal safety +#endif + +using internal::ReinstateSignalHandler; +using internal::SetSignalHandler; +using internal::SignalHandler; + +// NOTE: We care mainly about the making the common case (not cancelled) fast. + +struct StopSourceImpl { + std::atomic requested_{0}; // will be -1 or signal number if requested + std::mutex mutex_; + Status cancel_error_; +}; + +StopSource::StopSource() : impl_(new StopSourceImpl) {} + +StopSource::~StopSource() = default; + +void StopSource::RequestStop() { RequestStop(Status::Cancelled("Operation cancelled")); } + +void StopSource::RequestStop(Status st) { + std::lock_guard lock(impl_->mutex_); + DCHECK(!st.ok()); + if (!impl_->requested_) { + impl_->requested_ = -1; + impl_->cancel_error_ = std::move(st); + } +} + +void StopSource::RequestStopFromSignal(int signum) { + // Only async-signal-safe code allowed here + impl_->requested_.store(signum); +} + +void StopSource::Reset() { + std::lock_guard lock(impl_->mutex_); + impl_->cancel_error_ = Status::OK(); + impl_->requested_.store(0); +} + +StopToken StopSource::token() { return StopToken(impl_); } + +bool StopToken::IsStopRequested() { + if (!impl_) { + return false; + } + return impl_->requested_.load() != 0; +} + +Status StopToken::Poll() { + if (!impl_) { + return Status::OK(); + } + if (!impl_->requested_.load()) { + return Status::OK(); + } + + std::lock_guard lock(impl_->mutex_); + if (impl_->cancel_error_.ok()) { + auto signum = impl_->requested_.load(); + DCHECK_GT(signum, 0); + impl_->cancel_error_ = internal::CancelledFromSignal(signum, "Operation cancelled"); + } + return impl_->cancel_error_; +} + +namespace { + +struct SignalStopState { + struct SavedSignalHandler { + int signum; + SignalHandler handler; + }; + + Status RegisterHandlers(const std::vector& signals) { + if (!saved_handlers_.empty()) { + return Status::Invalid("Signal handlers already registered"); + } + for (int signum : signals) { + ARROW_ASSIGN_OR_RAISE(auto handler, + SetSignalHandler(signum, SignalHandler{&HandleSignal})); + saved_handlers_.push_back({signum, handler}); + } + return Status::OK(); + } + + void UnregisterHandlers() { + auto handlers = std::move(saved_handlers_); + for (const auto& h : handlers) { + ARROW_CHECK_OK(SetSignalHandler(h.signum, h.handler).status()); + } + } + + ~SignalStopState() { + UnregisterHandlers(); + Disable(); + } + + StopSource* stop_source() { return stop_source_.get(); } + + bool enabled() { return stop_source_ != nullptr; } + + void Enable() { + // Before creating a new StopSource, delete any lingering reference to + // the previous one in the trash can. See DoHandleSignal() for details. + EmptyTrashCan(); + internal::atomic_store(&stop_source_, std::make_shared()); + } + + void Disable() { internal::atomic_store(&stop_source_, NullSource()); } + + static SignalStopState* instance() { return &instance_; } + + private: + // For readability + std::shared_ptr NullSource() { return nullptr; } + + void EmptyTrashCan() { internal::atomic_store(&trash_can_, NullSource()); } + + static void HandleSignal(int signum) { instance_.DoHandleSignal(signum); } + + void DoHandleSignal(int signum) { + // async-signal-safe code only + auto source = internal::atomic_load(&stop_source_); + if (source) { + source->RequestStopFromSignal(signum); + // Disable() may have been called in the meantime, but we can't + // deallocate a shared_ptr here, so instead move it to a "trash can". + // This minimizes the possibility of running a deallocator here, + // however it doesn't entirely preclude it. + // + // Possible case: + // - a signal handler (A) starts running, fetches the current source + // - Disable() then Enable() are called, emptying the trash can and + // replacing the current source + // - a signal handler (B) starts running, fetches the current source + // - signal handler A resumes, moves its source (the old source) into + // the trash can (the only remaining reference) + // - signal handler B resumes, moves its source (the current source) + // into the trash can. This triggers deallocation of the old source, + // since the trash can had the only remaining reference to it. + // + // This case should be sufficiently unlikely, but we cannot entirely + // rule it out. The problem might be solved properly with a lock-free + // linked list of StopSources. + internal::atomic_store(&trash_can_, std::move(source)); + } + ReinstateSignalHandler(signum, &HandleSignal); + } + + std::shared_ptr stop_source_; + std::shared_ptr trash_can_; + + std::vector saved_handlers_; + + static SignalStopState instance_; +}; + +SignalStopState SignalStopState::instance_{}; + +} // namespace + +Result SetSignalStopSource() { + auto stop_state = SignalStopState::instance(); + if (stop_state->enabled()) { + return Status::Invalid("Signal stop source already set up"); + } + stop_state->Enable(); + return stop_state->stop_source(); +} + +void ResetSignalStopSource() { + auto stop_state = SignalStopState::instance(); + DCHECK(stop_state->enabled()); + stop_state->Disable(); +} + +Status RegisterCancellingSignalHandler(const std::vector& signals) { + auto stop_state = SignalStopState::instance(); + if (!stop_state->enabled()) { + return Status::Invalid("Signal stop source was not set up"); + } + return stop_state->RegisterHandlers(signals); +} + +void UnregisterCancellingSignalHandler() { + auto stop_state = SignalStopState::instance(); + DCHECK(stop_state->enabled()); + stop_state->UnregisterHandlers(); +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/cancel.h b/cpp/src/arrow/util/cancel.h new file mode 100644 index 00000000000..506a7e16e4f --- /dev/null +++ b/cpp/src/arrow/util/cancel.h @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class StopToken; + +struct StopSourceImpl; + +/// EXPERIMENTAL +class ARROW_EXPORT StopSource { + public: + StopSource(); + ~StopSource(); + + // Consumer API (the side that stops) + void RequestStop(); + void RequestStop(Status error); + void RequestStopFromSignal(int signum); + + StopToken token(); + + // For internal use only + void Reset(); + + protected: + std::shared_ptr impl_; +}; + +/// EXPERIMENTAL +class ARROW_EXPORT StopToken { + public: + // Public for Cython + StopToken() {} + + explicit StopToken(std::shared_ptr impl) : impl_(std::move(impl)) {} + + // A trivial token that never propagates any stop request + static StopToken Unstoppable() { return StopToken(); } + + // Producer API (the side that gets asked to stopped) + Status Poll(); + bool IsStopRequested(); + + protected: + std::shared_ptr impl_; +}; + +/// EXPERIMENTAL: Set a global StopSource that can receive signals +/// +/// The only allowed order of calls is the following: +/// - SetSignalStopSource() +/// - any number of pairs of (RegisterCancellingSignalHandler, +/// UnregisterCancellingSignalHandler) calls +/// - ResetSignalStopSource() +/// +/// Beware that these settings are process-wide. Typically, only one +/// thread should call these APIs, even in a multithreaded setting. +ARROW_EXPORT +Result SetSignalStopSource(); + +/// EXPERIMENTAL: Reset the global signal-receiving StopSource +/// +/// This will invalidate the pointer returned by SetSignalStopSource. +ARROW_EXPORT +void ResetSignalStopSource(); + +/// EXPERIMENTAL: Register signal handler triggering the signal-receiving StopSource +ARROW_EXPORT +Status RegisterCancellingSignalHandler(const std::vector& signals); + +/// EXPERIMENTAL: Unregister signal handler set up by RegisterCancellingSignalHandler +ARROW_EXPORT +void UnregisterCancellingSignalHandler(); + +} // namespace arrow diff --git a/cpp/src/arrow/util/cancel_test.cc b/cpp/src/arrow/util/cancel_test.cc new file mode 100644 index 00000000000..b9bf94ba43a --- /dev/null +++ b/cpp/src/arrow/util/cancel_test.cc @@ -0,0 +1,308 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#ifndef _WIN32 +#include // for setitimer() +#endif + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/cancel.h" +#include "arrow/util/future.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" + +namespace arrow { + +class CancelTest : public ::testing::Test {}; + +TEST_F(CancelTest, StopBasics) { + { + StopSource source; + StopToken token = source.token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + source.RequestStop(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + } + { + StopSource source; + StopToken token = source.token(); + source.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(IOError, token.Poll()); + } +} + +TEST_F(CancelTest, StopTokenCopy) { + StopSource source; + StopToken token = source.token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + StopToken token2 = token; + ASSERT_FALSE(token2.IsStopRequested()); + ASSERT_OK(token2.Poll()); + + source.RequestStop(); + StopToken token3 = token; + + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_TRUE(token2.IsStopRequested()); + ASSERT_TRUE(token3.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + ASSERT_EQ(token2.Poll(), token.Poll()); + ASSERT_EQ(token3.Poll(), token.Poll()); +} + +TEST_F(CancelTest, RequestStopTwice) { + StopSource source; + StopToken token = source.token(); + source.RequestStop(); + // Second RequestStop() call is ignored + source.RequestStop(Status::IOError("Operation cancelled")); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); +} + +TEST_F(CancelTest, Unstoppable) { + StopToken token = StopToken::Unstoppable(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); +} + +TEST_F(CancelTest, SourceVanishes) { + { + util::optional source{StopSource()}; + StopToken token = source->token(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + + source.reset(); + ASSERT_FALSE(token.IsStopRequested()); + ASSERT_OK(token.Poll()); + } + { + util::optional source{StopSource()}; + StopToken token = source->token(); + source->RequestStop(); + + source.reset(); + ASSERT_TRUE(token.IsStopRequested()); + ASSERT_RAISES(Cancelled, token.Poll()); + } +} + +static void noop_signal_handler(int signum) { + internal::ReinstateSignalHandler(signum, &noop_signal_handler); +} + +#ifndef _WIN32 +static util::optional signal_stop_source; + +static void signal_handler(int signum) { + signal_stop_source->RequestStopFromSignal(signum); +} + +// SIGALRM will be received once after the specified wait +static void SetITimer(double seconds) { + const double fractional = std::modf(seconds, &seconds); + struct itimerval it; + it.it_value.tv_sec = seconds; + it.it_value.tv_usec = 1e6 * fractional; + it.it_interval.tv_sec = 0; + it.it_interval.tv_usec = 0; + ASSERT_EQ(0, setitimer(ITIMER_REAL, &it, nullptr)) << "setitimer failed"; +} + +TEST_F(CancelTest, RequestStopFromSignal) { + signal_stop_source = StopSource(); // Start with a fresh StopSource + StopToken signal_token = signal_stop_source->token(); + SignalHandlerGuard guard(SIGALRM, &signal_handler); + + // Timer will be triggered once in 100 usecs + SetITimer(0.0001); + + BusyWait(1.0, [&]() { return signal_token.IsStopRequested(); }); + ASSERT_TRUE(signal_token.IsStopRequested()); + auto st = signal_token.Poll(); + ASSERT_RAISES(Cancelled, st); + ASSERT_EQ(st.message(), "Operation cancelled"); + ASSERT_EQ(internal::SignalFromStatus(st), SIGALRM); +} +#endif + +class SignalCancelTest : public CancelTest { + public: + void SetUp() override { + // Setup a dummy signal handler to avoid crashing when receiving signal + guard_.emplace(expected_signal_, &noop_signal_handler); + ASSERT_OK_AND_ASSIGN(auto stop_source, SetSignalStopSource()); + stop_token_ = stop_source->token(); + } + + void TearDown() override { + UnregisterCancellingSignalHandler(); + ResetSignalStopSource(); + } + + void RegisterHandler() { + ASSERT_OK(RegisterCancellingSignalHandler({expected_signal_})); + } + +#ifdef _WIN32 + void TriggerSignal() { + std::thread([]() { ASSERT_OK(internal::SendSignal(SIGINT)); }).detach(); + } +#else + // On Unix, use setitimer() to exercise signal-async-safety + void TriggerSignal() { SetITimer(0.0001); } +#endif + + void AssertStopNotRequested() { + SleepFor(0.01); + ASSERT_FALSE(stop_token_->IsStopRequested()); + ASSERT_OK(stop_token_->Poll()); + } + + void AssertStopRequested() { + BusyWait(1.0, [&]() { return stop_token_->IsStopRequested(); }); + ASSERT_TRUE(stop_token_->IsStopRequested()); + auto st = stop_token_->Poll(); + ASSERT_RAISES(Cancelled, st); + ASSERT_EQ(st.message(), "Operation cancelled"); + ASSERT_EQ(internal::SignalFromStatus(st), expected_signal_); + } + + protected: +#ifdef _WIN32 + const int expected_signal_ = SIGINT; +#else + const int expected_signal_ = SIGALRM; +#endif + util::optional guard_; + util::optional stop_token_; +}; + +TEST_F(SignalCancelTest, Register) { + RegisterHandler(); + + TriggerSignal(); + AssertStopRequested(); +} + +TEST_F(SignalCancelTest, RegisterUnregister) { + // The signal stop source was set up but no handler was registered, + // so the token shouldn't be signalled. + TriggerSignal(); + AssertStopNotRequested(); + + // Register and then unregister: same + RegisterHandler(); + UnregisterCancellingSignalHandler(); + + TriggerSignal(); + AssertStopNotRequested(); + + // Register again and raise the signal: token will be signalled. + RegisterHandler(); + + TriggerSignal(); + AssertStopRequested(); +} + +TEST_F(CancelTest, ThreadedPollSuccess) { + constexpr int kNumThreads = 10; + + std::vector results(kNumThreads); + std::vector threads; + + StopSource source; + StopToken token = source.token(); + std::atomic terminate_flag{false}; + + const auto worker_func = [&](int thread_num) { + while (token.Poll().ok() && !terminate_flag) { + } + results[thread_num] = token.Poll(); + }; + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(std::bind(worker_func, i)); + } + + // Let the threads start and hammer on Poll() for a while + SleepFor(1e-2); + // Tell threads to stop + terminate_flag = true; + for (auto& thread : threads) { + thread.join(); + } + + for (const auto& st : results) { + ASSERT_OK(st); + } +} + +TEST_F(CancelTest, ThreadedPollCancel) { + constexpr int kNumThreads = 10; + + std::vector results(kNumThreads); + std::vector threads; + + StopSource source; + StopToken token = source.token(); + std::atomic terminate_flag{false}; + const auto stop_error = Status::IOError("Operation cancelled"); + + const auto worker_func = [&](int thread_num) { + while (token.Poll().ok() && !terminate_flag) { + } + results[thread_num] = token.Poll(); + }; + + for (int i = 0; i < kNumThreads; ++i) { + threads.emplace_back(std::bind(worker_func, i)); + } + // Let the threads start + SleepFor(1e-2); + // Cancel token while threads are hammering on Poll() + source.RequestStop(stop_error); + // Tell threads to stop + terminate_flag = true; + for (auto& thread : threads) { + thread.join(); + } + + for (const auto& st : results) { + ASSERT_EQ(st, stop_error); + } +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index cabbfb350bc..e3460a50b1f 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -304,6 +304,26 @@ class WinErrorDetail : public StatusDetail { }; #endif +const char kSignalDetailTypeId[] = "arrow::SignalDetail"; + +class SignalDetail : public StatusDetail { + public: + explicit SignalDetail(int signum) : signum_(signum) {} + + const char* type_id() const override { return kSignalDetailTypeId; } + + std::string ToString() const override { + std::stringstream ss; + ss << "received signal " << signum_; + return ss.str(); + } + + int signum() const { return signum_; } + + protected: + int signum_; +}; + } // namespace std::shared_ptr StatusDetailFromErrno(int errnum) { @@ -316,6 +336,10 @@ std::shared_ptr StatusDetailFromWinError(int errnum) { } #endif +std::shared_ptr StatusDetailFromSignal(int signum) { + return std::make_shared(signum); +} + int ErrnoFromStatus(const Status& status) { const auto detail = status.detail(); if (detail != nullptr && detail->type_id() == kErrnoDetailTypeId) { @@ -334,6 +358,14 @@ int WinErrorFromStatus(const Status& status) { return 0; } +int SignalFromStatus(const Status& status) { + const auto detail = status.detail(); + if (detail != nullptr && detail->type_id() == kSignalDetailTypeId) { + return checked_cast(*detail).signum(); + } + return 0; +} + // // PlatformFilename implementation // @@ -1608,6 +1640,39 @@ Result SetSignalHandler(int signum, const SignalHandler& handler) return Status::OK(); } +void ReinstateSignalHandler(int signum, SignalHandler::Callback handler) { +#if !ARROW_HAVE_SIGACTION + // Cannot report any errors from signal() (but there shouldn't be any) + signal(signum, handler); +#endif +} + +Status SendSignal(int signum) { + if (raise(signum) == 0) { + return Status::OK(); + } + if (errno == EINVAL) { + return Status::Invalid("Invalid signal number ", signum); + } + return IOErrorFromErrno(errno, "Failed to raise signal"); +} + +Status SendSignalToThread(int signum, uint64_t thread_id) { +#ifdef _WIN32 + return Status::NotImplemented("Cannot send signal to specific thread on Windows"); +#else + // Have to use a C-style cast because pthread_t can be a pointer *or* integer type + int r = pthread_kill((pthread_t)thread_id, signum); // NOLINT readability-casting + if (r == 0) { + return Status::OK(); + } + if (r == EINVAL) { + return Status::Invalid("Invalid signal number ", signum); + } + return IOErrorFromErrno(r, "Failed to raise signal"); +#endif +} + namespace { int64_t GetPid() { diff --git a/cpp/src/arrow/util/io_util.h b/cpp/src/arrow/util/io_util.h index 541041a6acd..38bcdd4b41f 100644 --- a/cpp/src/arrow/util/io_util.h +++ b/cpp/src/arrow/util/io_util.h @@ -270,6 +270,8 @@ std::shared_ptr StatusDetailFromErrno(int errnum); ARROW_EXPORT std::shared_ptr StatusDetailFromWinError(int errnum); #endif +ARROW_EXPORT +std::shared_ptr StatusDetailFromSignal(int signum); template Status StatusFromErrno(int errnum, StatusCode code, Args&&... args) { @@ -295,6 +297,17 @@ Status IOErrorFromWinError(int errnum, Args&&... args) { } #endif +template +Status StatusFromSignal(int signum, StatusCode code, Args&&... args) { + return Status::FromDetailAndArgs(code, StatusDetailFromSignal(signum), + std::forward(args)...); +} + +template +Status CancelledFromSignal(int signum, Args&&... args) { + return StatusFromSignal(signum, StatusCode::Cancelled, std::forward(args)...); +} + ARROW_EXPORT int ErrnoFromStatus(const Status&); @@ -302,6 +315,9 @@ int ErrnoFromStatus(const Status&); ARROW_EXPORT int WinErrorFromStatus(const Status&); +ARROW_EXPORT +int SignalFromStatus(const Status&); + class ARROW_EXPORT TemporaryDir { public: ~TemporaryDir(); @@ -354,6 +370,26 @@ Result GetSignalHandler(int signum); ARROW_EXPORT Result SetSignalHandler(int signum, const SignalHandler& handler); +/// \brief Reinstate the signal handler +/// +/// For use in signal handlers. This is needed on platforms without sigaction() +/// such as Windows, as the default signal handler is restored there as +/// soon as a signal is raised. +ARROW_EXPORT +void ReinstateSignalHandler(int signum, SignalHandler::Callback handler); + +/// \brief Send a signal to the current process +/// +/// The thread which will receive the signal is unspecified. +ARROW_EXPORT +Status SendSignal(int signum); + +/// \brief Send a signal to the given thread +/// +/// This function isn't supported on Windows. +ARROW_EXPORT +Status SendSignalToThread(int signum, uint64_t thread_id); + /// \brief Get an unpredictable random seed /// /// This function may be slightly costly, so should only be used to initialize diff --git a/cpp/src/arrow/util/io_util_test.cc b/cpp/src/arrow/util/io_util_test.cc index d84a2b76e39..a423ecd0152 100644 --- a/cpp/src/arrow/util/io_util_test.cc +++ b/cpp/src/arrow/util/io_util_test.cc @@ -17,10 +17,18 @@ // under the License. #include +#include #include #include +#include #include +#include + +#ifndef _WIN32 +#include +#endif + #include #include "arrow/testing/gtest_util.h" @@ -59,6 +67,30 @@ TEST(ErrnoFromStatus, Basics) { ASSERT_EQ(ErrnoFromStatus(st), EPERM); st = IOErrorFromErrno(6789, "foo"); ASSERT_EQ(ErrnoFromStatus(st), 6789); + + st = CancelledFromSignal(SIGINT, "foo"); + ASSERT_EQ(ErrnoFromStatus(st), 0); +} + +TEST(SignalFromStatus, Basics) { + Status st; + st = Status::OK(); + ASSERT_EQ(SignalFromStatus(st), 0); + st = Status::KeyError("foo"); + ASSERT_EQ(SignalFromStatus(st), 0); + st = Status::Cancelled("foo"); + ASSERT_EQ(SignalFromStatus(st), 0); + st = StatusFromSignal(SIGINT, StatusCode::KeyError, "foo"); + ASSERT_EQ(SignalFromStatus(st), SIGINT); + ASSERT_EQ(st.ToString(), + "Key error: foo. Detail: received signal " + std::to_string(SIGINT)); + st = CancelledFromSignal(SIGINT, "bar"); + ASSERT_EQ(SignalFromStatus(st), SIGINT); + ASSERT_EQ(st.ToString(), + "Cancelled: bar. Detail: received signal " + std::to_string(SIGINT)); + + st = IOErrorFromErrno(EINVAL, "foo"); + ASSERT_EQ(SignalFromStatus(st), 0); } TEST(GetPageSize, Basics) { @@ -623,5 +655,46 @@ TEST(FileUtils, LongPaths) { } #endif +static std::atomic signal_received; + +static void handle_signal(int signum) { + ReinstateSignalHandler(signum, &handle_signal); + signal_received.store(signum); +} + +TEST(SendSignal, Generic) { + signal_received.store(0); + SignalHandlerGuard guard(SIGINT, &handle_signal); + + ASSERT_EQ(signal_received.load(), 0); + ASSERT_OK(SendSignal(SIGINT)); + BusyWait(1.0, [&]() { return signal_received.load() != 0; }); + ASSERT_EQ(signal_received.load(), SIGINT); + + // Re-try (exercise ReinstateSignalHandler) + signal_received.store(0); + ASSERT_OK(SendSignal(SIGINT)); + BusyWait(1.0, [&]() { return signal_received.load() != 0; }); + ASSERT_EQ(signal_received.load(), SIGINT); +} + +TEST(SendSignal, ToThread) { +#ifdef _WIN32 + uint64_t dummy_thread_id = 42; + ASSERT_RAISES(NotImplemented, SendSignalToThread(SIGINT, dummy_thread_id)); +#else + // Have to use a C-style cast because pthread_t can be a pointer *or* integer type + uint64_t thread_id = (uint64_t)(pthread_self()); // NOLINT readability-casting + signal_received.store(0); + SignalHandlerGuard guard(SIGINT, &handle_signal); + + ASSERT_EQ(signal_received.load(), 0); + ASSERT_OK(SendSignalToThread(SIGINT, thread_id)); + BusyWait(1.0, [&]() { return signal_received.load() != 0; }); + + ASSERT_EQ(signal_received.load(), SIGINT); +#endif +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/task_group.cc b/cpp/src/arrow/util/task_group.cc index a7b55921d32..7e8ab64b703 100644 --- a/cpp/src/arrow/util/task_group.cc +++ b/cpp/src/arrow/util/task_group.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -31,15 +30,23 @@ namespace arrow { namespace internal { +namespace { + //////////////////////////////////////////////////////////////////////// // Serial TaskGroup implementation class SerialTaskGroup : public TaskGroup { public: - void AppendReal(std::function task) override { + explicit SerialTaskGroup(StopToken stop_token) : stop_token_(std::move(stop_token)) {} + + void AppendReal(FnOnce task) override { DCHECK(!finished_); + if (stop_token_.IsStopRequested()) { + status_ &= stop_token_.Poll(); + return; + } if (status_.ok()) { - status_ &= task(); + status_ &= std::move(task)(); } } @@ -58,6 +65,7 @@ class SerialTaskGroup : public TaskGroup { int parallelism() override { return 1; } + StopToken stop_token_; Status status_; bool finished_ = false; }; @@ -67,8 +75,11 @@ class SerialTaskGroup : public TaskGroup { class ThreadedTaskGroup : public TaskGroup { public: - explicit ThreadedTaskGroup(Executor* executor) - : executor_(executor), nremaining_(0), ok_(true) {} + ThreadedTaskGroup(Executor* executor, StopToken stop_token) + : executor_(executor), + stop_token_(std::move(stop_token)), + nremaining_(0), + ok_(true) {} ~ThreadedTaskGroup() override { // Make sure all pending tasks are finished, so that dangling references @@ -76,25 +87,42 @@ class ThreadedTaskGroup : public TaskGroup { ARROW_UNUSED(Finish()); } - void AppendReal(std::function task) override { + void AppendReal(FnOnce task) override { DCHECK(!finished_); + if (stop_token_.IsStopRequested()) { + UpdateStatus(stop_token_.Poll()); + return; + } + // The hot path is unlocked thanks to atomics // Only if an error occurs is the lock taken if (ok_.load(std::memory_order_acquire)) { nremaining_.fetch_add(1, std::memory_order_acquire); auto self = checked_pointer_cast(shared_from_this()); - Status st = executor_->Spawn(std::bind( - [](const std::shared_ptr& self, - const std::function& task) { - if (self->ok_.load(std::memory_order_acquire)) { + + struct Callable { + void operator()() { + if (self_->ok_.load(std::memory_order_acquire)) { + Status st; + if (stop_token_.IsStopRequested()) { + st = stop_token_.Poll(); + } else { // XXX what about exceptions? - Status st = task(); - self->UpdateStatus(std::move(st)); + st = std::move(task_)(); } - self->OneTaskDone(); - }, - std::move(self), std::move(task))); + self_->UpdateStatus(std::move(st)); + } + self_->OneTaskDone(); + } + + std::shared_ptr self_; + FnOnce task_; + StopToken stop_token_; + }; + + Status st = + executor_->Spawn(Callable{std::move(self), std::move(task), stop_token_}); UpdateStatus(std::move(st)); } } @@ -169,6 +197,7 @@ class ThreadedTaskGroup : public TaskGroup { // These members are usable unlocked Executor* executor_; + StopToken stop_token_; std::atomic nremaining_; std::atomic ok_; @@ -180,12 +209,15 @@ class ThreadedTaskGroup : public TaskGroup { util::optional> completion_future_; }; -std::shared_ptr TaskGroup::MakeSerial() { - return std::shared_ptr(new SerialTaskGroup); +} // namespace + +std::shared_ptr TaskGroup::MakeSerial(StopToken stop_token) { + return std::shared_ptr(new SerialTaskGroup{stop_token}); } -std::shared_ptr TaskGroup::MakeThreaded(Executor* thread_pool) { - return std::shared_ptr(new ThreadedTaskGroup(thread_pool)); +std::shared_ptr TaskGroup::MakeThreaded(Executor* thread_pool, + StopToken stop_token) { + return std::shared_ptr(new ThreadedTaskGroup{thread_pool, stop_token}); } } // namespace internal diff --git a/cpp/src/arrow/util/task_group.h b/cpp/src/arrow/util/task_group.h index a6df43f1131..7a96bada013 100644 --- a/cpp/src/arrow/util/task_group.h +++ b/cpp/src/arrow/util/task_group.h @@ -17,11 +17,12 @@ #pragma once -#include #include #include #include "arrow/status.h" +#include "arrow/util/cancel.h" +#include "arrow/util/functional.h" #include "arrow/util/macros.h" #include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" @@ -87,8 +88,9 @@ class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this { /// This is only a hint, useful for testing or debugging. virtual int parallelism() = 0; - static std::shared_ptr MakeSerial(); - static std::shared_ptr MakeThreaded(internal::Executor*); + static std::shared_ptr MakeSerial(StopToken = StopToken::Unstoppable()); + static std::shared_ptr MakeThreaded(internal::Executor*, + StopToken = StopToken::Unstoppable()); virtual ~TaskGroup() = default; @@ -96,7 +98,7 @@ class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this { TaskGroup() = default; ARROW_DISALLOW_COPY_AND_ASSIGN(TaskGroup); - virtual void AppendReal(std::function task) = 0; + virtual void AppendReal(FnOnce task) = 0; }; } // namespace internal diff --git a/cpp/src/arrow/util/task_group_test.cc b/cpp/src/arrow/util/task_group_test.cc index 38f4b211820..f4e7974ee45 100644 --- a/cpp/src/arrow/util/task_group_test.cc +++ b/cpp/src/arrow/util/task_group_test.cc @@ -114,6 +114,49 @@ void TestTaskGroupErrors(std::shared_ptr task_group) { ASSERT_RAISES(Invalid, task_group->Finish()); } +void TestTaskGroupCancel(std::shared_ptr task_group, StopSource* stop_source) { + const int NSUCCESSES = 2; + const int NCANCELS = 20; + + std::atomic count(0); + + auto task_group_was_ok = false; + task_group->Append([&]() -> Status { + for (int i = 0; i < NSUCCESSES; ++i) { + task_group->Append([&]() { + count++; + return Status::OK(); + }); + } + task_group_was_ok = task_group->ok(); + for (int i = 0; i < NCANCELS; ++i) { + task_group->Append([&]() { + SleepFor(1e-2); + stop_source->RequestStop(); + count++; + return Status::OK(); + }); + } + + return Status::OK(); + }); + + // Cancellation is propagated + ASSERT_RAISES(Cancelled, task_group->Finish()); + ASSERT_TRUE(task_group_was_ok); + ASSERT_FALSE(task_group->ok()); + if (task_group->parallelism() == 1) { + // Serial: exactly three successes + ASSERT_EQ(count.load(), NSUCCESSES + 1); + } else { + // Parallel: at least three successes + ASSERT_GE(count.load(), NSUCCESSES + 1); + ASSERT_LE(count.load(), NSUCCESSES * task_group->parallelism()); + } + // Finish() is idempotent + ASSERT_RAISES(Cancelled, task_group->Finish()); +} + class CopyCountingTask { public: explicit CopyCountingTask(std::shared_ptr target) @@ -310,6 +353,11 @@ TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); } +TEST(SerialTaskGroup, Cancel) { + StopSource stop_source; + TestTaskGroupCancel(TaskGroup::MakeSerial(stop_source.token()), &stop_source); +} + TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSerial()); } TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); } @@ -336,6 +384,15 @@ TEST(ThreadedTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeThreaded(thread_pool.get())); } +TEST(ThreadedTaskGroup, Cancel) { + std::shared_ptr thread_pool; + ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); + + StopSource stop_source; + TestTaskGroupCancel(TaskGroup::MakeThreaded(thread_pool.get(), stop_source.token()), + &stop_source); +} + TEST(ThreadedTaskGroup, TasksSpawnTasks) { auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); TestTasksSpawnTasks(task_group); diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 33eb937ba43..4c644b39cdb 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -34,6 +34,16 @@ namespace internal { Executor::~Executor() = default; +namespace { + +struct Task { + FnOnce callable; + StopToken stop_token; + Executor::StopCallback stop_callback; +}; + +} // namespace + struct ThreadPool::State { State() = default; @@ -47,7 +57,7 @@ struct ThreadPool::State { std::list workers_; // Trashcan for finished threads std::vector finished_workers_; - std::deque> pending_tasks_; + std::deque pending_tasks_; // Desired number of threads int desired_capacity_ = 0; @@ -91,12 +101,20 @@ static void WorkerLoop(std::shared_ptr state, --state->ready_count_; DCHECK_GE(state->ready_count_, 0); { - FnOnce task = std::move(state->pending_tasks_.front()); + Task task = std::move(state->pending_tasks_.front()); state->pending_tasks_.pop_front(); + StopToken* stop_token = &task.stop_token; lock.unlock(); - std::move(task)(); + if (!stop_token->IsStopRequested()) { + std::move(task.callable)(); + } else { + if (task.stop_callback) { + std::move(task.stop_callback)(stop_token->Poll()); + } + } + ARROW_UNUSED(std::move(task)); // release resources before waiting for lock + lock.lock(); } - lock.lock(); ++state->ready_count_; } // Now either the queue is empty *or* a quick shutdown was requested @@ -242,7 +260,8 @@ void ThreadPool::LaunchWorkersUnlocked(int threads) { } } -Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task) { +Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task, StopToken stop_token, + StopCallback&& stop_callback) { { ProtectAgainstFork(); std::lock_guard lock(state_->mutex_); @@ -256,7 +275,8 @@ Status ThreadPool::SpawnReal(TaskHints hints, FnOnce task) { // spawn one more thread. LaunchWorkersUnlocked(/*threads=*/1); } - state_->pending_tasks_.push_back(std::move(task)); + state_->pending_tasks_.push_back( + {std::move(task), std::move(stop_token), std::move(stop_callback)}); } state_->cv_.notify_one(); return Status::OK(); diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 5db3a9a4722..6334b010c21 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -28,6 +28,8 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/util/cancel.h" +#include "arrow/util/functional.h" #include "arrow/util/future.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -73,17 +75,22 @@ struct TaskHints { class ARROW_EXPORT Executor { public: + using StopCallback = internal::FnOnce; + virtual ~Executor(); // Spawn a fire-and-forget task. template - Status Spawn(Function&& func) { - return SpawnReal(TaskHints{}, std::forward(func)); + Status Spawn(Function&& func, StopToken stop_token = StopToken::Unstoppable()) { + return SpawnReal(TaskHints{}, std::forward(func), std::move(stop_token), + StopCallback{}); } template - Status Spawn(TaskHints hints, Function&& func) { - return SpawnReal(hints, std::forward(func)); + Status Spawn(TaskHints hints, Function&& func, + StopToken stop_token = StopToken::Unstoppable()) { + return SpawnReal(hints, std::forward(func), std::move(stop_token), + StopCallback{}); } // Transfers a future to this executor. Any continuations added to the @@ -114,21 +121,51 @@ class ARROW_EXPORT Executor { template > - Result Submit(TaskHints hints, Function&& func, Args&&... args) { - auto future = FutureType::Make(); + Result Submit(TaskHints hints, StopToken stop_token, Function&& func, + Args&&... args) { + using ValueType = typename FutureType::ValueType; + auto future = FutureType::Make(); auto task = std::bind(::arrow::detail::ContinueFuture{}, future, std::forward(func), std::forward(args)...); - ARROW_RETURN_NOT_OK(SpawnReal(hints, std::move(task))); + struct { + WeakFuture weak_fut; + + void operator()(const Status& st) { + auto fut = weak_fut.get(); + if (fut.is_valid()) { + fut.MarkFinished(st); + } + } + } stop_callback{WeakFuture(future)}; + ARROW_RETURN_NOT_OK(SpawnReal(hints, std::move(task), std::move(stop_token), + std::move(stop_callback))); return future; } + template > + Result Submit(StopToken stop_token, Function&& func, Args&&... args) { + return Submit(TaskHints{}, stop_token, std::forward(func), + std::forward(args)...); + } + + template > + Result Submit(TaskHints hints, Function&& func, Args&&... args) { + return Submit(std::move(hints), StopToken::Unstoppable(), + std::forward(func), std::forward(args)...); + } + template > Result Submit(Function&& func, Args&&... args) { - return Submit(TaskHints{}, std::forward(func), std::forward(args)...); + return Submit(TaskHints{}, StopToken::Unstoppable(), std::forward(func), + std::forward(args)...); } // Return the level of parallelism (the number of tasks that may be executed @@ -141,7 +178,8 @@ class ARROW_EXPORT Executor { Executor() = default; // Subclassing API - virtual Status SpawnReal(TaskHints hints, FnOnce task) = 0; + virtual Status SpawnReal(TaskHints hints, FnOnce task, StopToken, + StopCallback&&) = 0; }; // An Executor implementation spawning tasks in FIFO manner on a fixed-size @@ -192,7 +230,8 @@ class ARROW_EXPORT ThreadPool : public Executor { ThreadPool(); - Status SpawnReal(TaskHints hints, FnOnce task) override; + Status SpawnReal(TaskHints hints, FnOnce task, StopToken, + StopCallback&&) override; // Collect finished worker threads, making sure the OS threads have exited void CollectFinishedWorkersUnlocked(); diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index fef0acc6395..a2add4cd469 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -40,23 +40,20 @@ namespace arrow { namespace internal { -static void busy_wait(double seconds, std::function predicate) { - const double period = 0.001; - for (int i = 0; !predicate() && i * period < seconds; ++i) { - SleepFor(period); - } -} - template static void task_add(T x, T y, T* out) { *out = x + y; } template -static void task_slow_add(double seconds, T x, T y, T* out) { - SleepFor(seconds); - *out = x + y; -} +struct task_slow_add { + void operator()(T x, T y, T* out) { + SleepFor(seconds_); + *out = x + y; + } + + const double seconds_; +}; typedef std::function AddTaskFunc; @@ -80,13 +77,14 @@ static T inplace_add(T& x, T y) { class AddTester { public: - explicit AddTester(int nadds) : nadds(nadds), xs(nadds), ys(nadds), outs(nadds, -1) { + explicit AddTester(int nadds, StopToken stop_token = StopToken::Unstoppable()) + : nadds_(nadds), stop_token_(stop_token), xs_(nadds), ys_(nadds), outs_(nadds, -1) { int x = 0, y = 0; - std::generate(xs.begin(), xs.end(), [&] { + std::generate(xs_.begin(), xs_.end(), [&] { ++x; return x; }); - std::generate(ys.begin(), ys.end(), [&] { + std::generate(ys_.begin(), ys_.end(), [&] { y += 10; return y; }); @@ -95,20 +93,20 @@ class AddTester { AddTester(AddTester&&) = default; void SpawnTasks(ThreadPool* pool, AddTaskFunc add_func) { - for (int i = 0; i < nadds; ++i) { - ASSERT_OK(pool->Spawn([=] { add_func(xs[i], ys[i], &outs[i]); })); + for (int i = 0; i < nadds_; ++i) { + ASSERT_OK(pool->Spawn([=] { add_func(xs_[i], ys_[i], &outs_[i]); }, stop_token_)); } } void CheckResults() { - for (int i = 0; i < nadds; ++i) { - ASSERT_EQ(outs[i], (i + 1) * 11); + for (int i = 0; i < nadds_; ++i) { + ASSERT_EQ(outs_[i], (i + 1) * 11); } } void CheckNotAllComputed() { - for (int i = 0; i < nadds; ++i) { - if (outs[i] == -1) { + for (int i = 0; i < nadds_; ++i) { + if (outs_[i] == -1) { return; } } @@ -118,10 +116,11 @@ class AddTester { private: ARROW_DISALLOW_COPY_AND_ASSIGN(AddTester); - int nadds; - std::vector xs; - std::vector ys; - std::vector outs; + int nadds_; + StopToken stop_token_; + std::vector xs_; + std::vector ys_; + std::vector outs_; }; class TestThreadPool : public ::testing::Test { @@ -137,32 +136,72 @@ class TestThreadPool : public ::testing::Test { return *ThreadPool::Make(threads); } - void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func) { - AddTester add_tester(nadds); + void DoSpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func, + StopToken stop_token = StopToken::Unstoppable(), + StopSource* stop_source = nullptr) { + AddTester add_tester(nadds, stop_token); add_tester.SpawnTasks(pool, add_func); + if (stop_source) { + stop_source->RequestStop(); + } ASSERT_OK(pool->Shutdown()); - add_tester.CheckResults(); + if (stop_source) { + add_tester.CheckNotAllComputed(); + } else { + add_tester.CheckResults(); + } } - void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, - AddTaskFunc add_func) { + void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func, + StopToken stop_token = StopToken::Unstoppable()) { + DoSpawnAdds(pool, nadds, std::move(add_func), std::move(stop_token)); + } + + void SpawnAddsAndCancel(ThreadPool* pool, int nadds, AddTaskFunc add_func, + StopSource* stop_source) { + DoSpawnAdds(pool, nadds, std::move(add_func), stop_source->token(), stop_source); + } + + void DoSpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, + AddTaskFunc add_func, + StopToken stop_token = StopToken::Unstoppable(), + StopSource* stop_source = nullptr) { // Same as SpawnAdds, but do the task spawning from multiple threads std::vector add_testers; std::vector threads; for (int i = 0; i < nthreads; ++i) { - add_testers.emplace_back(nadds); + add_testers.emplace_back(nadds, stop_token); } for (auto& add_tester : add_testers) { threads.emplace_back([&] { add_tester.SpawnTasks(pool, add_func); }); } + if (stop_source) { + stop_source->RequestStop(); + } for (auto& thread : threads) { thread.join(); } ASSERT_OK(pool->Shutdown()); for (auto& add_tester : add_testers) { - add_tester.CheckResults(); + if (stop_source) { + add_tester.CheckNotAllComputed(); + } else { + add_tester.CheckResults(); + } } } + + void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, AddTaskFunc add_func, + StopToken stop_token = StopToken::Unstoppable()) { + DoSpawnAddsThreaded(pool, nthreads, nadds, std::move(add_func), + std::move(stop_token)); + } + + void SpawnAddsThreadedAndCancel(ThreadPool* pool, int nthreads, int nadds, + AddTaskFunc add_func, StopSource* stop_source) { + DoSpawnAddsThreaded(pool, nthreads, nadds, std::move(add_func), stop_source->token(), + stop_source); + } }; TEST_F(TestThreadPool, ConstructDestruct) { @@ -192,32 +231,49 @@ TEST_F(TestThreadPool, StressSpawnThreaded) { TEST_F(TestThreadPool, SpawnSlow) { // This checks that Shutdown() waits for all tasks to finish auto pool = this->MakeThreadPool(2); - SpawnAdds(pool.get(), 7, [](int x, int y, int* out) { - return task_slow_add(0.02 /* seconds */, x, y, out); - }); + SpawnAdds(pool.get(), 7, task_slow_add{/*seconds=*/0.02}); } TEST_F(TestThreadPool, StressSpawnSlow) { auto pool = this->MakeThreadPool(30); - SpawnAdds(pool.get(), 1000, [](int x, int y, int* out) { - return task_slow_add(0.002 /* seconds */, x, y, out); - }); + SpawnAdds(pool.get(), 1000, task_slow_add{/*seconds=*/0.002}); } TEST_F(TestThreadPool, StressSpawnSlowThreaded) { auto pool = this->MakeThreadPool(30); - SpawnAddsThreaded(pool.get(), 20, 100, [](int x, int y, int* out) { - return task_slow_add(0.002 /* seconds */, x, y, out); - }); + SpawnAddsThreaded(pool.get(), 20, 100, task_slow_add{/*seconds=*/0.002}); +} + +TEST_F(TestThreadPool, SpawnWithStopToken) { + StopSource stop_source; + auto pool = this->MakeThreadPool(3); + SpawnAdds(pool.get(), 7, task_add, stop_source.token()); +} + +TEST_F(TestThreadPool, StressSpawnThreadedWithStopToken) { + StopSource stop_source; + auto pool = this->MakeThreadPool(30); + SpawnAddsThreaded(pool.get(), 20, 100, task_add, stop_source.token()); +} + +TEST_F(TestThreadPool, SpawnWithStopTokenCancelled) { + StopSource stop_source; + auto pool = this->MakeThreadPool(3); + SpawnAddsAndCancel(pool.get(), 100, task_slow_add{/*seconds=*/0.02}, &stop_source); +} + +TEST_F(TestThreadPool, StressSpawnThreadedWithStopTokenCancelled) { + StopSource stop_source; + auto pool = this->MakeThreadPool(30); + SpawnAddsThreadedAndCancel(pool.get(), 20, 100, task_slow_add{/*seconds=*/0.02}, + &stop_source); } TEST_F(TestThreadPool, QuickShutdown) { AddTester add_tester(100); { auto pool = this->MakeThreadPool(3); - add_tester.SpawnTasks(pool.get(), [](int x, int y, int* out) { - return task_slow_add(0.02 /* seconds */, x, y, out); - }); + add_tester.SpawnTasks(pool.get(), task_slow_add{/*seconds=*/0.02}); ASSERT_OK(pool->Shutdown(false /* wait */)); add_tester.CheckNotAllComputed(); } @@ -235,12 +291,12 @@ TEST_F(TestThreadPool, SetCapacity) { ASSERT_EQ(pool->GetCapacity(), 3); ASSERT_EQ(pool->GetActualCapacity(), 0); - ASSERT_OK(pool->Spawn(std::bind(SleepFor, 0.1 /* seconds */))); + ASSERT_OK(pool->Spawn(std::bind(SleepFor, /*seconds=*/0.1))); ASSERT_EQ(pool->GetActualCapacity(), 1); // Spawn more tasks than the pool capacity for (int i = 0; i < 6; ++i) { - ASSERT_OK(pool->Spawn(std::bind(SleepFor, 0.1 /* seconds */))); + ASSERT_OK(pool->Spawn(std::bind(SleepFor, /*seconds=*/0.1))); } ASSERT_EQ(pool->GetActualCapacity(), 3); // maxxed out @@ -254,20 +310,20 @@ TEST_F(TestThreadPool, SetCapacity) { ASSERT_OK(pool->SetCapacity(2)); ASSERT_EQ(pool->GetCapacity(), 2); // Wait for workers to wake up and secede - busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; }); + BusyWait(0.5, [&] { return pool->GetActualCapacity() == 2; }); ASSERT_EQ(pool->GetActualCapacity(), 2); // Downsize while tasks are pending ASSERT_OK(pool->SetCapacity(5)); ASSERT_EQ(pool->GetCapacity(), 5); for (int i = 0; i < 10; ++i) { - ASSERT_OK(pool->Spawn(std::bind(SleepFor, 0.1 /* seconds */))); + ASSERT_OK(pool->Spawn(std::bind(SleepFor, /*seconds=*/0.1))); } ASSERT_EQ(pool->GetActualCapacity(), 5); ASSERT_OK(pool->SetCapacity(2)); ASSERT_EQ(pool->GetCapacity(), 2); - busy_wait(0.5, [&] { return pool->GetActualCapacity() == 2; }); + BusyWait(0.5, [&] { return pool->GetActualCapacity() == 2; }); ASSERT_EQ(pool->GetActualCapacity(), 2); // Ensure nothing got stuck @@ -289,7 +345,7 @@ TEST_F(TestThreadPool, Submit) { ASSERT_OK_AND_EQ("foobar", fut.result()); } { - ASSERT_OK_AND_ASSIGN(auto fut, pool->Submit(slow_add, 0.01 /* seconds */, 4, 5)); + ASSERT_OK_AND_ASSIGN(auto fut, pool->Submit(slow_add, /*seconds=*/0.01, 4, 5)); ASSERT_OK_AND_EQ(9, fut.result()); } { @@ -307,6 +363,48 @@ TEST_F(TestThreadPool, Submit) { } } +TEST_F(TestThreadPool, SubmitWithStopToken) { + auto pool = this->MakeThreadPool(3); + { + StopSource stop_source; + ASSERT_OK_AND_ASSIGN(Future fut, + pool->Submit(stop_source.token(), add, 4, 5)); + Result res = fut.result(); + ASSERT_OK_AND_EQ(9, res); + } +} + +TEST_F(TestThreadPool, SubmitWithStopTokenCancelled) { + auto pool = this->MakeThreadPool(3); + { + const int n_futures = 100; + StopSource stop_source; + StopToken stop_token = stop_source.token(); + std::vector> futures; + for (int i = 0; i < n_futures; ++i) { + ASSERT_OK_AND_ASSIGN( + auto fut, pool->Submit(stop_token, slow_add, 0.01 /*seconds*/, i, 1)); + futures.push_back(std::move(fut)); + } + SleepFor(0.05); // Let some work finish + stop_source.RequestStop(); + int n_success = 0; + int n_cancelled = 0; + for (int i = 0; i < n_futures; ++i) { + Result res = futures[i].result(); + if (res.ok()) { + ASSERT_EQ(i + 1, *res); + ++n_success; + } else { + ASSERT_RAISES(Cancelled, res); + ++n_cancelled; + } + } + ASSERT_GT(n_success, 0); + ASSERT_GT(n_cancelled, 0); + } +} + // Test fork safety on Unix #if !(defined(_WIN32) || defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER) || \ diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 995bed9f195..adfd69c18b3 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -66,7 +66,8 @@ def parse_git(root, **kwargs): from pyarrow.lib import (BuildInfo, RuntimeInfo, VersionInfo, cpp_build_info, cpp_version, cpp_version_info, - runtime_info, cpu_count, set_cpu_count) + runtime_info, cpu_count, set_cpu_count, + enable_signal_handlers) def show_versions(): @@ -177,7 +178,8 @@ def show_versions(): concat_arrays, concat_tables) # Exceptions -from pyarrow.lib import (ArrowCapacityError, +from pyarrow.lib import (ArrowCancelled, + ArrowCapacityError, ArrowException, ArrowKeyError, ArrowIndexError, diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index 2f09a82c20e..cce44d1d8c8 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -34,8 +34,8 @@ from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, pyarrow_unwrap_batch, pyarrow_unwrap_table, pyarrow_wrap_schema, pyarrow_wrap_table, pyarrow_wrap_data_type, pyarrow_unwrap_data_type, - Table, RecordBatch) -from pyarrow.lib import frombytes, tobytes + Table, RecordBatch, StopToken) +from pyarrow.lib import frombytes, tobytes, SignalStopHandler from pyarrow.util import _stringify_path @@ -649,18 +649,27 @@ cdef class CSVStreamingReader(RecordBatchReader): "use pyarrow.csv.open_csv() instead." .format(self.__class__.__name__)) + # Note about cancellation: we cannot create a SignalStopHandler + # by default here, as several CSVStreamingReader instances may be + # created (including by the same thread). Handling cancellation + # would require having the user pass the SignalStopHandler. + # (in addition to solving ARROW-11853) + cdef _open(self, shared_ptr[CInputStream] stream, CCSVReadOptions c_read_options, CCSVParseOptions c_parse_options, CCSVConvertOptions c_convert_options, - CMemoryPool* c_memory_pool): + MemoryPool memory_pool): cdef: shared_ptr[CSchema] c_schema + CIOContext io_context + + io_context = CIOContext(maybe_unbox_memory_pool(memory_pool)) with nogil: self.reader = GetResultValue( CCSVStreamingReader.Make( - c_memory_pool, stream, + io_context, stream, move(c_read_options), move(c_parse_options), move(c_convert_options))) c_schema = self.reader.get().schema() @@ -701,6 +710,7 @@ def read_csv(input_file, read_options=None, parse_options=None, CCSVReadOptions c_read_options CCSVParseOptions c_parse_options CCSVConvertOptions c_convert_options + CIOContext io_context shared_ptr[CCSVReader] reader shared_ptr[CTable] table @@ -709,12 +719,16 @@ def read_csv(input_file, read_options=None, parse_options=None, _get_parse_options(parse_options, &c_parse_options) _get_convert_options(convert_options, &c_convert_options) - reader = GetResultValue(CCSVReader.Make( - CIOContext(maybe_unbox_memory_pool(memory_pool)), - stream, c_read_options, c_parse_options, c_convert_options)) + with SignalStopHandler() as stop_handler: + io_context = CIOContext( + maybe_unbox_memory_pool(memory_pool), + ( stop_handler.stop_token).stop_token) + reader = GetResultValue(CCSVReader.Make( + io_context, stream, + c_read_options, c_parse_options, c_convert_options)) - with nogil: - table = GetResultValue(reader.get().Read()) + with nogil: + table = GetResultValue(reader.get().Read()) return pyarrow_wrap_table(table) @@ -762,8 +776,7 @@ def open_csv(input_file, read_options=None, parse_options=None, reader = CSVStreamingReader.__new__(CSVStreamingReader) reader._open(stream, move(c_read_options), move(c_parse_options), - move(c_convert_options), - maybe_unbox_memory_pool(memory_pool)) + move(c_convert_options), memory_pool) return reader diff --git a/python/pyarrow/error.pxi b/python/pyarrow/error.pxi index a7cb17c09e4..f9e45f238df 100644 --- a/python/pyarrow/error.pxi +++ b/python/pyarrow/error.pxi @@ -15,9 +15,16 @@ # specific language governing permissions and limitations # under the License. +from cpython.exc cimport PyErr_CheckSignals, PyErr_SetInterrupt + from pyarrow.includes.libarrow cimport CStatus, IsPyError, RestorePyError from pyarrow.includes.common cimport c_string +from contextlib import contextmanager +import os +import signal +import threading + class ArrowException(Exception): pass @@ -57,6 +64,12 @@ class ArrowSerializationError(ArrowException): pass +class ArrowCancelled(ArrowException): + def __init__(self, message, signum=None): + super().__init__(message) + self.signum = signum + + # Compatibility alias ArrowIOError = IOError @@ -111,6 +124,12 @@ cdef int check_status(const CStatus& status) nogil except -1: raise ArrowIndexError(message) elif status.IsSerializationError(): raise ArrowSerializationError(message) + elif status.IsCancelled(): + signum = SignalFromStatus(status) + if signum > 0: + raise ArrowCancelled(message, signum) + else: + raise ArrowCancelled(message) else: message = frombytes(status.ToString(), safe=True) raise ArrowException(message) @@ -120,3 +139,93 @@ cdef int check_status(const CStatus& status) nogil except -1: cdef api int pyarrow_internal_check_status(const CStatus& status) \ nogil except -1: return check_status(status) + + +cdef class StopToken: + cdef void init(self, CStopToken stop_token): + self.stop_token = move(stop_token) + + +cdef c_bool signal_handlers_enabled = True + + +def enable_signal_handlers(c_bool enable): + """ + Enable or disable interruption of long-running operations. + + By default, certain long running operations will detect user + interruptions, such as by pressing Ctrl-C. This detection relies + on setting a signal handler for the duration of the long-running + operation, and may therefore interfere with other frameworks or + libraries (such as an event loop). + + Parameters + ---------- + enable: bool + Whether to enable user interruption by setting a temporary + signal handler. + """ + global signal_handlers_enabled + signal_handlers_enabled = enable + + +# For internal use + +cdef class SignalStopHandler: + cdef: + StopToken _stop_token + vector[int] _signals + c_bool _enabled + + def __cinit__(self): + self._enabled = False + + tid = threading.current_thread().ident + if (signal_handlers_enabled and + threading.current_thread() is threading.main_thread()): + self._signals = [ + sig for sig in (signal.SIGINT, signal.SIGTERM) + if signal.getsignal(sig) not in (signal.SIG_DFL, + signal.SIG_IGN, None)] + + if not self._signals.empty(): + self._stop_token = StopToken() + self._stop_token.init(GetResultValue( + SetSignalStopSource()).token()) + self._enabled = True + + def __enter__(self): + if self._enabled: + check_status(RegisterCancellingSignalHandler(self._signals)) + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + if self._enabled: + UnregisterCancellingSignalHandler() + if isinstance(exc_value, ArrowCancelled): + if exc_value.signum: + # Re-emit the exact same signal. We restored the Python signal + # handler above, so it should receive it. + if os.name == 'nt': + SendSignal(exc_value.signum) + else: + SendSignalToThread(exc_value.signum, threading.get_ident()) + else: + # Simulate Python receiving a SIGINT + # (see https://bugs.python.org/issue43356 for why we can't + # simulate the exact signal number) + PyErr_SetInterrupt() + # Maximize chances of the Python signal handler being executed now. + # Otherwise a potential KeyboardInterrupt might be missed by an + # immediately enclosing try/except block. + PyErr_CheckSignals() + # ArrowCancelled will be re-raised if PyErr_CheckSignals() + # returned successfully. + + def __dealloc__(self): + if self._enabled: + ResetSignalStopSource() + + @property + def stop_token(self): + return self._stop_token diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 3ac7ac594e1..3f67a3256cc 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -110,6 +110,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: c_bool IsCapacityError() c_bool IsIndexError() c_bool IsSerializationError() + c_bool IsCancelled() cdef cppclass CStatusDetail "arrow::StatusDetail": c_string ToString() diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 046d1989202..9afe4d1e720 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1130,6 +1130,21 @@ ctypedef void CallbackTransform( object, const shared_ptr[CBuffer]& src, shared_ptr[CBuffer]* dest) +cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil: + cdef cppclass CStopToken "arrow::StopToken": + CStatus Poll() + c_bool IsStopRequested() + + cdef cppclass CStopSource "arrow::StopSource": + CStopToken token() + + CResult[CStopSource*] SetSignalStopSource() + void ResetSignalStopSource() + + CStatus RegisterCancellingSignalHandler(vector[int] signals) + void UnregisterCancellingSignalHandler() + + cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: enum FileMode" arrow::io::FileMode::type": FileMode_READ" arrow::io::FileMode::READ" @@ -1142,7 +1157,9 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: cdef cppclass CIOContext" arrow::io::IOContext": CIOContext() + CIOContext(CStopToken) CIOContext(CMemoryPool*) + CIOContext(CMemoryPool*, CStopToken) CIOContext c_default_io_context "arrow::io::default_io_context"() @@ -1640,7 +1657,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: CRecordBatchReader): @staticmethod CResult[shared_ptr[CCSVStreamingReader]] Make( - CMemoryPool*, shared_ptr[CInputStream], + CIOContext, shared_ptr[CInputStream], CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) cdef CStatus WriteCSV( @@ -2256,6 +2273,11 @@ cdef extern from 'arrow/util/compression.h' namespace 'arrow' nogil: cdef extern from 'arrow/util/io_util.h' namespace 'arrow::internal' nogil: int ErrnoFromStatus(CStatus status) int WinErrorFromStatus(CStatus status) + int SignalFromStatus(CStatus status) + + CStatus SendSignal(int signum) + CStatus SendSignalToThread(int signum, uint64_t thread_id) + cdef extern from 'arrow/util/iterator.h' namespace 'arrow' nogil: cdef cppclass CIterator" arrow::Iterator"[T]: diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index fb390e1af42..3ce10f3b999 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -504,6 +504,14 @@ cdef class Codec(_Weakrefable): cdef inline CCodec* unwrap(self) nogil +# This class is only used internally for now +cdef class StopToken: + cdef: + CStopToken stop_token + + cdef void init(self, CStopToken stop_token) + + cdef get_input_stream(object source, c_bool use_memory_map, shared_ptr[CInputStream]* reader) cdef get_reader(object source, c_bool use_memory_map, diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index 5ca31aefebc..be662b20e74 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -25,8 +25,11 @@ import os import pickle import shutil +import signal import string +import sys import tempfile +import threading import time import unittest @@ -891,6 +894,38 @@ def test_stress_convert_options_blowup(self): assert table.num_rows == 0 assert table.column_names == col_names + def test_cancellation(self): + if (threading.current_thread().ident != + threading.main_thread().ident): + pytest.skip("test only works from main Python thread") + + if sys.version_info >= (3, 8): + raise_signal = signal.raise_signal + elif os.name == 'nt': + # On Windows, os.kill() doesn't actually send a signal, + # it just terminates the process with the given exit code. + pytest.skip("test requires Python 3.8+ on Windows") + else: + # On Unix, emulate raise_signal() with os.kill(). + def raise_signal(signum): + os.kill(os.getpid(), signum) + + large_csv = b"a,b,c\n" + b"1,2,3\n" * 30000000 + + def signal_from_thread(): + time.sleep(0.2) + raise_signal(signal.SIGINT) + + t1 = time.time() + with pytest.raises(KeyboardInterrupt) as exc_info: + threading.Thread(target=signal_from_thread).start() + self.read_bytes(large_csv) + dt = time.time() - t1 + assert dt <= 1.0 + e = exc_info.value.__context__ + assert isinstance(e, pa.ArrowCancelled) + assert e.signum == signal.SIGINT + class TestSerialCSVRead(BaseTestCSVRead, unittest.TestCase):