diff --git a/c_glib/arrow-glib/reader.cpp b/c_glib/arrow-glib/reader.cpp index c3082271ca5..17100e76a3c 100644 --- a/c_glib/arrow-glib/reader.cpp +++ b/c_glib/arrow-glib/reader.cpp @@ -1592,6 +1592,7 @@ garrow_csv_reader_new(GArrowInputStream *input, auto arrow_reader = arrow::csv::TableReader::Make(arrow::default_memory_pool(), + arrow::io::AsyncContext(), arrow_input, read_options, parse_options, diff --git a/cpp/examples/minimal_build/example.cc b/cpp/examples/minimal_build/example.cc index 4b6acd2a0dd..8f58de5777a 100644 --- a/cpp/examples/minimal_build/example.cc +++ b/cpp/examples/minimal_build/example.cc @@ -39,6 +39,7 @@ Status RunMain(int argc, char** argv) { ARROW_ASSIGN_OR_RAISE( auto csv_reader, arrow::csv::TableReader::Make(arrow::default_memory_pool(), + arrow::io::AsyncContext(), input_file, arrow::csv::ReadOptions::Defaults(), arrow::csv::ParseOptions::Defaults(), diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 1e93cf9975a..4403def9949 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -189,7 +189,6 @@ set(ARROW_SRCS util/future.cc util/int_util.cc util/io_util.cc - util/iterator.cc util/logging.cc util/key_value_metadata.cc util/memory.cc diff --git a/cpp/src/arrow/csv/CMakeLists.txt b/cpp/src/arrow/csv/CMakeLists.txt index 84b1a103264..2766cfd3bd2 100644 --- a/cpp/src/arrow/csv/CMakeLists.txt +++ b/cpp/src/arrow/csv/CMakeLists.txt @@ -21,7 +21,8 @@ add_arrow_test(csv-test column_builder_test.cc column_decoder_test.cc converter_test.cc - parser_test.cc) + parser_test.cc + reader_test.cc) add_arrow_benchmark(converter_benchmark PREFIX "arrow-csv") add_arrow_benchmark(parser_benchmark PREFIX "arrow-csv") diff --git a/cpp/src/arrow/csv/column_decoder.cc b/cpp/src/arrow/csv/column_decoder.cc index c57477ef59d..1dd13bc9086 100644 --- a/cpp/src/arrow/csv/column_decoder.cc +++ b/cpp/src/arrow/csv/column_decoder.cc @@ -84,7 +84,7 @@ class ConcreteColumnDecoder : public ColumnDecoder { auto chunk_index = next_chunk_++; WaitForChunkUnlocked(chunk_index); // Move Future to avoid keeping chunk alive - return std::move(chunks_[chunk_index]).result(); + return chunks_[chunk_index].MoveResult(); } protected: diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index cf5047aaf16..f0fa1f206d3 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -40,6 +40,8 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/type.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/future.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" @@ -60,8 +62,7 @@ class InputStream; namespace csv { -using internal::GetCpuThreadPool; -using internal::ThreadPool; +using internal::Executor; struct ConversionSchema { struct Column { @@ -94,20 +95,24 @@ struct ConversionSchema { // An iterator of Buffers that makes sure there is no straddling CRLF sequence. class CSVBufferIterator { public: - explicit CSVBufferIterator(Iterator> buffer_iterator) - : buffer_iterator_(std::move(buffer_iterator)) {} - static Iterator> Make( Iterator> buffer_iterator) { - CSVBufferIterator it(std::move(buffer_iterator)); - return Iterator>(std::move(it)); + Transformer, std::shared_ptr> fn = + CSVBufferIterator(); + return MakeTransformedIterator(std::move(buffer_iterator), fn); + } + + static AsyncGenerator> MakeAsync( + AsyncGenerator> buffer_iterator) { + Transformer, std::shared_ptr> fn = + CSVBufferIterator(); + return MakeAsyncGenerator(std::move(buffer_iterator), fn); } - Result> Next() { - ARROW_ASSIGN_OR_RAISE(auto buf, buffer_iterator_.Next()); + Result>> operator()(std::shared_ptr buf) { if (buf == nullptr) { // EOF - return nullptr; + return TransformFinish(); } int64_t offset = 0; @@ -127,14 +132,13 @@ class CSVBufferIterator { buf = SliceBuffer(buf, offset); if (buf->size() == 0) { // EOF - return nullptr; + return TransformFinish(); } else { - return buf; + return TransformYield(buf); } } protected: - Iterator> buffer_iterator_; bool first_buffer_ = true; // Whether there was a trailing CR at the end of last received buffer bool trailing_cr_ = false; @@ -150,20 +154,36 @@ struct CSVBlock { std::function consume_bytes; }; +} // namespace csv + +template <> +struct IterationTraits { + static csv::CSVBlock End() { return csv::CSVBlock{{}, {}, {}, -1, true, {}}; } +}; + +namespace csv { + +// The == operator must be defined to be used as T in Iterator +bool operator==(const CSVBlock& left, const CSVBlock& right) { + return left.block_index == right.block_index; +} +bool operator!=(const CSVBlock& left, const CSVBlock& right) { + return left.block_index != right.block_index; +} + +// This is a callable that can be used to transform an iterator. The source iterator +// will contain buffers of data and the output iterator will contain delimited CSV +// blocks. util::optional is used so that there is an end token (required by the +// iterator APIs (e.g. Visit)) even though an empty optional is never used in this code. class BlockReader { public: - BlockReader(std::unique_ptr chunker, - Iterator> buffer_iterator, - std::shared_ptr first_buffer) + BlockReader(std::unique_ptr chunker, std::shared_ptr first_buffer) : chunker_(std::move(chunker)), - buffer_iterator_(std::move(buffer_iterator)), partial_(std::make_shared("")), buffer_(std::move(first_buffer)) {} protected: std::unique_ptr chunker_; - Iterator> buffer_iterator_; - std::shared_ptr partial_, buffer_; int64_t block_index_ = 0; // Whether there was a trailing CR at the end of last received buffer @@ -177,14 +197,25 @@ class SerialBlockReader : public BlockReader { public: using BlockReader::BlockReader; - Result> Next() { + static Iterator MakeIterator( + Iterator> buffer_iterator, std::unique_ptr chunker, + std::shared_ptr first_buffer) { + auto block_reader = + std::make_shared(std::move(chunker), first_buffer); + // Wrap shared pointer in callable + Transformer, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr buf) { + return (*block_reader)(std::move(buf)); + }; + return MakeTransformedIterator(std::move(buffer_iterator), block_reader_fn); + } + + Result> operator()(std::shared_ptr next_buffer) { if (buffer_ == nullptr) { - // EOF - return util::optional(); + return TransformFinish(); } - std::shared_ptr next_buffer, completion; - ARROW_ASSIGN_OR_RAISE(next_buffer, buffer_iterator_.Next()); + std::shared_ptr completion; bool is_final = (next_buffer == nullptr); if (is_final) { @@ -210,8 +241,9 @@ class SerialBlockReader : public BlockReader { return Status::OK(); }; - return CSVBlock{partial_, completion, buffer_, - block_index_++, is_final, std::move(consume_bytes)}; + return TransformYield(CSVBlock{partial_, completion, buffer_, + block_index_++, is_final, + std::move(consume_bytes)}); } }; @@ -220,14 +252,35 @@ class ThreadedBlockReader : public BlockReader { public: using BlockReader::BlockReader; - Result> Next() { + static Iterator MakeIterator( + Iterator> buffer_iterator, std::unique_ptr chunker, + std::shared_ptr first_buffer) { + auto block_reader = + std::make_shared(std::move(chunker), first_buffer); + // Wrap shared pointer in callable + Transformer, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr next) { return (*block_reader)(next); }; + return MakeTransformedIterator(std::move(buffer_iterator), block_reader_fn); + } + + static AsyncGenerator MakeAsyncIterator( + AsyncGenerator> buffer_generator, + std::unique_ptr chunker, std::shared_ptr first_buffer) { + auto block_reader = + std::make_shared(std::move(chunker), first_buffer); + // Wrap shared pointer in callable + Transformer, CSVBlock> block_reader_fn = + [block_reader](std::shared_ptr next) { return (*block_reader)(next); }; + return MakeAsyncGenerator(std::move(buffer_generator), block_reader_fn); + } + + Result> operator()(std::shared_ptr next_buffer) { if (buffer_ == nullptr) { // EOF - return util::optional(); + return TransformFinish(); } - std::shared_ptr next_buffer, whole, completion, next_partial; - ARROW_ASSIGN_OR_RAISE(next_buffer, buffer_iterator_.Next()); + std::shared_ptr whole, completion, next_partial; bool is_final = (next_buffer == nullptr); auto current_partial = std::move(partial_); @@ -252,7 +305,8 @@ class ThreadedBlockReader : public BlockReader { partial_ = std::move(next_partial); buffer_ = std::move(next_buffer); - return CSVBlock{current_partial, completion, whole, block_index_++, is_final, {}}; + return TransformYield( + CSVBlock{current_partial, completion, whole, block_index_++, is_final, {}}); } }; @@ -449,7 +503,6 @@ class ReaderMixin { ConversionSchema conversion_schema_; std::shared_ptr input_; - Iterator> buffer_iterator_; std::shared_ptr task_group_; }; @@ -462,6 +515,10 @@ class BaseTableReader : public ReaderMixin, public csv::TableReader { virtual Status Init() = 0; + Future> ReadAsync() override { + return Future>::MakeFinished(Read()); + } + protected: // Make column builders from conversion schema Status MakeColumnBuilders() { @@ -624,6 +681,7 @@ class BaseStreamingReader : public ReaderMixin, public csv::StreamingReader { std::vector> column_decoders_; std::shared_ptr schema_; std::shared_ptr pending_batch_; + Iterator> buffer_iterator_; bool eof_ = false; }; @@ -656,7 +714,7 @@ class SerialStreamingReader : public BaseStreamingReader { if (eof_) { return nullptr; } - if (block_reader_ == nullptr) { + if (!block_iterator_) { Status st = SetupReader(); if (!st.ok()) { // Can't setup reader => bail out @@ -670,18 +728,18 @@ class SerialStreamingReader : public BaseStreamingReader { } if (!source_eof_) { - ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader_->Next()); - if (maybe_block.has_value()) { - last_block_index_ = maybe_block->block_index; - auto maybe_parsed = ParseAndInsert(maybe_block->partial, maybe_block->completion, - maybe_block->buffer, maybe_block->block_index, - maybe_block->is_final); + ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator_.Next()); + if (maybe_block != IterationTraits::End()) { + last_block_index_ = maybe_block.block_index; + auto maybe_parsed = ParseAndInsert(maybe_block.partial, maybe_block.completion, + maybe_block.buffer, maybe_block.block_index, + maybe_block.is_final); if (!maybe_parsed.ok()) { // Parse error => bail out eof_ = true; return maybe_parsed.status(); } - RETURN_NOT_OK(maybe_block->consume_bytes(*maybe_parsed)); + RETURN_NOT_OK(maybe_block.consume_bytes(*maybe_parsed)); } else { source_eof_ = true; for (auto& decoder : column_decoders_) { @@ -705,15 +763,15 @@ class SerialStreamingReader : public BaseStreamingReader { RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer)); RETURN_NOT_OK(MakeColumnDecoders()); - block_reader_ = std::make_shared(MakeChunker(parse_options_), - std::move(buffer_iterator_), - std::move(first_buffer)); + block_iterator_ = SerialBlockReader::MakeIterator(std::move(buffer_iterator_), + MakeChunker(parse_options_), + std::move(first_buffer)); return Status::OK(); } bool source_eof_ = false; int64_t last_block_index_ = 0; - std::shared_ptr block_reader_; + Iterator block_iterator_; }; ///////////////////////////////////////////////////////////////////////// @@ -746,41 +804,46 @@ class SerialTableReader : public BaseTableReader { RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer)); RETURN_NOT_OK(MakeColumnBuilders()); - SerialBlockReader block_reader(MakeChunker(parse_options_), - std::move(buffer_iterator_), std::move(first_buffer)); - + auto block_iterator = SerialBlockReader::MakeIterator(std::move(buffer_iterator_), + MakeChunker(parse_options_), + std::move(first_buffer)); while (true) { - ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader.Next()); - if (!maybe_block.has_value()) { + ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next()); + if (maybe_block == IterationTraits::End()) { // EOF break; } - ARROW_ASSIGN_OR_RAISE(int64_t parsed_bytes, - ParseAndInsert(maybe_block->partial, maybe_block->completion, - maybe_block->buffer, maybe_block->block_index, - maybe_block->is_final)); - RETURN_NOT_OK(maybe_block->consume_bytes(parsed_bytes)); + ARROW_ASSIGN_OR_RAISE( + int64_t parsed_bytes, + ParseAndInsert(maybe_block.partial, maybe_block.completion, maybe_block.buffer, + maybe_block.block_index, maybe_block.is_final)); + RETURN_NOT_OK(maybe_block.consume_bytes(parsed_bytes)); } // Finish conversion, create schema and table RETURN_NOT_OK(task_group_->Finish()); return MakeTable(); } -}; -///////////////////////////////////////////////////////////////////////// -// Parallel TableReader implementation + protected: + Iterator> buffer_iterator_; +}; -class ThreadedTableReader : public BaseTableReader { +class AsyncThreadedTableReader + : public BaseTableReader, + public std::enable_shared_from_this { public: using BaseTableReader::BaseTableReader; - ThreadedTableReader(MemoryPool* pool, std::shared_ptr input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options, ThreadPool* thread_pool) + AsyncThreadedTableReader(MemoryPool* pool, std::shared_ptr input, + const ReadOptions& read_options, + const ParseOptions& parse_options, + const ConvertOptions& convert_options, Executor* cpu_executor, + Executor* io_executor) : BaseTableReader(pool, input, read_options, parse_options, convert_options), - thread_pool_(thread_pool) {} + cpu_executor_(cpu_executor), + io_executor_(io_executor) {} - ~ThreadedTableReader() override { + ~AsyncThreadedTableReader() override { if (task_group_) { // In case of error, make sure all pending tasks are finished before // we start destroying BaseTableReader members @@ -792,65 +855,98 @@ class ThreadedTableReader : public BaseTableReader { ARROW_ASSIGN_OR_RAISE(auto istream_it, io::MakeInputStreamIterator(input_, read_options_.block_size)); - int32_t block_queue_size = thread_pool_->GetCapacity(); - ARROW_ASSIGN_OR_RAISE(auto rh_it, - MakeReadaheadIterator(std::move(istream_it), block_queue_size)); - buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it)); + // TODO: use io_executor_ here, see ARROW-11590 + ARROW_ASSIGN_OR_RAISE(auto background_executor, internal::ThreadPool::Make(1)); + ARROW_ASSIGN_OR_RAISE(auto bg_it, MakeBackgroundGenerator(std::move(istream_it), + background_executor.get())); + AsyncGenerator> wrapped_bg_it = + [bg_it, background_executor]() { return bg_it(); }; + + auto transferred_it = + MakeTransferredGenerator(std::move(wrapped_bg_it), cpu_executor_); + + int32_t block_queue_size = cpu_executor_->GetCapacity(); + auto rh_it = MakeReadaheadGenerator(std::move(transferred_it), block_queue_size); + buffer_generator_ = CSVBufferIterator::MakeAsync(std::move(rh_it)); return Status::OK(); } - Result> Read() override { - task_group_ = internal::TaskGroup::MakeThreaded(thread_pool_); + Result> Read() override { return ReadAsync().result(); } + + Future> ReadAsync() override { + task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_); + + auto self = shared_from_this(); + return ProcessFirstBuffer().Then([self](std::shared_ptr first_buffer) { + auto block_generator = ThreadedBlockReader::MakeAsyncIterator( + self->buffer_generator_, MakeChunker(self->parse_options_), + std::move(first_buffer)); + + std::function block_visitor = + [self](CSVBlock maybe_block) -> Status { + // The logic in VisitAsyncGenerator ensures that we will never be + // passed an empty block (visit does not call with the end token) so + // we can be assured maybe_block has a value. + DCHECK_GE(maybe_block.block_index, 0); + DCHECK(!maybe_block.consume_bytes); + + // Launch parse task + self->task_group_->Append([self, maybe_block] { + return self + ->ParseAndInsert(maybe_block.partial, maybe_block.completion, + maybe_block.buffer, maybe_block.block_index, + maybe_block.is_final) + .status(); + }); + return Status::OK(); + }; + + return VisitAsyncGenerator(std::move(block_generator), block_visitor) + .Then([self](...) -> Future<> { + // By this point we've added all top level tasks so it is safe to call + // FinishAsync + return self->task_group_->FinishAsync(); + }) + .Then([self](...) -> Result> { + // Finish conversion, create schema and table + return self->MakeTable(); + }); + }); + } + protected: + Future> ProcessFirstBuffer() { // First block - ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next()); - if (first_buffer == nullptr) { - return Status::Invalid("Empty CSV file"); - } - RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer)); - RETURN_NOT_OK(MakeColumnBuilders()); - - ThreadedBlockReader block_reader(MakeChunker(parse_options_), - std::move(buffer_iterator_), - std::move(first_buffer)); - - while (true) { - ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_reader.Next()); - if (!maybe_block.has_value()) { - // EOF - break; + auto first_buffer_future = buffer_generator_(); + return first_buffer_future.Then([this](const std::shared_ptr& first_buffer) + -> Result> { + if (first_buffer == nullptr) { + return Status::Invalid("Empty CSV file"); } - DCHECK(!maybe_block->consume_bytes); - - // Launch parse task - task_group_->Append([this, maybe_block] { - return ParseAndInsert(maybe_block->partial, maybe_block->completion, - maybe_block->buffer, maybe_block->block_index, - maybe_block->is_final) - .status(); - }); - } - - // Finish conversion, create schema and table - RETURN_NOT_OK(task_group_->Finish()); - return MakeTable(); + std::shared_ptr first_buffer_processed; + RETURN_NOT_OK(ProcessHeader(first_buffer, &first_buffer_processed)); + RETURN_NOT_OK(MakeColumnBuilders()); + return first_buffer_processed; + }); } - protected: - ThreadPool* thread_pool_; + Executor* cpu_executor_; + Executor* io_executor_; + AsyncGenerator> buffer_generator_; }; ///////////////////////////////////////////////////////////////////////// // Factory functions Result> TableReader::Make( - MemoryPool* pool, std::shared_ptr input, - const ReadOptions& read_options, const ParseOptions& parse_options, - const ConvertOptions& convert_options) { + MemoryPool* pool, io::AsyncContext async_context, + std::shared_ptr input, const ReadOptions& read_options, + const ParseOptions& parse_options, const ConvertOptions& convert_options) { std::shared_ptr reader; if (read_options.use_threads) { - reader = std::make_shared( - pool, input, read_options, parse_options, convert_options, GetCpuThreadPool()); + reader = std::make_shared( + pool, input, read_options, parse_options, convert_options, async_context.executor, + internal::GetCpuThreadPool()); } else { reader = std::make_shared(pool, input, read_options, parse_options, convert_options); @@ -871,4 +967,5 @@ Result> StreamingReader::Make( } } // namespace csv + } // namespace arrow diff --git a/cpp/src/arrow/csv/reader.h b/cpp/src/arrow/csv/reader.h index 652cedc8c74..c361fbddce9 100644 --- a/cpp/src/arrow/csv/reader.h +++ b/cpp/src/arrow/csv/reader.h @@ -20,10 +20,12 @@ #include #include "arrow/csv/options.h" // IWYU pragma: keep +#include "arrow/io/interfaces.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/type.h" #include "arrow/type_fwd.h" +#include "arrow/util/future.h" #include "arrow/util/visibility.h" namespace arrow { @@ -40,9 +42,12 @@ class ARROW_EXPORT TableReader { /// Read the entire CSV file and convert it to a Arrow Table virtual Result> Read() = 0; + /// Read the entire CSV file and convert it to a Arrow Table + virtual Future> ReadAsync() = 0; /// Create a TableReader instance static Result> Make(MemoryPool* pool, + io::AsyncContext async_context, std::shared_ptr input, const ReadOptions&, const ParseOptions&, diff --git a/cpp/src/arrow/csv/reader_test.cc b/cpp/src/arrow/csv/reader_test.cc new file mode 100644 index 00000000000..64010ae481a --- /dev/null +++ b/cpp/src/arrow/csv/reader_test.cc @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include + +#include "arrow/csv/options.h" +#include "arrow/csv/reader.h" +#include "arrow/csv/test_common.h" +#include "arrow/io/interfaces.h" +#include "arrow/io/memory.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/future.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { +namespace csv { + +using TableReaderFactory = + std::function>(std::shared_ptr)>; + +void StressTableReader(TableReaderFactory reader_factory) { + const int NTASKS = 100; + const int NROWS = 1000; + ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS)); + + std::vector>> task_futures(NTASKS); + for (int i = 0; i < NTASKS; i++) { + auto input = std::make_shared(table_buffer); + ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input)); + task_futures[i] = reader->ReadAsync(); + } + auto combined_future = All(task_futures); + combined_future.Wait(); + + ASSERT_OK_AND_ASSIGN(std::vector>> results, + combined_future.result()); + for (auto&& result : results) { + ASSERT_OK_AND_ASSIGN(auto table, result); + ASSERT_EQ(NROWS, table->num_rows()); + } +} + +void StressInvalidTableReader(TableReaderFactory reader_factory) { + const int NTASKS = 100; + const int NROWS = 1000; + ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS, false)); + + std::vector>> task_futures(NTASKS); + for (int i = 0; i < NTASKS; i++) { + auto input = std::make_shared(table_buffer); + ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input)); + task_futures[i] = reader->ReadAsync(); + } + auto combined_future = All(task_futures); + combined_future.Wait(); + + ASSERT_OK_AND_ASSIGN(std::vector>> results, + combined_future.result()); + for (auto&& result : results) { + ASSERT_RAISES(Invalid, result); + } +} + +void TestNestedParallelism(std::shared_ptr thread_pool, + TableReaderFactory reader_factory) { + const int NROWS = 1000; + ASSERT_OK_AND_ASSIGN(auto table_buffer, MakeSampleCsvBuffer(NROWS)); + auto input = std::make_shared(table_buffer); + ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(input)); + + Future> table_future; + + auto read_task = [&reader, &table_future]() mutable { + table_future = reader->ReadAsync(); + return Status::OK(); + }; + ASSERT_OK_AND_ASSIGN(auto future, thread_pool->Submit(read_task)); + + ASSERT_FINISHES_OK(future); + ASSERT_FINISHES_OK_AND_ASSIGN(auto table, table_future); + ASSERT_EQ(table->num_rows(), NROWS); +} // namespace csv + +TableReaderFactory MakeSerialFactory() { + return [](std::shared_ptr input_stream) { + auto read_options = ReadOptions::Defaults(); + read_options.block_size = 1 << 10; + read_options.use_threads = false; + return TableReader::Make(default_memory_pool(), io::AsyncContext(), input_stream, + read_options, ParseOptions::Defaults(), + ConvertOptions::Defaults()); + }; +} + +TEST(SerialReaderTests, Stress) { StressTableReader(MakeSerialFactory()); } +TEST(SerialReaderTests, StressInvalid) { StressInvalidTableReader(MakeSerialFactory()); } +TEST(SerialReaderTests, NestedParallelism) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); + TestNestedParallelism(thread_pool, MakeSerialFactory()); +} + +Result MakeAsyncFactory( + std::shared_ptr thread_pool = nullptr) { + if (!thread_pool) { + ARROW_ASSIGN_OR_RAISE(thread_pool, internal::ThreadPool::Make(1)); + } + return [thread_pool](std::shared_ptr input_stream) + -> Result> { + ReadOptions read_options = ReadOptions::Defaults(); + read_options.use_threads = true; + read_options.block_size = 1 << 10; + auto table_reader = TableReader::Make( + default_memory_pool(), io::AsyncContext(thread_pool.get()), input_stream, + read_options, ParseOptions::Defaults(), ConvertOptions::Defaults()); + return table_reader; + }; +} + +TEST(AsyncReaderTests, Stress) { + ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory()); + StressTableReader(table_factory); +} +TEST(AsyncReaderTests, StressInvalid) { + ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory()); + StressInvalidTableReader(table_factory); +} +TEST(AsyncReaderTests, NestedParallelism) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); + ASSERT_OK_AND_ASSIGN(auto table_factory, MakeAsyncFactory(thread_pool)); + TestNestedParallelism(thread_pool, table_factory); +} + +} // namespace csv +} // namespace arrow diff --git a/cpp/src/arrow/csv/test_common.cc b/cpp/src/arrow/csv/test_common.cc index 08981a70501..c3d0241aa38 100644 --- a/cpp/src/arrow/csv/test_common.cc +++ b/cpp/src/arrow/csv/test_common.cc @@ -61,5 +61,59 @@ void MakeColumnParser(std::vector items, std::shared_ptrnum_rows(), items.size()); } +namespace { + +const std::vector int64_rows = {"123", "4", "-317005557", "", "N/A", "0"}; +const std::vector float_rows = {"0", "123.456", "-3170.55766", "", "N/A"}; +const std::vector decimal128_rows = {"0", "123.456", "-3170.55766", + "", "N/A", "1233456789.123456789"}; +const std::vector iso8601_rows = {"1917-10-17", "2018-09-13", + "1941-06-22 04:00", "1945-05-09 09:45:38"}; +const std::vector strptime_rows = {"10/17/1917", "9/13/2018", "9/5/1945"}; + +static void WriteHeader(std::ostream& writer) { + writer << "Int64,Float,Decimal128,ISO8601,Strptime" << std::endl; +} + +static std::string GetCell(const std::vector& base_rows, size_t row_index) { + return base_rows[row_index % base_rows.size()]; +} + +static void WriteRow(std::ostream& writer, size_t row_index) { + writer << GetCell(int64_rows, row_index); + writer << ','; + writer << GetCell(float_rows, row_index); + writer << ','; + writer << GetCell(decimal128_rows, row_index); + writer << ','; + writer << GetCell(iso8601_rows, row_index); + writer << ','; + writer << GetCell(strptime_rows, row_index); + writer << std::endl; +} + +static void WriteInvalidRow(std::ostream& writer, size_t row_index) { + writer << "\"" << std::endl << "\""; + writer << std::endl; +} +} // namespace + +Result> MakeSampleCsvBuffer(size_t num_rows, bool valid) { + std::stringstream writer; + + WriteHeader(writer); + for (size_t i = 0; i < num_rows; ++i) { + if (i == num_rows / 2 && !valid) { + WriteInvalidRow(writer, i); + } else { + WriteRow(writer, i); + } + } + + auto table_str = writer.str(); + auto table_buffer = std::make_shared(table_str); + return MemoryManager::CopyBuffer(table_buffer, default_cpu_memory_manager()); +} + } // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/csv/test_common.h b/cpp/src/arrow/csv/test_common.h index 119da03a83d..823cf643fa0 100644 --- a/cpp/src/arrow/csv/test_common.h +++ b/cpp/src/arrow/csv/test_common.h @@ -46,5 +46,8 @@ void MakeCSVParser(std::vector lines, std::shared_ptr* ARROW_TESTING_EXPORT void MakeColumnParser(std::vector items, std::shared_ptr* out); +ARROW_TESTING_EXPORT +Result> MakeSampleCsvBuffer(size_t num_rows, bool valid = true); + } // namespace csv } // namespace arrow diff --git a/cpp/src/arrow/json/reader.cc b/cpp/src/arrow/json/reader.cc index dc0d6e04d11..44aa2607d9e 100644 --- a/cpp/src/arrow/json/reader.cc +++ b/cpp/src/arrow/json/reader.cc @@ -29,6 +29,7 @@ #include "arrow/json/parser.h" #include "arrow/record_batch.h" #include "arrow/table.h" +#include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/string_view.h" diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 6504d950674..0172a852434 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -317,7 +317,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } const T& operator*() const& { return ValueOrDie(); } - const T* operator->() const& { return &ValueOrDie(); } + const T* operator->() const { return &ValueOrDie(); } /// Gets a mutable reference to the stored `T` value. /// @@ -332,7 +332,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { return ValueUnsafe(); } T& operator*() & { return ValueOrDie(); } - T* operator->() & { return &ValueOrDie(); } + T* operator->() { return &ValueOrDie(); } /// Moves and returns the internally-stored `T` value. /// diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index cdb23a92899..fafccc2930d 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -135,15 +135,55 @@ ASSERT_EQ(expected, _actual); \ } while (0) +// This macro should be called by futures that are expected to +// complete pretty quickly. 2 seconds is the default max wait +// here. Anything longer than that and it's a questionable +// unit test anyways. +#define ASSERT_FINISHES_IMPL(fut) \ + do { \ + ASSERT_TRUE(fut.Wait(10)); \ + if (!fut.is_finished()) { \ + FAIL() << "Future did not finish in a timely fashion"; \ + } \ + } while (false) + +#define ASSERT_FINISHES_OK(expr) \ + do { \ + auto&& _fut = (expr); \ + ASSERT_TRUE(_fut.Wait(10)); \ + if (!_fut.is_finished()) { \ + FAIL() << "Future did not finish in a timely fashion"; \ + } \ + auto _st = _fut.status(); \ + if (!_st.ok()) { \ + FAIL() << "'" ARROW_STRINGIFY(expr) "' failed with " << _st.ToString(); \ + } \ + } while (false) + +#define ASSERT_FINISHES_ERR(ENUM, expr) \ + do { \ + auto&& fut = (expr); \ + ASSERT_FINISHES_IMPL(fut); \ + ASSERT_RAISES(ENUM, fut.status()); \ + } while (false) + +#define ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, future_name) \ + auto future_name = (rexpr); \ + ASSERT_FINISHES_IMPL(future_name); \ + ASSERT_OK_AND_ASSIGN(lhs, future_name.result()); + +#define ASSERT_FINISHES_OK_AND_ASSIGN(lhs, rexpr) \ + ASSERT_FINISHES_OK_AND_ASSIGN_IMPL(lhs, rexpr, \ + ARROW_ASSIGN_OR_RAISE_NAME(_fut, __COUNTER__)) + namespace arrow { +// ---------------------------------------------------------------------- +// Useful testing::Types declarations inline void PrintTo(StatusCode code, std::ostream* os) { *os << Status::CodeAsString(code); } -// ---------------------------------------------------------------------- -// Useful testing::Types declarations - using NumericArrowTypes = ::testing::Types; diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h new file mode 100644 index 00000000000..8e88813d611 --- /dev/null +++ b/cpp/src/arrow/util/async_generator.h @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once +#include + +#include "arrow/util/functional.h" +#include "arrow/util/future.h" +#include "arrow/util/iterator.h" +#include "arrow/util/optional.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +template +using AsyncGenerator = std::function()>; + +/// Iterates through a generator of futures, visiting the result of each one and +/// returning a future that completes when all have been visited +template +Future<> VisitAsyncGenerator(AsyncGenerator generator, + std::function visitor) { + struct LoopBody { + struct Callback { + Result> operator()(const T& result) { + if (result == IterationTraits::End()) { + return Break(detail::Empty()); + } else { + auto visited = visitor(result); + if (visited.ok()) { + return Continue(); + } else { + return visited; + } + } + } + + std::function visitor; + }; + + Future> operator()() { + Callback callback{visitor}; + auto next = generator(); + return next.Then(std::move(callback)); + } + + AsyncGenerator generator; + std::function visitor; + }; + + return Loop(LoopBody{std::move(generator), std::move(visitor)}); +} + +template +Future> CollectAsyncGenerator(AsyncGenerator generator) { + auto vec = std::make_shared>(); + struct LoopBody { + Future>> operator()() { + auto next = generator(); + auto vec = vec_; + return next.Then([vec](const T& result) -> Result>> { + if (result == IterationTraits::End()) { + return Break(*vec); + } else { + vec->push_back(result); + return Continue(); + } + }); + } + AsyncGenerator generator; + std::shared_ptr> vec_; + }; + return Loop(LoopBody{std::move(generator), std::move(vec)}); +} + +template +class TransformingGenerator { + // The transforming generator state will be referenced as an async generator but will + // also be referenced via callback to various futures. If the async generator owner + // moves it around we need the state to be consistent for future callbacks. + struct TransformingGeneratorState + : std::enable_shared_from_this { + TransformingGeneratorState(AsyncGenerator generator, Transformer transformer) + : generator_(std::move(generator)), + transformer_(std::move(transformer)), + last_value_(), + finished_() {} + + Future operator()() { + while (true) { + auto maybe_next_result = Pump(); + if (!maybe_next_result.ok()) { + return Future::MakeFinished(maybe_next_result.status()); + } + auto maybe_next = std::move(maybe_next_result).ValueUnsafe(); + if (maybe_next.has_value()) { + return Future::MakeFinished(*std::move(maybe_next)); + } + + auto next_fut = generator_(); + // If finished already, process results immediately inside the loop to avoid stack + // overflow + if (next_fut.is_finished()) { + auto next_result = next_fut.result(); + if (next_result.ok()) { + last_value_ = *next_result; + } else { + return Future::MakeFinished(next_result.status()); + } + // Otherwise, if not finished immediately, add callback to process results + } else { + auto self = this->shared_from_this(); + return next_fut.Then([self](const Result& next_result) { + if (next_result.ok()) { + self->last_value_ = *next_result; + return (*self)(); + } else { + return Future::MakeFinished(next_result.status()); + } + }); + } + } + } + + // See comment on TransformingIterator::Pump + Result> Pump() { + if (!finished_ && last_value_.has_value()) { + ARROW_ASSIGN_OR_RAISE(TransformFlow next, transformer_(*last_value_)); + if (next.ReadyForNext()) { + if (*last_value_ == IterationTraits::End()) { + finished_ = true; + } + last_value_.reset(); + } + if (next.Finished()) { + finished_ = true; + } + if (next.HasValue()) { + return next.Value(); + } + } + if (finished_) { + return IterationTraits::End(); + } + return util::nullopt; + } + + AsyncGenerator generator_; + Transformer transformer_; + util::optional last_value_; + bool finished_; + }; + + public: + explicit TransformingGenerator(AsyncGenerator generator, + Transformer transformer) + : state_(std::make_shared(std::move(generator), + std::move(transformer))) {} + + Future operator()() { return (*state_)(); } + + protected: + std::shared_ptr state_; +}; + +template +class ReadaheadGenerator { + public: + ReadaheadGenerator(AsyncGenerator source_generator, int max_readahead) + : source_generator_(std::move(source_generator)), max_readahead_(max_readahead) { + auto finished = std::make_shared>(); + mark_finished_if_done_ = [finished](const Result& next_result) { + if (!next_result.ok()) { + finished->store(true); + } else { + const auto& next = *next_result; + if (next == IterationTraits::End()) { + *finished = true; + } + } + }; + finished_ = std::move(finished); + } + + Future operator()() { + if (readahead_queue_.empty()) { + // This is the first request, let's pump the underlying queue + for (int i = 0; i < max_readahead_; i++) { + auto next = source_generator_(); + next.AddCallback(mark_finished_if_done_); + readahead_queue_.push(std::move(next)); + } + } + // Pop one and add one + auto result = readahead_queue_.front(); + readahead_queue_.pop(); + if (finished_->load()) { + readahead_queue_.push(Future::MakeFinished(IterationTraits::End())); + } else { + auto back_of_queue = source_generator_(); + back_of_queue.AddCallback(mark_finished_if_done_); + readahead_queue_.push(std::move(back_of_queue)); + } + return result; + } + + private: + AsyncGenerator source_generator_; + int max_readahead_; + std::function&)> mark_finished_if_done_; + // Can't use a bool here because finished may be referenced by callbacks that + // outlive this class + std::shared_ptr> finished_; + std::queue> readahead_queue_; +}; + +/// \brief Creates a generator that pulls reentrantly from a source +/// This generator will pull reentrantly from a source, ensuring that max_readahead +/// requests are active at any given time. +/// +/// The source generator must be async-reentrant +/// +/// This generator itself is async-reentrant. +template +AsyncGenerator MakeReadaheadGenerator(AsyncGenerator source_generator, + int max_readahead) { + return ReadaheadGenerator(std::move(source_generator), max_readahead); +} + +/// \brief Transforms an async generator using a transformer function returning a new +/// AsyncGenerator +/// +/// The transform function here behaves exactly the same as the transform function in +/// MakeTransformedIterator and you can safely use the same transform function to +/// transform both synchronous and asynchronous streams. +/// +/// This generator is not async-reentrant +template +AsyncGenerator MakeAsyncGenerator(AsyncGenerator generator, + Transformer transformer) { + return TransformingGenerator(generator, transformer); +} + +/// \brief Transfers execution of the generator onto the given executor +/// +/// This generator is async-reentrant if the source generator is async-reentrant +template +class TransferringGenerator { + public: + explicit TransferringGenerator(AsyncGenerator source, internal::Executor* executor) + : source_(std::move(source)), executor_(executor) {} + + Future operator()() { return executor_->Transfer(source_()); } + + private: + AsyncGenerator source_; + internal::Executor* executor_; +}; + +/// \brief Transfers a future to an underlying executor. +/// +/// Continuations run on the returned future will be run on the given executor +/// if they cannot be run synchronously. +/// +/// This is often needed to move computation off I/O threads or other external +/// completion sources and back on to the CPU executor so the I/O thread can +/// stay busy and focused on I/O +/// +/// Keep in mind that continuations called on an already completed future will +/// always be run synchronously and so no transfer will happen in that case. +template +AsyncGenerator MakeTransferredGenerator(AsyncGenerator source, + internal::Executor* executor) { + return TransferringGenerator(std::move(source), executor); +} + +/// \brief Async generator that iterates on an underlying iterator in a +/// separate executor. +/// +/// This generator is async-reentrant +template +class BackgroundGenerator { + public: + explicit BackgroundGenerator(Iterator it, internal::Executor* io_executor) + : io_executor_(io_executor) { + task_ = Task{std::make_shared>(std::move(it)), + std::make_shared>(false)}; + } + + ~BackgroundGenerator() { + // The thread pool will be disposed of automatically. By default it will not wait + // so the background thread may outlive this object. That should be ok. Any task + // objects in the thread pool are copies of task_ and have their own shared_ptr to + // the iterator. + } + + ARROW_DEFAULT_MOVE_AND_ASSIGN(BackgroundGenerator); + ARROW_DISALLOW_COPY_AND_ASSIGN(BackgroundGenerator); + + Future operator()() { + auto submitted_future = io_executor_->Submit(task_); + if (!submitted_future.ok()) { + return Future::MakeFinished(submitted_future.status()); + } + return std::move(*submitted_future); + } + + protected: + struct Task { + Result operator()() { + if (*done_) { + return IterationTraits::End(); + } + auto next = it_->Next(); + if (!next.ok() || *next == IterationTraits::End()) { + *done_ = true; + } + return next; + } + // This task is going to be copied so we need to convert the iterator ptr to + // a shared ptr. This should be safe however because the background executor only + // has a single thread so it can't access it_ across multiple threads. + std::shared_ptr> it_; + std::shared_ptr> done_; + }; + + Task task_; + internal::Executor* io_executor_; +}; + +/// \brief Creates an AsyncGenerator by iterating over an Iterator on a background +/// thread +template +static Result> MakeBackgroundGenerator( + Iterator iterator, internal::Executor* io_executor) { + auto background_iterator = std::make_shared>( + std::move(iterator), std::move(io_executor)); + return [background_iterator]() { return (*background_iterator)(); }; +} + +/// \brief Converts an AsyncGenerator to an Iterator by blocking until each future +/// is finished +template +class GeneratorIterator { + public: + explicit GeneratorIterator(AsyncGenerator source) : source_(std::move(source)) {} + + Result Next() { return source_().result(); } + + private: + AsyncGenerator source_; +}; + +template +Result> MakeGeneratorIterator(AsyncGenerator source) { + return Iterator(GeneratorIterator(std::move(source))); +} + +template +Result> MakeReadaheadIterator(Iterator it, int readahead_queue_size) { + ARROW_ASSIGN_OR_RAISE(auto io_executor, internal::ThreadPool::Make(1)); + ARROW_ASSIGN_OR_RAISE(auto background_generator, + MakeBackgroundGenerator(std::move(it), io_executor.get())); + // Capture io_executor to keep it alive as long as owned_bg_generator is still + // referenced + AsyncGenerator owned_bg_generator = [io_executor, background_generator]() { + return background_generator(); + }; + auto readahead_generator = + MakeReadaheadGenerator(std::move(owned_bg_generator), readahead_queue_size); + return MakeGeneratorIterator(std::move(readahead_generator)); +} + +} // namespace arrow diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc index f8d12ad7611..3a77f34e68f 100644 --- a/cpp/src/arrow/util/future.cc +++ b/cpp/src/arrow/util/future.cc @@ -239,6 +239,16 @@ class ConcreteFutureImpl : public FutureImpl { } } + bool TryAddCallback(const std::function& callback_factory) { + std::unique_lock lock(mutex_); + if (IsFutureFinished(state_)) { + return false; + } else { + callbacks_.push_back(callback_factory()); + return true; + } + } + void DoMarkFinishedOrFailed(FutureState state) { { // Lock the hypothetical waiter first, and the future after. @@ -326,4 +336,8 @@ void FutureImpl::AddCallback(Callback callback) { GetConcreteFuture(this)->AddCallback(std::move(callback)); } +bool FutureImpl::TryAddCallback(const std::function& callback_factory) { + return GetConcreteFuture(this)->TryAddCallback(callback_factory); +} + } // namespace arrow diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 2fc040c2e2f..ee053cf3096 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -29,6 +29,7 @@ #include "arrow/status.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" +#include "arrow/util/optional.h" #include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" @@ -152,6 +153,7 @@ class ARROW_EXPORT FutureImpl { using Callback = internal::FnOnce; void AddCallback(Callback callback); + bool TryAddCallback(const std::function& callback_factory); // Waiter API inline FutureState SetWaiter(FutureWaiter* w, int future_num); @@ -273,7 +275,14 @@ class ARROW_MUST_USE_TYPE Future { Wait(); return *GetResult(); } - Result&& result() && { + + /// \brief Returns an rvalue to the result. This method is potentially unsafe + /// + /// The future is not the unique owner of the result, copies of a future will + /// also point to the same result. You must make sure that no other copies + /// of the future exist. Attempts to add callbacks after you move the result + /// will result in undefined behavior. + Result&& MoveResult() { Wait(); return std::move(*GetResult()); } @@ -326,7 +335,10 @@ class ARROW_MUST_USE_TYPE Future { /// \brief Producer API: instantiate a valid Future /// - /// The Future's state is initialized with PENDING. + /// The Future's state is initialized with PENDING. If you are creating a future with + /// this method you must ensure that future is eventually completed (with success or + /// failure). Creating a future, returning it, and never completing the future can lead + /// to memory leaks (for example, see Loop). static Future Make() { Future fut; fut.impl_ = FutureImpl::Make(); @@ -375,22 +387,33 @@ class ARROW_MUST_USE_TYPE Future { /// In this example `fut` falls out of scope but is not destroyed because it holds a /// cyclic reference to itself through the callback. template - void AddCallback(OnComplete&& on_complete) const { - struct Callback { - void operator()() && { - auto self = weak_self.get(); - std::move(on_complete)(*self.GetResult()); - } - - WeakFuture weak_self; - OnComplete on_complete; - }; - + void AddCallback(OnComplete on_complete) const { // We know impl_ will not be dangling when invoking callbacks because at least one // thread will be waiting for MarkFinished to return. Thus it's safe to keep a // weak reference to impl_ here impl_->AddCallback( - Callback{WeakFuture(*this), std::forward(on_complete)}); + Callback{WeakFuture(*this), std::move(on_complete)}); + } + + /// \brief Overload of AddCallback that will return false instead of running + /// synchronously + /// + /// This overload will guarantee the callback is never run synchronously. If the future + /// is already finished then it will simply return false. This can be useful to avoid + /// stack overflow in a situation where you have recursive Futures. For an example + /// see the Loop function + /// + /// Takes in a callback factory function to allow moving callbacks (the factory function + /// will only be called if the callback can successfully be added) + /// + /// Returns true if a callback was actually added and false if the callback failed + /// to add because the future was marked complete. + template + bool TryAddCallback(const CallbackFactory& callback_factory) const { + return impl_->TryAddCallback([this, &callback_factory]() { + return Callback>{WeakFuture(*this), + callback_factory()}; + }); } /// \brief Consumer API: Register a continuation to run when this future completes @@ -428,7 +451,7 @@ class ARROW_MUST_USE_TYPE Future { template > - ContinuedFuture Then(OnSuccess&& on_success, OnFailure&& on_failure) const { + ContinuedFuture Then(OnSuccess on_success, OnFailure on_failure) const { static_assert( std::is_same, ContinuedFuture>::value, @@ -471,6 +494,17 @@ class ARROW_MUST_USE_TYPE Future { } protected: + template + struct Callback { + void operator()() && { + auto self = weak_self.get(); + std::move(on_complete)(*self.GetResult()); + } + + WeakFuture weak_self; + OnComplete on_complete; + }; + Result* GetResult() const { return static_cast*>(impl_->result_.get()); } @@ -557,6 +591,38 @@ inline bool WaitForAll(const std::vector*>& futures, return waiter->Wait(seconds); } +/// \brief Create a Future which completes when all of `futures` complete. +/// +/// The future's result is a vector of the results of `futures`. +/// Note that this future will never be marked "failed"; failed results +/// will be stored in the result vector alongside successful results. +template +Future>> All(std::vector> futures) { + struct State { + explicit State(std::vector> f) + : futures(std::move(f)), n_remaining(futures.size()) {} + + std::vector> futures; + std::atomic n_remaining; + }; + + auto state = std::make_shared(std::move(futures)); + + auto out = Future>>::Make(); + for (const Future& future : state->futures) { + future.AddCallback([state, out](const Result&) mutable { + if (state->n_remaining.fetch_sub(1) != 1) return; + + std::vector> results(state->futures.size()); + for (size_t i = 0; i < results.size(); ++i) { + results[i] = state->futures[i].result(); + } + out.MarkFinished(std::move(results)); + }); + } + return out; +} + /// \brief Wait for one of the futures to end, or for the given timeout to expire. /// /// The indices of all completed futures are returned. Note that some futures @@ -581,4 +647,79 @@ inline std::vector WaitForAny(const std::vector*>& futures, return waiter->MoveFinishedFutures(); } +struct Continue { + template + operator util::optional() && { // NOLINT explicit + return {}; + } +}; + +template +util::optional Break(T break_value = {}) { + return util::optional{std::move(break_value)}; +} + +template +using ControlFlow = util::optional; + +/// \brief Loop through an asynchronous sequence +/// +/// \param[in] iterate A generator of Future>. On completion of +/// each yielded future the resulting ControlFlow will be examined. A Break will terminate +/// the loop, while a Continue will re-invoke `iterate`. \return A future which will +/// complete when a Future returned by iterate completes with a Break +template ::ValueType, + typename BreakValueType = typename Control::value_type> +Future Loop(Iterate iterate) { + auto break_fut = Future::Make(); + + struct Callback { + bool CheckForTermination(const Result& control_res) { + if (!control_res.ok()) { + break_fut.MarkFinished(control_res.status()); + return true; + } + if (control_res->has_value()) { + break_fut.MarkFinished(*std::move(*control_res)); + return true; + } + return false; + } + + void operator()(const Result& maybe_control) && { + if (CheckForTermination(maybe_control)) return; + + auto control_fut = iterate(); + while (true) { + if (control_fut.TryAddCallback([this]() { return *this; })) { + // Adding a callback succeeded; control_fut was not finished + // and we must wait to CheckForTermination. + return; + } + // Adding a callback failed; control_fut was finished and we + // can CheckForTermination immediately. This also avoids recursion and potential + // stack overflow. + if (CheckForTermination(control_fut.result())) return; + + control_fut = iterate(); + } + } + + Iterate iterate; + + // If the future returned by control_fut is never completed then we will be hanging on + // to break_fut forever even if the listener has given up listening on it. Instead we + // rely on the fact that a producer (the caller of Future<>::Make) is always + // responsible for completing the futures they create. + // TODO: Could avoid this kind of situation with "future abandonment" similar to mesos + Future break_fut; + }; + + auto control_fut = iterate(); + control_fut.AddCallback(Callback{std::move(iterate), break_fut}); + + return break_fut; +} + } // namespace arrow diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 203f05b5446..2a4fc6bb2fd 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -20,7 +20,9 @@ #include #include +#include #include +#include #include #include #include @@ -287,6 +289,109 @@ TEST(FutureSyncTest, Int) { } } +TEST(FutureSyncTest, Foo) { + { + auto fut = Future::Make(); + AssertNotFinished(fut); + fut.MarkFinished(Foo(42)); + AssertSuccessful(fut); + auto res = fut.result(); + ASSERT_OK(res); + Foo value = *res; + ASSERT_EQ(value, 42); + ASSERT_OK(fut.status()); + res = std::move(fut).result(); + ASSERT_OK(res); + value = *res; + ASSERT_EQ(value, 42); + } + { + // MarkFinished(Result) + auto fut = Future::Make(); + AssertNotFinished(fut); + fut.MarkFinished(Result(Foo(42))); + AssertSuccessful(fut); + ASSERT_OK_AND_ASSIGN(Foo value, fut.result()); + ASSERT_EQ(value, 42); + } + { + // MarkFinished(failed Result) + auto fut = Future::Make(); + AssertNotFinished(fut); + fut.MarkFinished(Result(Status::IOError("xxx"))); + AssertFailed(fut); + ASSERT_RAISES(IOError, fut.result()); + } +} + +TEST(FutureSyncTest, Empty) { + { + // MarkFinished() + auto fut = Future<>::Make(); + AssertNotFinished(fut); + fut.MarkFinished(); + AssertSuccessful(fut); + } + { + // MakeFinished() + auto fut = Future<>::MakeFinished(); + AssertSuccessful(fut); + auto res = fut.result(); + ASSERT_OK(res); + res = std::move(fut.result()); + ASSERT_OK(res); + } + { + // MarkFinished(Status) + auto fut = Future<>::Make(); + AssertNotFinished(fut); + fut.MarkFinished(Status::OK()); + AssertSuccessful(fut); + } + { + // MakeFinished(Status) + auto fut = Future<>::MakeFinished(Status::OK()); + AssertSuccessful(fut); + fut = Future<>::MakeFinished(Status::IOError("xxx")); + AssertFailed(fut); + } + { + // MarkFinished(Status) + auto fut = Future<>::Make(); + AssertNotFinished(fut); + fut.MarkFinished(Status::IOError("xxx")); + AssertFailed(fut); + ASSERT_RAISES(IOError, fut.status()); + } +} + +TEST(FutureSyncTest, GetStatusFuture) { + { + auto fut = Future::Make(); + Future<> status_future(fut); + + AssertNotFinished(fut); + AssertNotFinished(status_future); + + fut.MarkFinished(MoveOnlyDataType(42)); + AssertSuccessful(fut); + AssertSuccessful(status_future); + ASSERT_EQ(&fut.status(), &status_future.status()); + } + { + auto fut = Future::Make(); + Future<> status_future(fut); + + AssertNotFinished(fut); + AssertNotFinished(status_future); + + fut.MarkFinished(Status::IOError("xxx")); + AssertFailed(fut); + AssertFailed(status_future); + ASSERT_EQ(&fut.status(), &status_future.status()); + } +} + TEST(FutureRefTest, ChainRemoved) { // Creating a future chain should not prevent the futures from being deleted if the // entire chain is deleted @@ -359,7 +464,7 @@ TEST(FutureRefTest, HeadRemoved) { ASSERT_TRUE(ref.expired()); } -TEST(FutureTest, StressCallback) { +TEST(FutureStessTest, Callback) { for (unsigned int n = 0; n < 1000; n++) { auto fut = Future<>::Make(); std::atomic count_finished_immediately(0); @@ -404,6 +509,56 @@ TEST(FutureTest, StressCallback) { } } +TEST(FutureStessTest, TryAddCallback) { + for (unsigned int n = 0; n < 1; n++) { + auto fut = Future<>::Make(); + std::atomic callbacks_added(0); + std::atomic finished; + std::mutex mutex; + std::condition_variable cv; + std::thread::id callback_adder_thread_id; + + std::thread callback_adder([&] { + callback_adder_thread_id = std::this_thread::get_id(); + std::function&)> callback = + [&callback_adder_thread_id](const Result&) { + if (std::this_thread::get_id() == callback_adder_thread_id) { + FAIL() << "TryAddCallback allowed a callback to be run synchronously"; + } + }; + std::function&)>()> + callback_factory = [&callback]() { return callback; }; + while (true) { + auto callback_added = fut.TryAddCallback(callback_factory); + if (callback_added) { + callbacks_added++; + } else { + break; + } + } + { + std::lock_guard lg(mutex); + finished.store(true); + } + cv.notify_one(); + }); + + while (callbacks_added.load() == 0) { + // Spin until the callback_adder has started running + } + + fut.MarkFinished(); + + std::unique_lock lk(mutex); + cv.wait_for(lk, std::chrono::duration(0.5), + [&finished] { return finished.load(); }); + lk.unlock(); + + ASSERT_TRUE(finished); + callback_adder.join(); + } +} + TEST(FutureCompletionTest, Void) { { // Simple callback @@ -832,142 +987,213 @@ TEST(FutureCompletionTest, FutureVoid) { } } -TEST(FutureSyncTest, Foo) { - { - // MarkFinished(Foo) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Foo(42)); - AssertSuccessful(fut); - auto res = fut.result(); - ASSERT_OK(res); - Foo value = *res; - ASSERT_EQ(value, 42); - ASSERT_OK(fut.status()); - res = std::move(fut).result(); - ASSERT_OK(res); - value = *res; - ASSERT_EQ(value, 42); - } - { - // MarkFinished(Result) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Result(Foo(42))); - AssertSuccessful(fut); - ASSERT_OK_AND_ASSIGN(Foo value, fut.result()); - ASSERT_EQ(value, 42); - } - { - // MarkFinished(failed Result) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Result(Status::IOError("xxx"))); - AssertFailed(fut); - ASSERT_RAISES(IOError, fut.result()); - } +TEST(FutureAllTest, Simple) { + auto f1 = Future::Make(); + auto f2 = Future::Make(); + std::vector> futures = {f1, f2}; + auto combined = arrow::All(futures); + + auto after_assert = combined.Then([](std::vector> results) { + ASSERT_EQ(2, results.size()); + ASSERT_EQ(1, *results[0]); + ASSERT_EQ(2, *results[1]); + }); + + // Finish in reverse order, results should still be delivered in proper order + AssertNotFinished(after_assert); + f2.MarkFinished(2); + AssertNotFinished(after_assert); + f1.MarkFinished(1); + AssertSuccessful(after_assert); } -TEST(FutureSyncTest, MoveOnlyDataType) { - { - // MarkFinished(MoveOnlyDataType) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(MoveOnlyDataType(42)); - AssertSuccessful(fut); - const auto& res = fut.result(); - ASSERT_TRUE(res.ok()); - ASSERT_EQ(*res, 42); - ASSERT_OK_AND_ASSIGN(MoveOnlyDataType value, std::move(fut).result()); - ASSERT_EQ(value, 42); - } +TEST(FutureAllTest, Failure) { + auto f1 = Future::Make(); + auto f2 = Future::Make(); + auto f3 = Future::Make(); + std::vector> futures = {f1, f2, f3}; + auto combined = arrow::All(futures); + + auto after_assert = combined.Then([](std::vector> results) { + ASSERT_EQ(3, results.size()); + ASSERT_EQ(1, *results[0]); + ASSERT_EQ(Status::IOError("XYZ"), results[1].status()); + ASSERT_EQ(3, *results[2]); + }); + + f1.MarkFinished(1); + f2.MarkFinished(Status::IOError("XYZ")); + f3.MarkFinished(3); + + AssertFinished(after_assert); +} + +TEST(FutureLoopTest, Sync) { + struct { + int i = 0; + Future Get() { return Future::MakeFinished(i++); } + } IntSource; + + bool do_fail = false; + std::vector ints; + auto loop_body = [&] { + return IntSource.Get().Then([&](int i) -> Result> { + if (do_fail && i == 3) { + return Status::IOError("xxx"); + } + + if (i == 5) { + int sum = 0; + for (int i : ints) sum += i; + return Break(sum); + } + + ints.push_back(i); + return Continue(); + }); + }; + { - // MarkFinished(Result) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Result(MoveOnlyDataType(43))); - AssertSuccessful(fut); - ASSERT_OK_AND_ASSIGN(MoveOnlyDataType value, std::move(fut).result()); - ASSERT_EQ(value, 43); + auto sum_fut = Loop(loop_body); + AssertSuccessful(sum_fut); + + ASSERT_OK_AND_ASSIGN(auto sum, sum_fut.result()); + ASSERT_EQ(sum, 0 + 1 + 2 + 3 + 4); } + { - // MarkFinished(failed Result) - auto fut = Future::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Result(Status::IOError("xxx"))); - AssertFailed(fut); - ASSERT_RAISES(IOError, fut.status()); - const auto& res = fut.result(); - ASSERT_TRUE(res.status().IsIOError()); - ASSERT_RAISES(IOError, std::move(fut).result()); + do_fail = true; + IntSource.i = 0; + auto sum_fut = Loop(loop_body); + AssertFailed(sum_fut); + ASSERT_RAISES(IOError, sum_fut.result()); } } -TEST(FutureSyncTest, Empty) { - { - // MarkFinished() - auto fut = Future<>::Make(); - AssertNotFinished(fut); - fut.MarkFinished(); - AssertSuccessful(fut); +TEST(FutureLoopTest, EmptyBreakValue) { + Future<> none_fut = + Loop([&] { return Future<>::MakeFinished().Then([&](...) { return Break(); }); }); + AssertSuccessful(none_fut); +} + +TEST(FutureLoopTest, EmptyLoop) { + auto loop_body = []() -> Future> { + return Future>::MakeFinished(Break(0)); + }; + auto loop_fut = Loop(loop_body); + ASSERT_FINISHES_OK_AND_ASSIGN(auto loop_res, loop_fut); + ASSERT_EQ(loop_res, 0); +} + +// TODO - Test provided by Ben but I don't understand how it can pass legitimately. +// Any future result will be passed by reference to the callbacks (as there can be +// multiple callbacks). In the Loop construct it takes the break and forwards it +// on to the outer future. Since there is no way to move a reference this can only +// be done by copying. +// +// In theory it should be safe since Loop is guaranteed to be the last callback added +// to the control future and so the value can be safely moved at that point. However, +// I'm unable to reproduce whatever trick you had in ControlFlow to make this work. +// If we want to formalize this "last callback can steal" concept then we could add +// a "last callback" to Future which gets called with an rvalue instead of an lvalue +// reference but that seems overly complicated. +// +// Ben, can you recreate whatever trick you had in place before that allowed this to +// pass? Perhaps some kind of cast. Worst case, I can move back to using +// ControlFlow instead of std::optional +// +// TEST(FutureLoopTest, MoveOnlyBreakValue) { +// Future one_fut = Loop([&] { +// return Future::MakeFinished(1).Then( +// [&](int i) { return Break(MoveOnlyDataType(i)); }); +// }); +// AssertSuccessful(one_fut); +// ASSERT_OK_AND_ASSIGN(auto one, std::move(one_fut).result()); +// ASSERT_EQ(one, 1); +// } + +TEST(FutureLoopTest, StackOverflow) { + // Looping over futures is normally a rather recursive task. If the futures complete + // synchronously (because they are already finished) it could lead to a stack overflow + // if care is not taken. + int counter = 0; + auto loop_body = [&counter]() -> Future> { + while (counter < 1000000) { + counter++; + return Future>::MakeFinished(Continue()); + } + return Future>::MakeFinished(Break(-1)); + }; + auto loop_fut = Loop(loop_body); + ASSERT_TRUE(loop_fut.Wait(0.1)); +} + +TEST(FutureLoopTest, AllowsBreakFutToBeDiscarded) { + int counter = 0; + auto loop_body = [&counter]() -> Future> { + while (counter < 10) { + counter++; + return Future>::MakeFinished(Continue()); + } + return Future>::MakeFinished(Break(-1)); + }; + auto loop_fut = Loop(loop_body).Then([](...) { return Status::OK(); }); + ASSERT_TRUE(loop_fut.Wait(0.1)); +} + +class MoveTrackingCallable { + public: + MoveTrackingCallable() { + // std::cout << "CONSTRUCT" << std::endl; } - { - // MakeFinished() - auto fut = Future<>::MakeFinished(); - AssertSuccessful(fut); - auto res = fut.result(); - ASSERT_OK(res); - res = std::move(fut.result()); - ASSERT_OK(res); + ~MoveTrackingCallable() { + valid_ = false; + // std::cout << "DESTRUCT" << std::endl; } - { - // MarkFinished(Status) - auto fut = Future<>::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Status::OK()); - AssertSuccessful(fut); + MoveTrackingCallable(const MoveTrackingCallable& other) { + // std::cout << "COPY CONSTRUCT" << std::endl; } - { - // MakeFinished(Status) - auto fut = Future<>::MakeFinished(Status::OK()); - AssertSuccessful(fut); - fut = Future<>::MakeFinished(Status::IOError("xxx")); - AssertFailed(fut); + MoveTrackingCallable(MoveTrackingCallable&& other) { + other.valid_ = false; + // std::cout << "MOVE CONSTRUCT" << std::endl; } - { - // MarkFinished(Status) - auto fut = Future<>::Make(); - AssertNotFinished(fut); - fut.MarkFinished(Status::IOError("xxx")); - AssertFailed(fut); - ASSERT_RAISES(IOError, fut.status()); + MoveTrackingCallable& operator=(const MoveTrackingCallable& other) { + // std::cout << "COPY ASSIGN" << std::endl; + return *this; + } + MoveTrackingCallable& operator=(MoveTrackingCallable&& other) { + other.valid_ = false; + // std::cout << "MOVE ASSIGN" << std::endl; + return *this; } -} -TEST(FutureSyncTest, GetStatusFuture) { - { - auto fut = Future::Make(); - Future<> status_future(fut); + Status operator()(...) { + // std::cout << "TRIGGER" << std::endl; + if (valid_) { + return Status::OK(); + } else { + return Status::Invalid("Invalid callback triggered"); + } + } - AssertNotFinished(fut); - AssertNotFinished(status_future); + private: + bool valid_ = true; +}; - fut.MarkFinished(MoveOnlyDataType(42)); - AssertSuccessful(fut); - AssertSuccessful(status_future); - ASSERT_EQ(&fut.status(), &status_future.status()); - } +TEST(FutureCompletionTest, ReuseCallback) { + auto fut = Future<>::Make(); + + Future<> continuation; { - auto fut = Future::Make(); - Future<> status_future(fut); + MoveTrackingCallable callback; + continuation = fut.Then(callback); + } - AssertNotFinished(fut); - AssertNotFinished(status_future); + fut.MarkFinished(Status::OK()); - fut.MarkFinished(Status::IOError("xxx")); - AssertFailed(fut); - AssertFailed(status_future); - ASSERT_EQ(&fut.status(), &status_future.status()); + ASSERT_TRUE(continuation.is_finished()); + if (continuation.is_finished()) { + ASSERT_OK(continuation.status()); } } @@ -1287,34 +1513,34 @@ class FutureTestBase : public ::testing::Test { }; template -class FutureTest : public FutureTestBase {}; +class FutureWaitTest : public FutureTestBase {}; -using FutureTestTypes = ::testing::Types; +using FutureWaitTestTypes = ::testing::Types; -TYPED_TEST_SUITE(FutureTest, FutureTestTypes); +TYPED_TEST_SUITE(FutureWaitTest, FutureWaitTestTypes); -TYPED_TEST(FutureTest, BasicWait) { this->TestBasicWait(); } +TYPED_TEST(FutureWaitTest, BasicWait) { this->TestBasicWait(); } -TYPED_TEST(FutureTest, TimedWait) { this->TestTimedWait(); } +TYPED_TEST(FutureWaitTest, TimedWait) { this->TestTimedWait(); } -TYPED_TEST(FutureTest, StressWait) { this->TestStressWait(); } +TYPED_TEST(FutureWaitTest, StressWait) { this->TestStressWait(); } -TYPED_TEST(FutureTest, BasicWaitForAny) { this->TestBasicWaitForAny(); } +TYPED_TEST(FutureWaitTest, BasicWaitForAny) { this->TestBasicWaitForAny(); } -TYPED_TEST(FutureTest, TimedWaitForAny) { this->TestTimedWaitForAny(); } +TYPED_TEST(FutureWaitTest, TimedWaitForAny) { this->TestTimedWaitForAny(); } -TYPED_TEST(FutureTest, StressWaitForAny) { this->TestStressWaitForAny(); } +TYPED_TEST(FutureWaitTest, StressWaitForAny) { this->TestStressWaitForAny(); } -TYPED_TEST(FutureTest, BasicWaitForAll) { this->TestBasicWaitForAll(); } +TYPED_TEST(FutureWaitTest, BasicWaitForAll) { this->TestBasicWaitForAll(); } -TYPED_TEST(FutureTest, TimedWaitForAll) { this->TestTimedWaitForAll(); } +TYPED_TEST(FutureWaitTest, TimedWaitForAll) { this->TestTimedWaitForAll(); } -TYPED_TEST(FutureTest, StressWaitForAll) { this->TestStressWaitForAll(); } +TYPED_TEST(FutureWaitTest, StressWaitForAll) { this->TestStressWaitForAll(); } template class FutureIteratorTest : public FutureTestBase {}; -using FutureIteratorTestTypes = ::testing::Types; +using FutureIteratorTestTypes = ::testing::Types; TYPED_TEST_SUITE(FutureIteratorTest, FutureIteratorTestTypes); diff --git a/cpp/src/arrow/util/iterator.cc b/cpp/src/arrow/util/iterator.cc deleted file mode 100644 index 0c71bbaabd0..00000000000 --- a/cpp/src/arrow/util/iterator.cc +++ /dev/null @@ -1,175 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/util/iterator.h" - -#include -#include -#include -#include -#include - -#include "arrow/util/logging.h" - -namespace arrow { -namespace detail { - -ReadaheadPromise::~ReadaheadPromise() {} - -class ReadaheadQueue::Impl : public std::enable_shared_from_this { - public: - explicit Impl(int64_t readahead_queue_size) : max_readahead_(readahead_queue_size) {} - - ~Impl() { EnsureShutdownOrDie(false); } - - void Start() { - // Cannot do this in constructor as shared_from_this() would throw - DCHECK(!thread_.joinable()); - auto self = shared_from_this(); - thread_ = std::thread([self]() { self->DoWork(); }); - DCHECK(thread_.joinable()); - } - - void EnsureShutdownOrDie(bool wait = true) { - std::unique_lock lock(mutex_); - if (!please_shutdown_) { - ARROW_CHECK_OK(ShutdownUnlocked(std::move(lock), wait)); - } - DCHECK(!thread_.joinable()); - } - - Status Append(std::unique_ptr promise) { - std::unique_lock lock(mutex_); - if (please_shutdown_) { - return Status::Invalid("Shutdown requested"); - } - todo_.push_back(std::move(promise)); - if (static_cast(todo_.size()) == 1) { - // Signal there's more work to do - lock.unlock(); - worker_wakeup_.notify_one(); - } - return Status::OK(); - } - - Status PopDone(std::unique_ptr* out) { - std::unique_lock lock(mutex_); - if (please_shutdown_) { - return Status::Invalid("Shutdown requested"); - } - work_done_.wait(lock, [this]() { return done_.size() > 0; }); - *out = std::move(done_.front()); - done_.pop_front(); - if (static_cast(done_.size()) < max_readahead_) { - // Signal there's more work to do - lock.unlock(); - worker_wakeup_.notify_one(); - } - return Status::OK(); - } - - Status Pump(std::function()> factory) { - std::unique_lock lock(mutex_); - if (please_shutdown_) { - return Status::Invalid("Shutdown requested"); - } - while (static_cast(done_.size() + todo_.size()) < max_readahead_) { - todo_.push_back(factory()); - } - // Signal there's more work to do - lock.unlock(); - worker_wakeup_.notify_one(); - return Status::OK(); - } - - Status Shutdown(bool wait = true) { - return ShutdownUnlocked(std::unique_lock(mutex_), wait); - } - - Status ShutdownUnlocked(std::unique_lock lock, bool wait = true) { - if (please_shutdown_) { - return Status::Invalid("Shutdown already requested"); - } - DCHECK(thread_.joinable()); - please_shutdown_ = true; - lock.unlock(); - worker_wakeup_.notify_one(); - if (wait) { - thread_.join(); - } else { - thread_.detach(); - } - return Status::OK(); - } - - void DoWork() { - std::unique_lock lock(mutex_); - while (!please_shutdown_) { - while (static_cast(done_.size()) < max_readahead_ && todo_.size() > 0) { - auto promise = std::move(todo_.front()); - todo_.pop_front(); - lock.unlock(); - promise->Call(); - lock.lock(); - done_.push_back(std::move(promise)); - work_done_.notify_one(); - // Exit eagerly - if (please_shutdown_) { - return; - } - } - // Wait for more work to do - worker_wakeup_.wait(lock); - } - } - - std::deque> todo_; - std::deque> done_; - int64_t max_readahead_; - bool please_shutdown_ = false; - - std::thread thread_; - std::mutex mutex_; - std::condition_variable worker_wakeup_; - std::condition_variable work_done_; -}; - -ReadaheadQueue::ReadaheadQueue(int readahead_queue_size) - : impl_(new Impl(readahead_queue_size)) { - impl_->Start(); -} - -ReadaheadQueue::~ReadaheadQueue() {} - -Status ReadaheadQueue::Append(std::unique_ptr promise) { - return impl_->Append(std::move(promise)); -} - -Status ReadaheadQueue::PopDone(std::unique_ptr* out) { - return impl_->PopDone(out); -} - -Status ReadaheadQueue::Pump(std::function()> factory) { - return impl_->Pump(std::move(factory)); -} - -Status ReadaheadQueue::Shutdown() { return impl_->Shutdown(); } - -void ReadaheadQueue::EnsureShutdownOrDie() { return impl_->EnsureShutdownOrDie(); } - -} // namespace detail -} // namespace arrow diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index 58dda5df2a7..75ccf283aa5 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -186,6 +187,127 @@ class Iterator : public util::EqualityComparable> { Result (*next_)(void*) = NULLPTR; }; +template +struct TransformFlow { + using YieldValueType = T; + + TransformFlow(YieldValueType value, bool ready_for_next) + : finished_(false), + ready_for_next_(ready_for_next), + yield_value_(std::move(value)) {} + TransformFlow(bool finished, bool ready_for_next) + : finished_(finished), ready_for_next_(ready_for_next), yield_value_() {} + + bool HasValue() const { return yield_value_.has_value(); } + bool Finished() const { return finished_; } + bool ReadyForNext() const { return ready_for_next_; } + T Value() const { return *yield_value_; } + + bool finished_ = false; + bool ready_for_next_ = false; + util::optional yield_value_; +}; + +struct TransformFinish { + template + operator TransformFlow() && { // NOLINT explicit + return TransformFlow(true, true); + } +}; + +struct TransformSkip { + template + operator TransformFlow() && { // NOLINT explicit + return TransformFlow(false, true); + } +}; + +template +TransformFlow TransformYield(T value = {}, bool ready_for_next = true) { + return TransformFlow(std::move(value), ready_for_next); +} + +template +using Transformer = std::function>(T)>; + +template +class TransformIterator { + public: + explicit TransformIterator(Iterator it, Transformer transformer) + : it_(std::move(it)), + transformer_(std::move(transformer)), + last_value_(), + finished_() {} + + Result Next() { + while (!finished_) { + ARROW_ASSIGN_OR_RAISE(util::optional next, Pump()); + if (next.has_value()) { + return std::move(*next); + } + ARROW_ASSIGN_OR_RAISE(last_value_, it_.Next()); + } + return IterationTraits::End(); + } + + private: + // Calls the transform function on the current value. Can return in several ways + // * If the next value is requested (e.g. skip) it will return an empty optional + // * If an invalid status is encountered that will be returned + // * If finished it will return IterationTraits::End() + // * If a value is returned by the transformer that will be returned + Result> Pump() { + if (!finished_ && last_value_.has_value()) { + auto next_res = transformer_(*last_value_); + if (!next_res.ok()) { + finished_ = true; + return next_res.status(); + } + auto next = *next_res; + if (next.ReadyForNext()) { + if (*last_value_ == IterationTraits::End()) { + finished_ = true; + } + last_value_.reset(); + } + if (next.Finished()) { + finished_ = true; + } + if (next.HasValue()) { + return next.Value(); + } + } + if (finished_) { + return IterationTraits::End(); + } + return util::nullopt; + } + + Iterator it_; + Transformer transformer_; + util::optional last_value_; + bool finished_ = false; +}; + +/// \brief Transforms an iterator according to a transformer, returning a new Iterator. +/// +/// The transformer will be called on each element of the source iterator and for each +/// call it can yield a value, skip, or finish the iteration. When yielding a value the +/// transformer can choose to consume the source item (the default, ready_for_next = true) +/// or to keep it and it will be called again on the same value. +/// +/// This is essentially a more generic form of the map operation that can return 0, 1, or +/// many values for each of the source items. +/// +/// The transformer will be exposed to the end of the source sequence +/// (IterationTraits::End) in case it needs to return some penultimate item(s). +/// +/// Any invalid status returned by the transformer will be returned immediately. +template +Iterator MakeTransformedIterator(Iterator it, Transformer op) { + return Iterator(TransformIterator(std::move(it), std::move(op))); +} + template struct IterationTraits> { // The end condition for an Iterator of Iterators is a default constructed (null) @@ -414,117 +536,4 @@ Iterator MakeFlattenIterator(Iterator> it) { return Iterator(FlattenIterator(std::move(it))); } -namespace detail { - -// A type-erased promise object for ReadaheadQueue. -struct ARROW_EXPORT ReadaheadPromise { - virtual ~ReadaheadPromise(); - virtual void Call() = 0; -}; - -template -struct ReadaheadIteratorPromise : ReadaheadPromise { - ~ReadaheadIteratorPromise() override {} - - explicit ReadaheadIteratorPromise(Iterator* it) : it_(it) {} - - void Call() override { - assert(!called_); - out_ = it_->Next(); - called_ = true; - } - - Iterator* it_; - Result out_ = IterationTraits::End(); - bool called_ = false; -}; - -class ARROW_EXPORT ReadaheadQueue { - public: - explicit ReadaheadQueue(int readahead_queue_size); - ~ReadaheadQueue(); - - Status Append(std::unique_ptr); - Status PopDone(std::unique_ptr*); - Status Pump(std::function()> factory); - Status Shutdown(); - void EnsureShutdownOrDie(); - - protected: - class Impl; - std::shared_ptr impl_; -}; - -} // namespace detail - -/// \brief Readahead iterator that iterates on the underlying iterator in a -/// separate thread, getting up to N values in advance. -template -class ReadaheadIterator { - using PromiseType = typename detail::ReadaheadIteratorPromise; - - public: - // Public default constructor creates an empty iterator - ReadaheadIterator() : done_(true) {} - - ~ReadaheadIterator() { - if (queue_) { - // Make sure the queue doesn't call any promises after this object - // is destroyed. - queue_->EnsureShutdownOrDie(); - } - } - - ARROW_DEFAULT_MOVE_AND_ASSIGN(ReadaheadIterator); - ARROW_DISALLOW_COPY_AND_ASSIGN(ReadaheadIterator); - - Result Next() { - if (done_) { - return IterationTraits::End(); - } - - std::unique_ptr promise; - ARROW_RETURN_NOT_OK(queue_->PopDone(&promise)); - auto it_promise = static_cast(promise.get()); - - ARROW_RETURN_NOT_OK(queue_->Append(MakePromise())); - - ARROW_ASSIGN_OR_RAISE(auto out, it_promise->out_); - if (out == IterationTraits::End()) { - done_ = true; - } - return out; - } - - static Result> Make(Iterator it, int readahead_queue_size) { - ReadaheadIterator rh(std::move(it), readahead_queue_size); - ARROW_RETURN_NOT_OK(rh.Pump()); - return Iterator(std::move(rh)); - } - - private: - explicit ReadaheadIterator(Iterator it, int readahead_queue_size) - : it_(new Iterator(std::move(it))), - queue_(new detail::ReadaheadQueue(readahead_queue_size)) {} - - Status Pump() { - return queue_->Pump([this]() { return MakePromise(); }); - } - - std::unique_ptr MakePromise() { - return std::unique_ptr(new PromiseType{it_.get()}); - } - - // The underlying iterator is referenced by pointer in ReadaheadPromise, - // so make sure it doesn't move. - std::unique_ptr> it_; - std::unique_ptr queue_; - bool done_ = false; -}; - -template -Result> MakeReadaheadIterator(Iterator it, int readahead_queue_size) { - return ReadaheadIterator::Make(std::move(it), readahead_queue_size); -} - } // namespace arrow diff --git a/cpp/src/arrow/util/iterator_test.cc b/cpp/src/arrow/util/iterator_test.cc index 7295627b7c8..322611bb3ee 100644 --- a/cpp/src/arrow/util/iterator_test.cc +++ b/cpp/src/arrow/util/iterator_test.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/util/iterator.h" - #include #include #include @@ -28,6 +26,8 @@ #include #include "arrow/testing/gtest_util.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/iterator.h" namespace arrow { @@ -49,6 +49,32 @@ struct IterationTraits { static TestInt End() { return TestInt(); } }; +struct TestStr { + TestStr() : value("") {} + TestStr(const std::string& s) : value(s) {} // NOLINT runtime/explicit + TestStr(const char* s) : value(s) {} // NOLINT runtime/explicit + explicit TestStr(const TestInt& test_int) { + if (test_int == IterationTraits::End()) { + value = ""; + } else { + value = std::to_string(test_int.value); + } + } + std::string value; + + bool operator==(const TestStr& other) const { return value == other.value; } + + friend std::ostream& operator<<(std::ostream& os, const TestStr& v) { + os << "{\"" << v.value << "\"}"; + return os; + } +}; + +template <> +struct IterationTraits { + static TestStr End() { return TestStr(); } +}; + template class TracingIterator { public: @@ -129,11 +155,45 @@ template inline Iterator EmptyIt() { return MakeEmptyIterator(); } - inline Iterator VectorIt(std::vector v) { return MakeVectorIterator(std::move(v)); } +AsyncGenerator AsyncVectorIt(std::vector v) { + size_t index = 0; + return [index, v]() mutable -> Future { + if (index >= v.size()) { + return Future::MakeFinished(IterationTraits::End()); + } + return Future::MakeFinished(v[index++]); + }; +} + +constexpr auto kYieldDuration = std::chrono::microseconds(50); + +// Yields items with a small pause between each one from a background thread +std::function()> BackgroundAsyncVectorIt(std::vector v) { + auto pool = internal::GetCpuThreadPool(); + auto iterator = VectorIt(v); + auto slow_iterator = MakeTransformedIterator( + std::move(iterator), [](TestInt item) -> Result> { + std::this_thread::sleep_for(kYieldDuration); + return TransformYield(item); + }); + EXPECT_OK_AND_ASSIGN(auto background, + MakeBackgroundGenerator(std::move(slow_iterator), + internal::GetCpuThreadPool())); + return MakeTransferredGenerator(background, pool); +} + +std::vector RangeVector(unsigned int max) { + std::vector range(max); + for (unsigned int i = 0; i < max; i++) { + range[i] = i; + } + return range; +} + template inline Iterator VectorIt(std::vector v) { return MakeVectorIterator(std::move(v)); @@ -154,6 +214,13 @@ void AssertIteratorMatch(std::vector expected, Iterator actual) { EXPECT_EQ(expected, IteratorToVector(std::move(actual))); } +template +void AssertAsyncGeneratorMatch(std::vector expected, AsyncGenerator actual) { + auto vec_future = CollectAsyncGenerator(std::move(actual)); + EXPECT_OK_AND_ASSIGN(auto vec, vec_future.result()); + EXPECT_EQ(expected, vec); +} + template void AssertIteratorNoMatch(std::vector expected, Iterator actual) { EXPECT_NE(expected, IteratorToVector(std::move(actual))); @@ -170,6 +237,9 @@ void AssertIteratorExhausted(Iterator& it) { AssertIteratorNext(IterationTraits::End(), it); } +// -------------------------------------------------------------------- +// Synchronous iterator tests + TEST(TestEmptyIterator, Basic) { AssertIteratorMatch({}, EmptyIt()); } TEST(TestVectorIterator, Basic) { @@ -214,6 +284,118 @@ TEST(TestVectorIterator, RangeForLoop) { ASSERT_EQ(ints_it, ints.end()); } +Transformer MakeFirstN(int n) { + int remaining = n; + return [remaining](TestInt next) mutable -> Result> { + if (remaining > 0) { + remaining--; + return TransformYield(TestStr(next)); + } + return TransformFinish(); + }; +} + +template +Transformer MakeFirstNGeneric(int n) { + int remaining = n; + return [remaining](T next) mutable -> Result> { + if (remaining > 0) { + remaining--; + return TransformYield(next); + } + return TransformFinish(); + }; +} + +TEST(TestIteratorTransform, Truncating) { + auto original = VectorIt({1, 2, 3}); + auto truncated = MakeTransformedIterator(std::move(original), MakeFirstN(2)); + AssertIteratorMatch({"1", "2"}, std::move(truncated)); +} + +TEST(TestIteratorTransform, TestPointer) { + auto original = VectorIt>( + {std::make_shared(1), std::make_shared(2), std::make_shared(3)}); + auto truncated = MakeTransformedIterator(std::move(original), + MakeFirstNGeneric>(2)); + ASSERT_OK_AND_ASSIGN(auto result, truncated.ToVector()); + ASSERT_EQ(2, result.size()); +} + +TEST(TestIteratorTransform, TruncatingShort) { + // Tests the failsafe case where we never call Finish + auto original = VectorIt({1}); + auto truncated = + MakeTransformedIterator(std::move(original), MakeFirstN(2)); + AssertIteratorMatch({"1"}, std::move(truncated)); +} + +Transformer MakeFilter(std::function filter) { + return [filter](TestInt next) -> Result> { + if (filter(next)) { + return TransformYield(TestStr(next)); + } else { + return TransformSkip(); + } + }; +} + +TEST(TestIteratorTransform, SkipSome) { + // Exercises TransformSkip + auto original = VectorIt({1, 2, 3}); + auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); + auto filtered = MakeTransformedIterator(std::move(original), filter); + AssertIteratorMatch({"1", "3"}, std::move(filtered)); +} + +TEST(TestIteratorTransform, SkipAll) { + // Exercises TransformSkip + auto original = VectorIt({1, 2, 3}); + auto filter = MakeFilter([](TestInt& t) { return false; }); + auto filtered = MakeTransformedIterator(std::move(original), filter); + AssertIteratorMatch({}, std::move(filtered)); +} + +Transformer MakeAbortOnSecond() { + int counter = 0; + return [counter](TestInt next) mutable -> Result> { + if (counter++ == 1) { + return Status::Invalid("X"); + } + return TransformYield(TestStr(next)); + }; +} + +TEST(TestIteratorTransform, Abort) { + auto original = VectorIt({1, 2, 3}); + auto transformed = MakeTransformedIterator(std::move(original), MakeAbortOnSecond()); + ASSERT_OK(transformed.Next()); + ASSERT_RAISES(Invalid, transformed.Next()); + ASSERT_OK_AND_ASSIGN(auto third, transformed.Next()); + ASSERT_EQ(IterationTraits::End(), third); +} + +template +Transformer MakeRepeatN(int repeat_count) { + int current_repeat = 0; + return [repeat_count, current_repeat](T next) mutable -> Result> { + current_repeat++; + bool ready_for_next = false; + if (current_repeat == repeat_count) { + current_repeat = 0; + ready_for_next = true; + } + return TransformYield(next, ready_for_next); + }; +} + +TEST(TestIteratorTransform, Repeating) { + auto original = VectorIt({1, 2, 3}); + auto repeated = MakeTransformedIterator(std::move(original), + MakeRepeatN(2)); + AssertIteratorMatch({1, 1, 2, 2, 3, 3}, std::move(repeated)); +} + TEST(TestFunctionIterator, RangeForLoop) { int i = 0; auto fails_at_3 = MakeFunctionIterator([&]() -> Result { @@ -295,13 +477,6 @@ TEST(FlattenVectorIterator, Pyramid) { AssertIteratorMatch({1, 2, 2, 3, 3, 3}, std::move(it)); } -TEST(ReadaheadIterator, DefaultConstructor) { - ReadaheadIterator it; - TestInt v{42}; - ASSERT_OK_AND_ASSIGN(v, it.Next()); - ASSERT_EQ(v, TestInt()); -} - TEST(ReadaheadIterator, Empty) { ASSERT_OK_AND_ASSIGN(auto it, MakeReadaheadIterator(VectorIt({}), 2)); AssertIteratorMatch({}, std::move(it)); @@ -329,13 +504,16 @@ TEST(ReadaheadIterator, Trace) { ASSERT_OK_AND_ASSIGN( auto it, MakeReadaheadIterator(Iterator(std::move(tracing_it)), 2)); - tracing->WaitForValues(2); - SleepABit(); // check no further value is emitted - tracing->AssertValuesEqual({1, 2}); + SleepABit(); // Background iterator won't start pumping until first request comes in + ASSERT_EQ(tracing->values().size(), 0); + + AssertIteratorNext({1}, it); // Once we ask for one value we should get that one value + // as well as 2 read ahead - AssertIteratorNext({1}, it); tracing->WaitForValues(3); - SleepABit(); + tracing->AssertValuesEqual({1, 2, 3}); + + SleepABit(); // No further values should be fetched tracing->AssertValuesEqual({1, 2, 3}); AssertIteratorNext({2}, it); @@ -383,13 +561,247 @@ TEST(ReadaheadIterator, NextError) { ASSERT_RAISES(IOError, it.Next().status()); - AssertIteratorNext({1}, it); - tracing->WaitForValues(3); + AssertIteratorExhausted(it); SleepABit(); - tracing->AssertValuesEqual({1, 2, 3}); - AssertIteratorNext({2}, it); - AssertIteratorNext({3}, it); + tracing->AssertValuesEqual({}); AssertIteratorExhausted(it); } +// -------------------------------------------------------------------- +// Asynchronous iterator tests + +TEST(TestAsyncUtil, Visit) { + auto generator = AsyncVectorIt({1, 2, 3}); + unsigned int sum = 0; + auto sum_future = VisitAsyncGenerator(generator, [&sum](TestInt item) { + sum += item.value; + return Status::OK(); + }); + ASSERT_TRUE(sum_future.is_finished()); + ASSERT_EQ(6, sum); +} + +TEST(TestAsyncUtil, Collect) { + std::vector expected = {1, 2, 3}; + auto generator = AsyncVectorIt(expected); + auto collected = CollectAsyncGenerator(generator); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected_val, collected); + ASSERT_EQ(expected, collected_val); +} + +TEST(TestAsyncUtil, SynchronousFinish) { + AsyncGenerator generator = []() { + return Future::MakeFinished(IterationTraits::End()); + }; + Transformer skip_all = [](TestInt value) { return TransformSkip(); }; + auto transformed = MakeAsyncGenerator(generator, skip_all); + auto future = CollectAsyncGenerator(transformed); + ASSERT_TRUE(future.is_finished()); + ASSERT_OK_AND_ASSIGN(auto actual, future.result()); + ASSERT_EQ(std::vector(), actual); +} + +TEST(TestAsyncUtil, GeneratorIterator) { + auto generator = BackgroundAsyncVectorIt({1, 2, 3}); + ASSERT_OK_AND_ASSIGN(auto iterator, MakeGeneratorIterator(std::move(generator))); + ASSERT_OK_AND_EQ(TestInt(1), iterator.Next()); + ASSERT_OK_AND_EQ(TestInt(2), iterator.Next()); + ASSERT_OK_AND_EQ(TestInt(3), iterator.Next()); + ASSERT_OK_AND_EQ(IterationTraits::End(), iterator.Next()); + ASSERT_OK_AND_EQ(IterationTraits::End(), iterator.Next()); +} + +TEST(TestAsyncUtil, MakeTransferredGenerator) { + std::mutex mutex; + std::condition_variable cv; + std::atomic finished(false); + + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(1)); + + // Needs to be a slow source to ensure we don't call Then on a completed + AsyncGenerator slow_generator = [&]() { + return thread_pool + ->Submit([&] { + std::unique_lock lock(mutex); + cv.wait_for(lock, std::chrono::duration(30), + [&] { return finished.load(); }); + return IterationTraits::End(); + }) + .ValueOrDie(); + }; + + auto transferred = + MakeTransferredGenerator(std::move(slow_generator), thread_pool.get()); + + auto current_thread_id = std::this_thread::get_id(); + auto fut = transferred().Then([¤t_thread_id](const Result& result) { + ASSERT_NE(current_thread_id, std::this_thread::get_id()); + }); + + { + std::lock_guard lg(mutex); + finished.store(true); + } + cv.notify_one(); + ASSERT_FINISHES_OK(fut); +} + +// This test is too slow for valgrind +#if !(defined(ARROW_VALGRIND) || defined(ADDRESS_SANITIZER)) + +TEST(TestAsyncUtil, StackOverflow) { + int counter = 0; + AsyncGenerator generator = [&counter]() { + if (counter < 1000000) { + return Future::MakeFinished(counter++); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + Transformer discard = + [](TestInt next) -> Result> { return TransformSkip(); }; + auto transformed = MakeAsyncGenerator(generator, discard); + auto collected_future = CollectAsyncGenerator(transformed); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, collected_future); + ASSERT_EQ(0, collected.size()); +} + +#endif + +TEST(TestAsyncUtil, Background) { + std::vector expected = {1, 2, 3}; + auto background = BackgroundAsyncVectorIt(expected); + auto future = CollectAsyncGenerator(background); + ASSERT_FINISHES_OK_AND_ASSIGN(auto collected, future); + ASSERT_EQ(expected, collected); +} + +struct SlowEmptyIterator { + Result Next() { + if (called_) { + return Status::Invalid("Should not have been called twice"); + } + SleepFor(0.1); + return IterationTraits::End(); + } + + private: + bool called_ = false; +}; + +TEST(TestAsyncUtil, BackgroundRepeatEnd) { + // Ensure that the background generator properly fulfills the asyncgenerator contract + // and can be called after it ends. + ASSERT_OK_AND_ASSIGN(auto io_pool, internal::ThreadPool::Make(1)); + + auto iterator = Iterator(SlowEmptyIterator()); + ASSERT_OK_AND_ASSIGN(auto background_gen, + MakeBackgroundGenerator(std::move(iterator), io_pool.get())); + + background_gen = + MakeTransferredGenerator(std::move(background_gen), internal::GetCpuThreadPool()); + + auto one = background_gen(); + auto two = background_gen(); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto one_fin, one); + ASSERT_EQ(IterationTraits::End(), one_fin); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto two_fin, two); + ASSERT_EQ(IterationTraits::End(), two_fin); +} + +TEST(TestAsyncUtil, CompleteBackgroundStressTest) { + auto expected = RangeVector(20); + std::vector>> futures; + for (unsigned int i = 0; i < 20; i++) { + auto background = BackgroundAsyncVectorIt(expected); + futures.push_back(CollectAsyncGenerator(background)); + } + auto combined = All(futures); + ASSERT_FINISHES_OK_AND_ASSIGN(auto completed_vectors, combined); + for (std::size_t i = 0; i < completed_vectors.size(); i++) { + ASSERT_OK_AND_ASSIGN(auto vector, completed_vectors[i]); + ASSERT_EQ(vector, expected); + } +} + +TEST(TestAsyncUtil, Readahead) { + int num_delivered = 0; + auto source = [&num_delivered]() { + if (num_delivered < 5) { + return Future::MakeFinished(num_delivered++); + } else { + return Future::MakeFinished(IterationTraits::End()); + } + }; + auto readahead = MakeReadaheadGenerator(source, 10); + // Should not pump until first item requested + ASSERT_EQ(0, num_delivered); + + auto first = readahead(); + // At this point the pumping should have happened + ASSERT_EQ(5, num_delivered); + ASSERT_FINISHES_OK_AND_ASSIGN(auto first_val, first); + ASSERT_EQ(TestInt(0), first_val); + + // Read the rest + for (int i = 0; i < 4; i++) { + auto next = readahead(); + ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, next); + ASSERT_EQ(TestInt(i + 1), next_val); + } + + // Next should be end + auto last = readahead(); + ASSERT_FINISHES_OK_AND_ASSIGN(auto last_val, last); + ASSERT_EQ(IterationTraits::End(), last_val); +} + +TEST(TestAsyncUtil, ReadaheadFailed) { + ASSERT_OK_AND_ASSIGN(auto thread_pool, internal::ThreadPool::Make(4)); + std::atomic counter(0); + // All tasks are a little slow. The first task fails. + // The readahead will have spawned 9 more tasks and they + // should all pass + auto source = [thread_pool, &counter]() -> Future { + auto count = counter++; + return *thread_pool->Submit([count]() -> Result { + if (count == 0) { + return Status::Invalid("X"); + } + return TestInt(count); + }); + }; + auto readahead = MakeReadaheadGenerator(source, 10); + ASSERT_FINISHES_ERR(Invalid, readahead()); + SleepABit(); + + for (int i = 0; i < 9; i++) { + ASSERT_FINISHES_OK_AND_ASSIGN(auto next_val, readahead()); + ASSERT_EQ(TestInt(i + 1), next_val); + } + ASSERT_FINISHES_OK_AND_ASSIGN(auto after, readahead()); + + // It's possible that finished was set quickly and there + // are only 10 elements + if (after == IterationTraits::End()) { + return; + } + + // It's also possible that finished was too slow and there + // ended up being 11 elements + ASSERT_EQ(TestInt(10), after); + // There can't be 12 elements because SleepABit will prevent it + ASSERT_FINISHES_OK_AND_ASSIGN(auto definitely_last, readahead()); + ASSERT_EQ(IterationTraits::End(), definitely_last); +} + +TEST(TestAsyncIteratorTransform, SkipSome) { + auto original = AsyncVectorIt({1, 2, 3}); + auto filter = MakeFilter([](TestInt& t) { return t.value != 2; }); + auto filtered = MakeAsyncGenerator(std::move(original), filter); + AssertAsyncGeneratorMatch({"1", "3"}, std::move(filtered)); +} + } // namespace arrow diff --git a/cpp/src/arrow/util/task_group.cc b/cpp/src/arrow/util/task_group.cc index 87656024648..a7b55921d32 100644 --- a/cpp/src/arrow/util/task_group.cc +++ b/cpp/src/arrow/util/task_group.cc @@ -54,6 +54,8 @@ class SerialTaskGroup : public TaskGroup { return status_; } + Future<> FinishAsync() override { return Future<>::MakeFinished(Finish()); } + int parallelism() override { return 1; } Status status_; @@ -114,6 +116,18 @@ class ThreadedTaskGroup : public TaskGroup { return status_; } + Future<> FinishAsync() override { + std::lock_guard lock(mutex_); + if (!completion_future_.has_value()) { + if (nremaining_.load() == 0) { + completion_future_ = Future<>::MakeFinished(status_); + } else { + completion_future_ = Future<>::Make(); + } + } + return *completion_future_; + } + int parallelism() override { return executor_->GetCapacity(); } protected: @@ -135,6 +149,21 @@ class ThreadedTaskGroup : public TaskGroup { // before cv.notify_one() has returned std::unique_lock lock(mutex_); cv_.notify_one(); + if (completion_future_.has_value()) { + // MarkFinished could be slow. We don't want to call it while we are holding + // the lock. + auto& future = *completion_future_; + const auto finished = completion_future_->is_finished(); + const auto& status = status_; + // This will be redundant if the user calls Finish and not FinishAsync + if (!finished && !finished_) { + finished_ = true; + lock.unlock(); + future.MarkFinished(status); + } else { + lock.unlock(); + } + } } } @@ -148,6 +177,7 @@ class ThreadedTaskGroup : public TaskGroup { std::condition_variable cv_; Status status_; bool finished_ = false; + util::optional> completion_future_; }; std::shared_ptr TaskGroup::MakeSerial() { diff --git a/cpp/src/arrow/util/task_group.h b/cpp/src/arrow/util/task_group.h index db3265df1c3..a6df43f1131 100644 --- a/cpp/src/arrow/util/task_group.h +++ b/cpp/src/arrow/util/task_group.h @@ -63,6 +63,20 @@ class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this { /// task (or subgroup). virtual Status Finish() = 0; + /// Returns a future that will complete the first time all tasks are finished. + /// This should be called only after all top level tasks + /// have been added to the task group. + /// + /// If you are using a TaskGroup asynchronously there are a few considerations to keep + /// in mind. The tasks should not block on I/O, etc (defeats the purpose of using + /// futures) and should not be doing any nested locking or you run the risk of the tasks + /// getting stuck in the thread pool waiting for tasks which cannot get scheduled. + /// + /// Primarily this call is intended to help migrate existing work written with TaskGroup + /// in mind to using futures without having to do a complete conversion on the first + /// pass. + virtual Future<> FinishAsync() = 0; + /// The current aggregate error Status. Non-blocking, useful for stopping early. virtual Status current_status() = 0; diff --git a/cpp/src/arrow/util/task_group_test.cc b/cpp/src/arrow/util/task_group_test.cc index 1e47a341fd8..38f4b211820 100644 --- a/cpp/src/arrow/util/task_group_test.cc +++ b/cpp/src/arrow/util/task_group_test.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -243,6 +244,68 @@ void TestNoCopyTask(std::shared_ptr task_group) { ASSERT_EQ(0, *counter); } +void TestFinishNotSticky(std::function()> factory) { + // If a task is added that runs very quickly it might decrement the task counter back + // down to 0 and mark the completion future as complete before all tasks are added. + // The "finished future" of the task group could get stuck to complete. + // + // Instead the task group should not allow the finished future to be marked complete + // until after FinishAsync has been called. + const int NTASKS = 100; + for (int i = 0; i < NTASKS; ++i) { + auto task_group = factory(); + // Add a task and let it complete + task_group->Append([] { return Status::OK(); }); + // Wait a little bit, if the task group was going to lock the finish hopefully it + // would do so here while we wait + SleepFor(1e-2); + + // Add a new task that will still be running + std::atomic ready(false); + std::mutex m; + std::condition_variable cv; + task_group->Append([&m, &cv, &ready] { + std::unique_lock lk(m); + cv.wait(lk, [&ready] { return ready.load(); }); + return Status::OK(); + }); + + // Ensure task group not finished already + auto finished = task_group->FinishAsync(); + ASSERT_FALSE(finished.is_finished()); + + std::unique_lock lk(m); + ready = true; + lk.unlock(); + cv.notify_one(); + + ASSERT_FINISHES_OK(finished); + } +} + +void TestFinishNeverStarted(std::shared_ptr task_group) { + // If we call FinishAsync we are done adding tasks so if we never added any it should be + // completed + auto finished = task_group->FinishAsync(); + ASSERT_TRUE(finished.Wait(1)); +} + +void TestFinishAlreadyCompleted(std::function()> factory) { + // If we call FinishAsync we are done adding tasks so even if no tasks are running we + // should still be completed + const int NTASKS = 100; + for (int i = 0; i < NTASKS; ++i) { + auto task_group = factory(); + // Add a task and let it complete + task_group->Append([] { return Status::OK(); }); + // Wait a little bit, hopefully enough time for the task to finish on one of these + // iterations + SleepFor(1e-2); + auto finished = task_group->FinishAsync(); + ASSERT_FINISHES_OK(finished); + } +} + TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); } TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); } @@ -251,6 +314,14 @@ TEST(SerialTaskGroup, TasksSpawnTasks) { TestTasksSpawnTasks(TaskGroup::MakeSeri TEST(SerialTaskGroup, NoCopyTask) { TestNoCopyTask(TaskGroup::MakeSerial()); } +TEST(SerialTaskGroup, FinishNeverStarted) { + TestFinishNeverStarted(TaskGroup::MakeSerial()); +} + +TEST(SerialTaskGroup, FinishAlreadyCompleted) { + TestFinishAlreadyCompleted([] { return TaskGroup::MakeSerial(); }); +} + TEST(ThreadedTaskGroup, Success) { auto task_group = TaskGroup::MakeThreaded(GetCpuThreadPool()); TestTaskGroupSuccess(task_group); @@ -291,5 +362,25 @@ TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) { [&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); } +TEST(ThreadedTaskGroup, FinishNotSticky) { + std::shared_ptr thread_pool; + ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); + + TestFinishNotSticky([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); +} + +TEST(ThreadedTaskGroup, FinishNeverStarted) { + std::shared_ptr thread_pool; + ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(4)); + TestFinishNeverStarted(TaskGroup::MakeThreaded(thread_pool.get())); +} + +TEST(ThreadedTaskGroup, FinishAlreadyCompleted) { + std::shared_ptr thread_pool; + ASSERT_OK_AND_ASSIGN(thread_pool, ThreadPool::Make(16)); + + TestFinishAlreadyCompleted([&] { return TaskGroup::MakeThreaded(thread_pool.get()); }); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index 03b925d7bb1..5db3a9a4722 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -86,6 +86,28 @@ class ARROW_EXPORT Executor { return SpawnReal(hints, std::forward(func)); } + // Transfers a future to this executor. Any continuations added to the + // returned future will run in this executor. Otherwise they would run + // on the same thread that called MarkFinished. + // + // This is necessary when (for example) an I/O task is completing a future. + // The continuations of that future should run on the CPU thread pool keeping + // CPU heavy work off the I/O thread pool. So the I/O task should transfer + // the future to the CPU executor before returning. + template + Future Transfer(Future future) { + auto transferred = Future::Make(); + future.AddCallback([this, transferred](const Result& result) mutable { + auto spawn_status = Spawn([transferred, result]() mutable { + transferred.MarkFinished(std::move(result)); + }); + if (!spawn_status.ok()) { + transferred.MarkFinished(spawn_status); + } + }); + return transferred; + } + // Submit a callable and arguments for execution. Return a future that // will return the callable's result value once. // The callable's arguments are copied before execution. diff --git a/docs/source/cpp/csv.rst b/docs/source/cpp/csv.rst index 9f17d5692e6..44dc1498f18 100644 --- a/docs/source/cpp/csv.rst +++ b/docs/source/cpp/csv.rst @@ -42,6 +42,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`. { // ... arrow::MemoryPool* pool = default_memory_pool(); + arrow::io::AsyncContext async_context; std::shared_ptr input = ...; auto read_options = arrow::csv::ReadOptions::Defaults(); @@ -51,6 +52,7 @@ A CSV file is read from a :class:`~arrow::io::InputStream`. // Instantiate TableReader from input stream and options auto maybe_reader = arrow::csv::TableReader::Make(pool, + async_context, input, read_options, parse_options, diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index 34c6693c51e..4068a0b9141 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -700,6 +700,7 @@ def read_csv(input_file, read_options=None, parse_options=None, CCSVConvertOptions c_convert_options shared_ptr[CCSVReader] reader shared_ptr[CTable] table + CAsyncContext c_async_ctx = CAsyncContext() _get_reader(input_file, read_options, &stream) _get_read_options(read_options, &c_read_options) @@ -707,7 +708,7 @@ def read_csv(input_file, read_options=None, parse_options=None, _get_convert_options(convert_options, &c_convert_options) reader = GetResultValue(CCSVReader.Make( - maybe_unbox_memory_pool(memory_pool), stream, + maybe_unbox_memory_pool(memory_pool), c_async_ctx, stream, c_read_options, c_parse_options, c_convert_options)) with nogil: diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 41159bd142b..6c1c7f671c7 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1140,6 +1140,9 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: ObjectType_FILE" arrow::io::ObjectType::FILE" ObjectType_DIRECTORY" arrow::io::ObjectType::DIRECTORY" + cdef cppclass CAsyncContext" arrow::io::AsyncContext": + CAsyncContext() + cdef cppclass FileStatistics: int64_t size ObjectType kind @@ -1618,7 +1621,7 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil: cdef cppclass CCSVReader" arrow::csv::TableReader": @staticmethod CResult[shared_ptr[CCSVReader]] Make( - CMemoryPool*, shared_ptr[CInputStream], + CMemoryPool*, CAsyncContext, shared_ptr[CInputStream], CCSVReadOptions, CCSVParseOptions, CCSVConvertOptions) CResult[shared_ptr[CTable]] Read() diff --git a/r/src/csv.cpp b/r/src/csv.cpp index 54d3abc3821..69b834a6be0 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -141,8 +141,9 @@ std::shared_ptr csv___TableReader__Make( const std::shared_ptr& read_options, const std::shared_ptr& parse_options, const std::shared_ptr& convert_options) { - return ValueOrStop(arrow::csv::TableReader::Make(gc_memory_pool(), input, *read_options, - *parse_options, *convert_options)); + return ValueOrStop( + arrow::csv::TableReader::Make(gc_memory_pool(), arrow::io::AsyncContext(), input, + *read_options, *parse_options, *convert_options)); } // [[arrow::export]]