diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 78f3d753711..8a469e3fe12 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -987,8 +987,9 @@ std::unique_ptr KernelExecutor::MakeScalarAggregate() { } // namespace detail -ExecContext::ExecContext(MemoryPool* pool, FunctionRegistry* func_registry) - : pool_(pool) { +ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor, + FunctionRegistry* func_registry) + : pool_(pool), executor_(executor) { this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; } diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index e7015814d2a..77d04b86ceb 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -34,6 +34,7 @@ #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" +#include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" namespace arrow { @@ -60,6 +61,7 @@ class ARROW_EXPORT ExecContext { public: // If no function registry passed, the default is used. explicit ExecContext(MemoryPool* pool = default_memory_pool(), + ::arrow::internal::Executor* executor = NULLPTR, FunctionRegistry* func_registry = NULLPTR); /// \brief The MemoryPool used for allocations, default is @@ -68,6 +70,9 @@ class ARROW_EXPORT ExecContext { ::arrow::internal::CpuInfo* cpu_info() const; + /// \brief An Executor which may be used to parallelize execution. + ::arrow::internal::Executor* executor() const { return executor_; } + /// \brief The FunctionRegistry for looking up functions by name and /// selecting kernels for execution. Defaults to the library-global function /// registry provided by GetFunctionRegistry. @@ -114,6 +119,7 @@ class ARROW_EXPORT ExecContext { private: MemoryPool* pool_; + ::arrow::internal::Executor* executor_; FunctionRegistry* func_registry_; int64_t exec_chunksize_ = std::numeric_limits::max(); bool preallocate_contiguous_ = true; diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 2dcbfb24724..d0d50af1ac7 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -39,11 +39,13 @@ namespace compute { namespace { struct ExecPlanImpl : public ExecPlan { - ExecPlanImpl() = default; + explicit ExecPlanImpl(ExecContext* exec_context) : ExecPlan(exec_context) {} ~ExecPlanImpl() override { - if (started_ && !stopped_) { + if (started_ && !finished_.is_finished()) { + ARROW_LOG(WARNING) << "Plan was destroyed before finishing"; StopProducing(); + finished().Wait(); } } @@ -77,25 +79,40 @@ struct ExecPlanImpl : public ExecPlan { // producers precede consumers sorted_nodes_ = TopoSort(); - for (size_t i = 0, rev_i = sorted_nodes_.size() - 1; i < sorted_nodes_.size(); - ++i, --rev_i) { - auto st = sorted_nodes_[rev_i]->StartProducing(); - if (st.ok()) continue; + std::vector> futures; - // Stop nodes that successfully started, in reverse order - for (; rev_i < sorted_nodes_.size(); ++rev_i) { - sorted_nodes_[rev_i]->StopProducing(); + Status st = Status::OK(); + + using rev_it = std::reverse_iterator; + for (rev_it it(sorted_nodes_.end()), end(sorted_nodes_.begin()); it != end; ++it) { + auto node = *it; + + st = node->StartProducing(); + if (!st.ok()) { + // Stop nodes that successfully started, in reverse order + stopped_ = true; + StopProducingImpl(it.base(), sorted_nodes_.end()); + break; } - return st; + + futures.push_back(node->finished()); } - return Status::OK(); + + finished_ = AllComplete(std::move(futures)); + return st; } void StopProducing() { DCHECK(started_) << "stopped an ExecPlan which never started"; stopped_ = true; - for (const auto& node : sorted_nodes_) { + StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end()); + } + + template + void StopProducingImpl(It begin, It end) { + for (auto it = begin; it != end; ++it) { + auto node = *it; node->StopProducing(); } } @@ -133,10 +150,11 @@ struct ExecPlanImpl : public ExecPlan { return std::move(Impl{nodes_}.sorted); } + Future<> finished_ = Future<>::MakeFinished(); bool started_ = false, stopped_ = false; std::vector> nodes_; - NodeVector sorted_nodes_; NodeVector sources_, sinks_; + NodeVector sorted_nodes_; }; ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast(ptr); } @@ -155,8 +173,8 @@ util::optional GetNodeIndex(const std::vector& nodes, } // namespace -Result> ExecPlan::Make() { - return std::make_shared(); +Result> ExecPlan::Make(ExecContext* ctx) { + return std::shared_ptr(new ExecPlanImpl{ctx}); } ExecNode* ExecPlan::AddNode(std::unique_ptr node) { @@ -175,6 +193,8 @@ Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); } +Future<> ExecPlan::finished() { return ToDerived(this)->finished_; } + ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, std::shared_ptr output_schema, int num_outputs) @@ -220,58 +240,61 @@ struct SourceNode : ExecNode { const char* kind_name() override { return "SourceNode"; } - static void NoInputs() { DCHECK(false) << "no inputs; this should never be called"; } - void InputReceived(ExecNode*, int, ExecBatch) override { NoInputs(); } - void ErrorReceived(ExecNode*, Status) override { NoInputs(); } - void InputFinished(ExecNode*, int) override { NoInputs(); } + [[noreturn]] static void NoInputs() { + DCHECK(false) << "no inputs; this should never be called"; + std::abort(); + } + [[noreturn]] void InputReceived(ExecNode*, int, ExecBatch) override { NoInputs(); } + [[noreturn]] void ErrorReceived(ExecNode*, Status) override { NoInputs(); } + [[noreturn]] void InputFinished(ExecNode*, int) override { NoInputs(); } Status StartProducing() override { - if (finished_) { - return Status::Invalid("Restarted SourceNode '", label(), "'"); + DCHECK(!stop_requested_) << "Restarted SourceNode"; + + CallbackOptions options; + if (auto executor = plan()->exec_context()->executor()) { + // These options will transfer execution to the desired Executor if necessary. + // This can happen for in-memory scans where batches didn't require + // any CPU work to decode. Otherwise, parsing etc should have already + // been placed us on the desired Executor and no queues will be pushed to. + options.executor = executor; + options.should_schedule = ShouldSchedule::IfDifferentExecutor; } - finished_fut_ = - Loop([this] { - std::unique_lock lock(mutex_); - int seq = next_batch_index_++; - if (finished_) { - return Future>::MakeFinished(Break(seq)); - } - lock.unlock(); - - return generator_().Then( - [=](const util::optional& batch) -> ControlFlow { - std::unique_lock lock(mutex_); - if (!batch || finished_) { - finished_ = true; - return Break(seq); - } - lock.unlock(); - - // TODO check if we are on the desired Executor and transfer if not. - // This can happen for in-memory scans where batches didn't require - // any CPU work to decode. Otherwise, parsing etc should have already - // been placed us on the thread pool - outputs_[0]->InputReceived(this, seq, *batch); - return Continue(); - }, - [=](const Status& error) -> ControlFlow { - std::unique_lock lock(mutex_); - if (!finished_) { - finished_ = true; + finished_ = Loop([this, options] { + std::unique_lock lock(mutex_); + int seq = batch_count_++; + if (stop_requested_) { + return Future>::MakeFinished(Break(seq)); + } lock.unlock(); - // unless we were already finished, push the error to our output - // XXX is this correct? Is it reasonable for a consumer to - // ignore errors from a finished producer? - outputs_[0]->ErrorReceived(this, error); - } - return Break(seq); - }); - }).Then([&](int seq) { - /// XXX this is probably redundant: do we always call InputFinished after - /// ErrorReceived or will ErrorRecieved be sufficient? - outputs_[0]->InputFinished(this, seq); - }); + + return generator_().Then( + [=](const util::optional& batch) -> ControlFlow { + std::unique_lock lock(mutex_); + if (IsIterationEnd(batch) || stop_requested_) { + stop_requested_ = true; + return Break(seq); + } + lock.unlock(); + + outputs_[0]->InputReceived(this, seq, *batch); + return Continue(); + }, + [=](const Status& error) -> ControlFlow { + // NB: ErrorReceived is independent of InputFinished, but + // ErrorReceived will usually prompt StopProducing which will + // prompt InputFinished. ErrorReceived may still be called from a + // node which was requested to stop (indeed, the request to stop + // may prompt an error). + std::unique_lock lock(mutex_); + stop_requested_ = true; + lock.unlock(); + outputs_[0]->ErrorReceived(this, error); + return Break(seq); + }, + options); + }).Then([&](int seq) { outputs_[0]->InputFinished(this, seq); }); return Status::OK(); } @@ -282,20 +305,21 @@ struct SourceNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); - { - std::unique_lock lock(mutex_); - finished_ = true; - } - finished_fut_.Wait(); + StopProducing(); } - void StopProducing() override { StopProducing(outputs_[0]); } + void StopProducing() override { + std::unique_lock lock(mutex_); + stop_requested_ = true; + } + + Future<> finished() override { return finished_; } private: std::mutex mutex_; - bool finished_{false}; - int next_batch_index_{0}; - Future<> finished_fut_ = Future<>::MakeFinished(); + bool stop_requested_{false}; + int batch_count_{0}; + Future<> finished_ = Future<>::MakeFinished(); AsyncGenerator> generator_; }; @@ -319,8 +343,8 @@ struct FilterNode : ExecNode { ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, SimplifyWithGuarantee(filter_, target.guarantee)); - // XXX get a non-default exec context - ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target)); + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target, + plan()->exec_context())); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -331,6 +355,10 @@ struct FilterNode : ExecNode { return target.Slice(0, 0); } + // if the values are all scalar then the mask must also be + DCHECK(!std::all_of(target.values.begin(), target.values.end(), + [](const Datum& value) { return value.is_scalar(); })); + auto values = target.values; for (auto& value : values) { if (value.is_scalar()) continue; @@ -345,7 +373,6 @@ struct FilterNode : ExecNode { auto maybe_filtered = DoFilter(std::move(batch)); if (!maybe_filtered.ok()) { outputs_[0]->ErrorReceived(this, maybe_filtered.status()); - inputs_[0]->StopProducing(this); return; } @@ -356,7 +383,6 @@ struct FilterNode : ExecNode { void ErrorReceived(ExecNode* input, Status error) override { DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); - inputs_[0]->StopProducing(this); } void InputFinished(ExecNode* input, int seq) override { @@ -372,10 +398,12 @@ struct FilterNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); - inputs_[0]->StopProducing(this); + StopProducing(); } - void StopProducing() override { StopProducing(outputs_[0]); } + void StopProducing() override { inputs_[0]->StopProducing(this); } + + Future<> finished() override { return inputs_[0]->finished(); } private: Expression filter_; @@ -407,15 +435,15 @@ struct ProjectNode : ExecNode { const char* kind_name() override { return "ProjectNode"; } Result DoProject(const ExecBatch& target) { - // XXX get a non-default exec context std::vector values{exprs_.size()}; for (size_t i = 0; i < exprs_.size(); ++i) { ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, SimplifyWithGuarantee(exprs_[i], target.guarantee)); - ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target)); + ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target, + plan()->exec_context())); } - return ExecBatch::Make(std::move(values)); + return ExecBatch{std::move(values), target.length}; } void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { @@ -424,7 +452,6 @@ struct ProjectNode : ExecNode { auto maybe_projected = DoProject(std::move(batch)); if (!maybe_projected.ok()) { outputs_[0]->ErrorReceived(this, maybe_projected.status()); - inputs_[0]->StopProducing(this); return; } @@ -435,7 +462,6 @@ struct ProjectNode : ExecNode { void ErrorReceived(ExecNode* input, Status error) override { DCHECK_EQ(input, inputs_[0]); outputs_[0]->ErrorReceived(this, std::move(error)); - inputs_[0]->StopProducing(this); } void InputFinished(ExecNode* input, int seq) override { @@ -451,10 +477,12 @@ struct ProjectNode : ExecNode { void StopProducing(ExecNode* output) override { DCHECK_EQ(output, outputs_[0]); - inputs_[0]->StopProducing(this); + StopProducing(); } - void StopProducing() override { StopProducing(outputs_[0]); } + void StopProducing() override { inputs_[0]->StopProducing(this); } + + Future<> finished() override { return inputs_[0]->finished(); } private: std::vector exprs_; @@ -494,28 +522,38 @@ struct SinkNode : ExecNode { const char* kind_name() override { return "SinkNode"; } - Status StartProducing() override { return Status::OK(); } + Status StartProducing() override { + finished_ = Future<>::Make(); + return Status::OK(); + } // sink nodes have no outputs from which to feel backpressure - static void NoOutputs() { DCHECK(false) << "no outputs; this should never be called"; } - void ResumeProducing(ExecNode* output) override { NoOutputs(); } - void PauseProducing(ExecNode* output) override { NoOutputs(); } - void StopProducing(ExecNode* output) override { NoOutputs(); } + [[noreturn]] static void NoOutputs() { + DCHECK(false) << "no outputs; this should never be called"; + std::abort(); + } + [[noreturn]] void ResumeProducing(ExecNode* output) override { NoOutputs(); } + [[noreturn]] void PauseProducing(ExecNode* output) override { NoOutputs(); } + [[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); } void StopProducing() override { - std::unique_lock lock(mutex_); - InputFinishedUnlocked(); + Finish(); + inputs_[0]->StopProducing(this); } + Future<> finished() override { return finished_; } + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override { DCHECK_EQ(input, inputs_[0]); std::unique_lock lock(mutex_); - if (stopped_) return; + if (finished_.is_finished()) return; ++num_received_; if (num_received_ == emit_stop_) { - InputFinishedUnlocked(); + lock.unlock(); + Finish(); + lock.lock(); } if (emit_stop_ != -1) { @@ -529,23 +567,21 @@ struct SinkNode : ExecNode { void ErrorReceived(ExecNode* input, Status error) override { DCHECK_EQ(input, inputs_[0]); producer_.Push(std::move(error)); - std::unique_lock lock(mutex_); - InputFinishedUnlocked(); + Finish(); + inputs_[0]->StopProducing(this); } void InputFinished(ExecNode* input, int seq_stop) override { std::unique_lock lock(mutex_); emit_stop_ = seq_stop; - if (emit_stop_ == num_received_) { - InputFinishedUnlocked(); - } + lock.unlock(); + Finish(); } private: - void InputFinishedUnlocked() { - if (!stopped_) { - stopped_ = true; - producer_.Close(); + void Finish() { + if (producer_.Close()) { + finished_.MarkFinished(); } } @@ -553,7 +589,7 @@ struct SinkNode : ExecNode { int num_received_ = 0; int emit_stop_ = -1; - bool stopped_ = false; + Future<> finished_ = Future<>::MakeFinished(); PushGenerator>::Producer producer_; }; diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 21a757af5a1..6c29ddfa7a6 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -22,36 +22,32 @@ #include #include +#include "arrow/compute/exec.h" #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/optional.h" #include "arrow/util/visibility.h" -// NOTES: -// - ExecBatches only have arrays or scalars -// - data streams may be ordered, so add input number? -// - node to combine input needs to reorder - namespace arrow { namespace compute { -class ExecNode; - class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { public: using NodeVector = std::vector; virtual ~ExecPlan() = default; + ExecContext* exec_context() const { return exec_context_; } + /// Make an empty exec plan - static Result> Make(); + static Result> Make(ExecContext* = default_exec_context()); ExecNode* AddNode(std::unique_ptr node); template Node* EmplaceNode(Args&&... args) { - auto node = std::unique_ptr(new Node{std::forward(args)...}); + std::unique_ptr node{new Node{std::forward(args)...}}; auto out = node.get(); AddNode(std::move(node)); return out; @@ -65,16 +61,24 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { Status Validate(); - /// Start producing on all nodes + /// \brief Start producing on all nodes /// /// Nodes are started in reverse topological order, such that any node /// is started before all of its inputs. Status StartProducing(); + /// \brief Stop producing on all nodes + /// + /// Nodes are stopped in topological order, such that any node + /// is stopped before all of its outputs. void StopProducing(); + /// \brief A future which will be marked finished when all nodes have stopped producing. + Future<> finished(); + protected: - ExecPlan() = default; + ExecContext* exec_context_; + explicit ExecPlan(ExecContext* exec_context) : exec_context_(exec_context) {} }; class ARROW_EXPORT ExecNode { @@ -203,14 +207,15 @@ class ARROW_EXPORT ExecNode { /// \brief Stop producing definitively to a single output /// /// This call is a hint that an output node has completed and is not willing - /// to not receive any further data. + /// to receive any further data. virtual void StopProducing(ExecNode* output) = 0; - /// \brief Stop producing definitively - /// - /// XXX maybe this should return a Future<>? + /// \brief Stop producing definitively to all outputs virtual void StopProducing() = 0; + /// \brief A future which will be marked finished when this node has stopped producing. + virtual Future<> finished() = 0; + protected: ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, std::vector input_labels, std::shared_ptr output_schema, @@ -229,10 +234,10 @@ class ARROW_EXPORT ExecNode { /// \brief Adapt an AsyncGenerator as a source node /// -/// TODO this should accept an Executor and explicitly handle batches -/// as they are generated on each of the Executor's threads. +/// plan->exec_context()->executor() is used to parallelize pushing to +/// outputs, if provided. ARROW_EXPORT -ExecNode* MakeSourceNode(ExecPlan*, std::string label, +ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, std::shared_ptr output_schema, std::function>()>); diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 75b71f97535..9ebafc42668 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -36,6 +36,7 @@ using testing::ElementsAre; using testing::HasSubstr; +using testing::Optional; using testing::UnorderedElementsAreArray; namespace arrow { @@ -45,7 +46,7 @@ namespace compute { TEST(ExecPlanConstruction, Empty) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - ASSERT_RAISES(Invalid, plan->Validate()); + ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid)); } TEST(ExecPlanConstruction, SingleNode) { @@ -58,7 +59,7 @@ TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); // Output not bound - ASSERT_RAISES(Invalid, plan->Validate()); + ASSERT_THAT(plan->Validate(), Raises(StatusCode::Invalid)); } TEST(ExecPlanConstruction, SourceSink) { @@ -144,7 +145,15 @@ TEST(ExecPlan, DummyStartProducing) { // Note that any correct reverse topological order may do ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1", "source2", "source1")); - ASSERT_EQ(t.stopped.size(), 0); + + plan->StopProducing(); + ASSERT_THAT(plan->finished(), Finishes(Ok())); + // Note that any correct topological order may do + ASSERT_THAT(t.stopped, ElementsAre("source1", "source2", "process1", "process2", + "process3", "sink")); + + ASSERT_THAT(plan->StartProducing(), + Raises(StatusCode::Invalid, HasSubstr("restarted"))); } TEST(ExecPlan, DummyStartProducingError) { @@ -179,7 +188,7 @@ TEST(ExecPlan, DummyStartProducingError) { ASSERT_EQ(t.stopped.size(), 0); // `process1` raises IOError - ASSERT_RAISES(IOError, plan->StartProducing()); + ASSERT_THAT(plan->StartProducing(), Raises(StatusCode::IOError)); ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1")); // Nodes that started successfully were stopped in reverse order ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); @@ -226,18 +235,20 @@ Result MakeTestSourceNode(ExecPlan* plan, std::string label, std::move(gen)); } -Result> StartAndCollect( +Future> StartAndCollect( ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); RETURN_NOT_OK(plan->StartProducing()); - auto maybe_collected = CollectAsyncGenerator(gen).result(); - ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + auto collected_fut = CollectAsyncGenerator(gen); - plan->StopProducing(); - - return internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, collected); + return AllComplete({plan->finished(), Future<>(collected_fut)}) + .Then([collected_fut]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); + return internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, + std::move(collected)); + }); } BatchesWithSchema MakeBasicBatches() { @@ -282,7 +293,7 @@ TEST(ExecPlanExecution, SourceSink) { auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(basic_data.batches))); + Finishes(ResultWith(UnorderedElementsAreArray(basic_data.batches)))); } } } @@ -304,7 +315,7 @@ TEST(ExecPlanExecution, SourceSinkError) { auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Raises(StatusCode::Invalid, HasSubstr("Artificial"))); + Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial")))); } TEST(ExecPlanExecution, StressSourceSink) { @@ -327,7 +338,37 @@ TEST(ExecPlanExecution, StressSourceSink) { auto sink_gen = MakeSinkNode(source, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(random_data.batches))); + Finishes(ResultWith(UnorderedElementsAreArray(random_data.batches)))); + } + } +} + +TEST(ExecPlanExecution, StressSourceSinkStopped) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + + int num_batches = slow && !parallel ? 30 : 300; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + auto random_data = MakeRandomBatches( + schema({field("a", int32()), field("b", boolean())}), num_batches); + + ASSERT_OK_AND_ASSIGN(auto source, MakeTestSourceNode(plan.get(), "source", + random_data, parallel, slow)); + + auto sink_gen = MakeSinkNode(source, "sink"); + + ASSERT_OK(plan->Validate()); + ASSERT_OK(plan->StartProducing()); + + EXPECT_THAT(sink_gen(), Finishes(ResultWith(Optional(random_data.batches[0])))); + + plan->StopProducing(); + ASSERT_THAT(plan->finished(), Finishes(Ok())); } } } @@ -349,9 +390,9 @@ TEST(ExecPlanExecution, SourceFilterSink) { auto sink_gen = MakeSinkNode(filter, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray( + Finishes(ResultWith(UnorderedElementsAreArray( {ExecBatchFromJSON({int32(), boolean()}, "[]"), - ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")}))); + ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")})))); } TEST(ExecPlanExecution, SourceProjectSink) { @@ -376,10 +417,10 @@ TEST(ExecPlanExecution, SourceProjectSink) { auto sink_gen = MakeSinkNode(projection, "sink"); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray( + Finishes(ResultWith(UnorderedElementsAreArray( {ExecBatchFromJSON({boolean(), int32()}, "[[false, null], [true, 5]]"), ExecBatchFromJSON({boolean(), int32()}, - "[[null, 6], [true, 7], [true, 8]]")}))); + "[[null, 6], [true, 7], [true, 8]]")})))); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 6fbfa2a430c..bd203b354f0 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -88,14 +89,12 @@ struct DummyNode : ExecNode { } void StopProducing(ExecNode* output) override { - ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + EXPECT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; AssertIsOutput(output); - StopProducing(); } void StopProducing() override { if (started_) { - started_ = false; for (const auto& input : inputs_) { input->StopProducing(this); } @@ -105,9 +104,12 @@ struct DummyNode : ExecNode { } } + Future<> finished() override { return Future<>::MakeFinished(); } + private: void AssertIsOutput(ExecNode* output) { - ASSERT_NE(std::find(outputs_.begin(), outputs_.end(), output), outputs_.end()); + auto it = std::find(outputs_.begin(), outputs_.end(), output); + ASSERT_NE(it, outputs_.end()); } std::shared_ptr dummy_schema() const { @@ -116,6 +118,7 @@ struct DummyNode : ExecNode { StartProducingFunc start_producing_; StopProducingFunc stop_producing_; + std::unordered_set requested_stop_; bool started_ = false; }; diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index 8ce7e52d252..2c145dadaeb 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -69,7 +69,7 @@ TEST(ExecContext, BasicWorkings) { // Now, let's customize all the things LoggingMemoryPool my_pool(default_memory_pool()); std::unique_ptr custom_reg = FunctionRegistry::Make(); - ExecContext ctx(&my_pool, custom_reg.get()); + ExecContext ctx(&my_pool, /*executor=*/nullptr, custom_reg.get()); ASSERT_EQ(custom_reg.get(), ctx.func_registry()); ASSERT_EQ(&my_pool, ctx.memory_pool()); diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index ffa64e8ec10..eab80010c76 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -491,9 +491,6 @@ TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragments) { auto all_row_groups = internal::Iota(static_cast(kNumRowGroups)); CountRowGroupsInFragment(fragment, all_row_groups, literal(true)); - // FIXME this is only meaningful if "not here" is a virtual column - // CountRowGroupsInFragment(fragment, all_row_groups, "not here"_ == 0); - for (int i = 0; i < kNumRowGroups; ++i) { CountRowGroupsInFragment(fragment, {i}, equal(field_ref("i64"), literal(i + 1))); } @@ -516,9 +513,10 @@ TEST_P(TestParquetFileFormatScan, PredicatePushdownRowGroupFragments) { fragment, {1, 3}, or_(equal(field_ref("i64"), literal(2)), equal(field_ref("i64"), literal(4)))); - // TODO(bkietz): better Assume support for InExpression - // auto set = ArrayFromJSON(int64(), "[2, 4]"); - // CountRowGroupsInFragment(fragment, {1, 3}, field_ref("i64").In(set)); + auto set = ArrayFromJSON(int64(), "[2, 4]"); + CountRowGroupsInFragment( + fragment, {1, 3}, + call("is_in", {field_ref("i64")}, compute::SetLookupOptions{set})); CountRowGroupsInFragment(fragment, {0, 1, 2, 3, 4}, less(field_ref("i64"), literal(6))); diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 58e96fdc113..cc2e5bcda66 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -24,6 +24,7 @@ #include #include "arrow/array/array_primitive.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" @@ -432,72 +433,9 @@ class ARROW_DS_EXPORT AsyncScanner : public Scanner, namespace { -inline Result DoFilterAndProjectRecordBatchAsync( - const std::shared_ptr& options, const EnumeratedRecordBatch& in) { - ARROW_ASSIGN_OR_RAISE( - compute::Expression simplified_filter, - SimplifyWithGuarantee(options->filter, in.fragment.value->partition_expression())); - - const auto& schema = *options->dataset_schema; - - compute::ExecContext exec_context{options->pool}; - ARROW_ASSIGN_OR_RAISE(Datum mask, - ExecuteScalarExpression(simplified_filter, schema, - in.record_batch.value, &exec_context)); - - Datum filtered; - if (mask.is_scalar()) { - const auto& mask_scalar = mask.scalar_as(); - if (mask_scalar.is_valid && mask_scalar.value) { - // filter matches entire table - filtered = in.record_batch.value; - } else { - // Filter matches nothing - filtered = in.record_batch.value->Slice(0, 0); - } - } else { - ARROW_ASSIGN_OR_RAISE( - filtered, compute::Filter(in.record_batch.value, mask, - compute::FilterOptions::Defaults(), &exec_context)); - } - - ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, - SimplifyWithGuarantee(options->projection, - in.fragment.value->partition_expression())); - - ARROW_ASSIGN_OR_RAISE( - Datum projected, - ExecuteScalarExpression(simplified_projection, schema, filtered, &exec_context)); - - DCHECK_EQ(projected.type()->id(), Type::STRUCT); - if (projected.shape() == ValueDescr::SCALAR) { - // Only virtual columns are projected. Broadcast to an array - ARROW_ASSIGN_OR_RAISE( - projected, - MakeArrayFromScalar(*projected.scalar(), filtered.record_batch()->num_rows(), - options->pool)); - } - ARROW_ASSIGN_OR_RAISE(auto out, - RecordBatch::FromStructArray(projected.array_as())); - auto projected_batch = - out->ReplaceSchemaMetadata(in.record_batch.value->schema()->metadata()); - - return EnumeratedRecordBatch{ - {std::move(projected_batch), in.record_batch.index, in.record_batch.last}, - in.fragment}; -} - -inline EnumeratedRecordBatchGenerator FilterAndProjectRecordBatchAsync( - const std::shared_ptr& options, EnumeratedRecordBatchGenerator rbs) { - auto mapper = [options](const EnumeratedRecordBatch& in) { - return DoFilterAndProjectRecordBatchAsync(options, in); - }; - return MakeMappedGenerator(std::move(rbs), mapper); -} - Result FragmentToBatches( const Enumerated>& fragment, - const std::shared_ptr& options, bool filter_and_project = true) { + const std::shared_ptr& options) { ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options)); auto enumerated_batch_gen = MakeEnumeratedGenerator(std::move(batch_gen)); @@ -506,73 +444,116 @@ Result FragmentToBatches( return EnumeratedRecordBatch{record_batch, fragment}; }; - auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn)); - - if (filter_and_project) { - return FilterAndProjectRecordBatchAsync(options, std::move(combined_gen)); - } - return combined_gen; + return MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn)); } Result> FragmentsToBatches( - FragmentGenerator fragment_gen, const std::shared_ptr& options, - bool filter_and_project = true) { + FragmentGenerator fragment_gen, const std::shared_ptr& options) { auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); return MakeMappedGenerator(std::move(enumerated_fragment_gen), [=](const Enumerated>& fragment) { - return FragmentToBatches(fragment, options, - filter_and_project); + return FragmentToBatches(fragment, options); }); } -Result>>> FragmentsToRowCount( - FragmentGenerator fragment_gen, - std::shared_ptr options_with_projection) { - // Must use optional to avoid breaking the pipeline on empty batches - auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); +Result MakeScanNode(compute::ExecPlan* plan, + FragmentGenerator fragment_gen, + std::shared_ptr options) { + if (!options->use_async) { + return Status::NotImplemented("ScanNodes without asynchrony"); + } - // Drop projection since we only need to count rows - auto options = std::make_shared(*options_with_projection); - RETURN_NOT_OK(SetProjection(options.get(), std::vector())); + ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen, + FragmentsToBatches(std::move(fragment_gen), options)); - auto count_fragment_fn = - [options](const Enumerated>& fragment) - -> Result>> { - auto count_fut = fragment.value->CountRows(options->filter, options); - return MakeFromFuture( - count_fut.Then([=](util::optional val) - -> Result>> { - // Fast path - if (val.has_value()) { - return MakeSingleFutureGenerator( - Future>::MakeFinished(val)); - } - // Slow path - ARROW_ASSIGN_OR_RAISE(auto batch_gen, FragmentToBatches(fragment, options)); - auto count_fn = - [](const EnumeratedRecordBatch& enumerated) -> util::optional { - return enumerated.record_batch.value->num_rows(); - }; - return MakeMappedGenerator(batch_gen, std::move(count_fn)); - })); - }; - return MakeMappedGenerator(std::move(enumerated_fragment_gen), - std::move(count_fragment_fn)); -} + auto merged_batch_gen = + MakeMergedGenerator(std::move(batch_gen_gen), options->fragment_readahead); -Result ScanBatchesUnorderedAsyncImpl( - const std::shared_ptr& options, FragmentGenerator fragment_gen, - internal::Executor* cpu_executor, bool filter_and_project = true) { - ARROW_ASSIGN_OR_RAISE( - auto batch_gen_gen, - FragmentsToBatches(std::move(fragment_gen), options, filter_and_project)); - auto batch_gen_gen_readahead = - MakeSerialReadaheadGenerator(std::move(batch_gen_gen), options->fragment_readahead); - auto merged_batch_gen = MakeMergedGenerator(std::move(batch_gen_gen_readahead), - options->fragment_readahead); - return MakeReadaheadGenerator(std::move(merged_batch_gen), options->fragment_readahead); + auto batch_gen = + MakeReadaheadGenerator(std::move(merged_batch_gen), options->fragment_readahead); + + auto gen = MakeMappedGenerator( + std::move(batch_gen), + [options](const EnumeratedRecordBatch& partial) + -> Result> { + ARROW_ASSIGN_OR_RAISE( + util::optional batch, + compute::MakeExecBatch(*options->dataset_schema, partial.record_batch.value)); + // TODO(ARROW-13263) fragments may be able to attach more guarantees to batches + // than this, for example parquet's row group stats. Failing to do this leaves + // perf on the table because row group stats could be used to skip kernel execs in + // FilterNode. + // + // Additionally, if a fragment failed to perform projection pushdown there may be + // unnecessarily materialized columns in batch. We could drop them now instead of + // letting them coast through the rest of the plan. + batch->guarantee = partial.fragment.value->partition_expression(); + + // tag rows with fragment- and batch-of-origin + batch->values.emplace_back(partial.fragment.index); + batch->values.emplace_back(partial.record_batch.index); + batch->values.emplace_back(partial.record_batch.last); + return batch; + }); + + auto augmented_fields = options->dataset_schema->fields(); + augmented_fields.push_back(field("__fragment_index", int32())); + augmented_fields.push_back(field("__batch_index", int32())); + augmented_fields.push_back(field("__last_in_fragment", boolean())); + return compute::MakeSourceNode(plan, "dataset_scan", + schema(std::move(augmented_fields)), std::move(gen)); } +class OneShotScanTask : public ScanTask { + public: + OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, + std::shared_ptr fragment) + : ScanTask(std::move(options), std::move(fragment)), + batch_it_(std::move(batch_it)) {} + Result Execute() override { + if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); + return std::move(batch_it_); + } + + private: + RecordBatchIterator batch_it_; +}; + +class OneShotFragment : public Fragment { + public: + OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) + : Fragment(compute::literal(true), std::move(schema)), + batch_it_(std::move(batch_it)) { + DCHECK_NE(physical_schema_, nullptr); + } + Status CheckConsumed() { + if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); + return Status::OK(); + } + Result Scan(std::shared_ptr options) override { + RETURN_NOT_OK(CheckConsumed()); + ScanTaskVector tasks{std::make_shared( + std::move(batch_it_), std::move(options), shared_from_this())}; + return MakeVectorIterator(std::move(tasks)); + } + Result ScanBatchesAsync( + const std::shared_ptr& options) override { + RETURN_NOT_OK(CheckConsumed()); + ARROW_ASSIGN_OR_RAISE( + auto background_gen, + MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor())); + return MakeTransferredGenerator(std::move(background_gen), + internal::GetCpuThreadPool()); + } + std::string type_name() const override { return "one-shot"; } + + protected: + Result> ReadPhysicalSchemaImpl() override { + return physical_schema_; + } + + RecordBatchIterator batch_it_; +}; } // namespace Result AsyncScanner::GetFragments() const { @@ -604,11 +585,88 @@ Result AsyncScanner::ScanBatchesUnorderedAsync() return ScanBatchesUnorderedAsync(internal::GetCpuThreadPool()); } +namespace { +Result ToEnumeratedRecordBatch( + const util::optional& batch, const ScanOptions& options, + const FragmentVector& fragments) { + int num_fields = options.projected_schema->num_fields(); + + ArrayVector columns(num_fields); + for (size_t i = 0; i < columns.size(); ++i) { + const Datum& value = batch->values[i]; + if (value.is_array()) { + columns[i] = value.make_array(); + continue; + } + ARROW_ASSIGN_OR_RAISE( + columns[i], MakeArrayFromScalar(*value.scalar(), batch->length, options.pool)); + } + + EnumeratedRecordBatch out; + out.fragment.index = batch->values[num_fields].scalar_as().value; + out.fragment.value = fragments[out.fragment.index]; + out.fragment.last = false; // ignored during reordering + + out.record_batch.index = batch->values[num_fields + 1].scalar_as().value; + out.record_batch.value = + RecordBatch::Make(options.projected_schema, batch->length, std::move(columns)); + out.record_batch.last = batch->values[num_fields + 2].scalar_as().value; + + return out; +} +} // namespace + Result AsyncScanner::ScanBatchesUnorderedAsync( internal::Executor* cpu_executor) { - ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); - return ScanBatchesUnorderedAsyncImpl(scan_options_, std::move(fragment_gen), - cpu_executor); + if (!scan_options_->use_threads) { + cpu_executor = nullptr; + } + + auto exec_context = + std::make_shared(scan_options_->pool, cpu_executor); + + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get())); + + ARROW_ASSIGN_OR_RAISE(auto scan, MakeScanNode(plan.get(), dataset_, scan_options_)); + + ARROW_ASSIGN_OR_RAISE(auto filter, + compute::MakeFilterNode(scan, "filter", scan_options_->filter)); + + auto exprs = scan_options_->projection.call()->arguments; + exprs.push_back(compute::field_ref("__fragment_index")); + exprs.push_back(compute::field_ref("__batch_index")); + exprs.push_back(compute::field_ref("__last_in_fragment")); + ARROW_ASSIGN_OR_RAISE(auto project, + compute::MakeProjectNode(filter, "project", std::move(exprs))); + + AsyncGenerator> sink_gen = + compute::MakeSinkNode(project, "sink"); + + RETURN_NOT_OK(plan->StartProducing()); + + auto options = scan_options_; + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset_->GetFragments(scan_options_->filter)); + ARROW_ASSIGN_OR_RAISE(auto fragments, fragments_it.ToVector()); + auto shared_fragments = std::make_shared(std::move(fragments)); + + // If the generator is destroyed before being completely drained, inform plan + std::shared_ptr stop_producing{ + nullptr, [plan, exec_context](...) { + bool not_finished_yet = plan->finished().TryAddCallback( + [&plan, &exec_context] { return [plan, exec_context](const Status&) {}; }); + + if (not_finished_yet) { + plan->StopProducing(); + } + }}; + + return MakeMappedGenerator( + std::move(sink_gen), + [sink_gen, options, stop_producing, + shared_fragments](const util::optional& batch) + -> Future { + return ToEnumeratedRecordBatch(batch, *options, *shared_fragments); + }); } Result AsyncScanner::ScanBatchesAsync() { @@ -729,20 +787,75 @@ Future> AsyncScanner::ToTableAsync( }); } +namespace { +Result GetSelectionSize(const Datum& selection, int64_t length) { + if (length == 0) return 0; + + if (selection.is_scalar()) { + if (!selection.scalar()->is_valid) return 0; + if (!selection.scalar_as().value) return 0; + return length; + } + + ARROW_ASSIGN_OR_RAISE(auto count, compute::Sum(selection)); + return static_cast(count.scalar_as().value); +} +} // namespace + Result AsyncScanner::CountRows() { ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); - ARROW_ASSIGN_OR_RAISE(auto count_gen_gen, - FragmentsToRowCount(std::move(fragment_gen), scan_options_)); - auto count_gen = MakeConcatenatedGenerator(std::move(count_gen_gen)); - int64_t total = 0; - auto sum_fn = [&total](util::optional count) -> Status { - if (count.has_value()) total += *count; - return Status::OK(); - }; - RETURN_NOT_OK(VisitAsyncGenerator>(std::move(count_gen), - std::move(sum_fn)) - .status()); - return total; + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); + // Drop projection since we only need to count rows + auto options = std::make_shared(*scan_options_); + RETURN_NOT_OK(SetProjection(options.get(), std::vector())); + + std::atomic total{0}; + + fragment_gen = MakeMappedGenerator( + std::move(fragment_gen), [&](const std::shared_ptr& fragment) { + return fragment->CountRows(scan_options_->filter, scan_options_) + .Then([&, fragment](util::optional fast_count) mutable + -> std::shared_ptr { + if (fast_count) { + // fast path: got row count directly; skip scanning this fragment + total += *fast_count; + return std::make_shared( + options->dataset_schema, + MakeEmptyIterator>()); + } + + // slow path: actually filter this fragment's batches + return std::move(fragment); + }); + }); + + ARROW_ASSIGN_OR_RAISE(auto scan, + MakeScanNode(plan.get(), std::move(fragment_gen), options)); + + ARROW_ASSIGN_OR_RAISE( + auto get_selection, + compute::MakeProjectNode(scan, "get_selection", {options->filter})); + + AsyncGenerator> sink_gen = + compute::MakeSinkNode(get_selection, "sink"); + + RETURN_NOT_OK(plan->StartProducing()); + + RETURN_NOT_OK( + VisitAsyncGenerator(std::move(sink_gen), + [&](const util::optional& batch) { + // TODO replace with scalar aggregation node + ARROW_ASSIGN_OR_RAISE( + int64_t slow_count, + GetSelectionSize(batch->values[0], batch->length)); + total += slow_count; + return Status::OK(); + }) + .status()); + + plan->finished().Wait(); + + return total.load(); } ScannerBuilder::ScannerBuilder(std::shared_ptr dataset) @@ -762,59 +875,6 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::move(schema), FragmentVector{std::move(fragment)}), std::move(scan_options)) {} -namespace { -class OneShotScanTask : public ScanTask { - public: - OneShotScanTask(RecordBatchIterator batch_it, std::shared_ptr options, - std::shared_ptr fragment) - : ScanTask(std::move(options), std::move(fragment)), - batch_it_(std::move(batch_it)) {} - Result Execute() override { - if (!batch_it_) return Status::Invalid("OneShotScanTask was already scanned"); - return std::move(batch_it_); - } - - private: - RecordBatchIterator batch_it_; -}; - -class OneShotFragment : public Fragment { - public: - OneShotFragment(std::shared_ptr schema, RecordBatchIterator batch_it) - : Fragment(compute::literal(true), std::move(schema)), - batch_it_(std::move(batch_it)) { - DCHECK_NE(physical_schema_, nullptr); - } - Status CheckConsumed() { - if (!batch_it_) return Status::Invalid("OneShotFragment was already scanned"); - return Status::OK(); - } - Result Scan(std::shared_ptr options) override { - RETURN_NOT_OK(CheckConsumed()); - ScanTaskVector tasks{std::make_shared( - std::move(batch_it_), std::move(options), shared_from_this())}; - return MakeVectorIterator(std::move(tasks)); - } - Result ScanBatchesAsync( - const std::shared_ptr& options) override { - RETURN_NOT_OK(CheckConsumed()); - ARROW_ASSIGN_OR_RAISE( - auto background_gen, - MakeBackgroundGenerator(std::move(batch_it_), options->io_context.executor())); - return MakeTransferredGenerator(std::move(background_gen), - internal::GetCpuThreadPool()); - } - std::string type_name() const override { return "one-shot"; } - - protected: - Result> ReadPhysicalSchemaImpl() override { - return physical_schema_; - } - - RecordBatchIterator batch_it_; -}; -} // namespace - std::shared_ptr ScannerBuilder::FromRecordBatchReader( std::shared_ptr reader) { auto batch_it = MakeIteratorFromReader(reader); @@ -1108,47 +1168,12 @@ Result SyncScanner::CountRows() { Result MakeScanNode(compute::ExecPlan* plan, std::shared_ptr dataset, std::shared_ptr scan_options) { - if (!scan_options->use_async) { - return Status::NotImplemented("ScanNodes without asynchrony"); - } - // using a generator for speculative forward compatibility with async fragment discovery - ARROW_ASSIGN_OR_RAISE(scan_options->filter, - scan_options->filter.Bind(*dataset->schema())); ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset->GetFragments(scan_options->filter)); ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector()); auto fragments_gen = MakeVectorGenerator(std::move(fragments_vec)); - ARROW_ASSIGN_OR_RAISE(auto batch_gen, - ScanBatchesUnorderedAsyncImpl( - scan_options, std::move(fragments_gen), - internal::GetCpuThreadPool(), /*filter_and_project=*/false)); - - auto gen = MakeMappedGenerator( - std::move(batch_gen), - [dataset](const EnumeratedRecordBatch& partial) - -> Result> { - ARROW_ASSIGN_OR_RAISE( - util::optional batch, - compute::MakeExecBatch(*dataset->schema(), partial.record_batch.value)); - - // TODO fragments may be able to attach more guarantees to batches than this, - // for example parquet's row group stats. - batch->guarantee = partial.fragment.value->partition_expression(); - - // tag rows with fragment- and batch-of-origin - batch->values.emplace_back(partial.fragment.index); - batch->values.emplace_back(partial.record_batch.index); - batch->values.emplace_back(partial.record_batch.last); - return batch; - }); - - auto augmented_fields = dataset->schema()->fields(); - augmented_fields.push_back(field("__fragment_index", int32())); - augmented_fields.push_back(field("__batch_index", int32())); - augmented_fields.push_back(field("__last_in_fragment", boolean())); - return compute::MakeSourceNode(plan, "dataset_scan", - schema(std::move(augmented_fields)), std::move(gen)); + return MakeScanNode(plan, std::move(fragments_gen), std::move(scan_options)); } } // namespace dataset diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index bed276b1bff..f567054bf91 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -1094,19 +1094,20 @@ TEST(ScanOptions, TestMaterializedFields) { namespace { -static Result> StartAndCollect( +Future> StartAndCollect( compute::ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); RETURN_NOT_OK(plan->StartProducing()); - auto maybe_collected = CollectAsyncGenerator(gen).result(); - ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + auto collected_fut = CollectAsyncGenerator(gen); - plan->StopProducing(); - - return internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, - collected); + return AllComplete({plan->finished(), Future<>(collected_fut)}) + .Then([collected_fut]() -> Result> { + ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); + return internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, + std::move(collected)); + }); } struct DatasetAndBatches { @@ -1183,6 +1184,7 @@ TEST(ScanNode, Schema) { auto options = std::make_shared(); options->use_async = true; + options->dataset_schema = basic.dataset->schema(); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); @@ -1200,6 +1202,7 @@ TEST(ScanNode, Trivial) { auto options = std::make_shared(); options->use_async = true; + options->dataset_schema = basic.dataset->schema(); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); auto sink_gen = MakeSinkNode(scan, "sink"); @@ -1207,7 +1210,7 @@ TEST(ScanNode, Trivial) { // trivial scan: the batches are returned unmodified auto expected = basic.batches; ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(expected))); + Finishes(ResultWith(UnorderedElementsAreArray(expected)))); } TEST(ScanNode, FilteredOnVirtualColumn) { @@ -1217,7 +1220,9 @@ TEST(ScanNode, FilteredOnVirtualColumn) { auto options = std::make_shared(); options->use_async = true; - options->filter = less(field_ref("c"), literal(30)); + options->dataset_schema = basic.dataset->schema(); + ASSERT_OK_AND_ASSIGN(options->filter, + less(field_ref("c"), literal(30)).Bind(*basic.dataset->schema())); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); @@ -1230,7 +1235,7 @@ TEST(ScanNode, FilteredOnVirtualColumn) { expected.pop_back(); ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(expected))); + Finishes(ResultWith(UnorderedElementsAreArray(expected)))); } TEST(ScanNode, DeferredFilterOnPhysicalColumn) { @@ -1240,7 +1245,10 @@ TEST(ScanNode, DeferredFilterOnPhysicalColumn) { auto options = std::make_shared(); options->use_async = true; - options->filter = greater(field_ref("a"), literal(4)); + options->dataset_schema = basic.dataset->schema(); + ASSERT_OK_AND_ASSIGN( + options->filter, + greater(field_ref("a"), literal(4)).Bind(*basic.dataset->schema())); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); @@ -1251,11 +1259,36 @@ TEST(ScanNode, DeferredFilterOnPhysicalColumn) { auto expected = basic.batches; ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(expected))); + Finishes(ResultWith(UnorderedElementsAreArray(expected)))); } -TEST(ScanNode, ProjectionPushdown) { - // ensure non-projected columns are dropped +TEST(ScanNode, DISABLED_ProjectionPushdown) { + // ARROW-13263 + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + options->dataset_schema = basic.dataset->schema(); + ASSERT_OK(SetProjection(options.get(), {field_ref("b")}, {"b"})); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto sink_gen = MakeSinkNode(scan, "sink"); + + auto expected = basic.batches; + + int a_index = basic.dataset->schema()->GetFieldIndex("a"); + int c_index = basic.dataset->schema()->GetFieldIndex("c"); + for (auto& batch : expected) { + // "a", "c" were not projected or filtered so they are dropped eagerly + batch.values[a_index] = MakeNullScalar(batch.values[a_index].type()); + batch.values[c_index] = MakeNullScalar(batch.values[c_index].type()); + } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(expected)))); } TEST(ScanNode, MaterializationOfVirtualColumn) { @@ -1265,6 +1298,7 @@ TEST(ScanNode, MaterializationOfVirtualColumn) { auto options = std::make_shared(); options->use_async = true; + options->dataset_schema = basic.dataset->schema(); ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); @@ -1286,106 +1320,7 @@ TEST(ScanNode, MaterializationOfVirtualColumn) { } ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - ResultWith(UnorderedElementsAreArray(expected))); -} - -TEST(ScanNode, CompareToScanner) { - ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); - - auto basic = MakeBasicDataset(); - - ScannerBuilder builder(basic.dataset); - ASSERT_OK(builder.UseAsync(true)); - ASSERT_OK(builder.UseThreads(true)); - ASSERT_OK(builder.Filter(greater(field_ref("c"), literal(30)))); - ASSERT_OK(builder.Project( - {field_ref("c"), call("multiply", {field_ref("a"), literal(2)})}, {"c", "a * 2"})); - ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); - - ASSERT_OK_AND_ASSIGN(auto fragments_it, - basic.dataset->GetFragments(scanner->options()->filter)); - ASSERT_OK_AND_ASSIGN(auto fragments, fragments_it.ToVector()); - - auto options = scanner->options(); - - ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); - - ASSERT_OK_AND_ASSIGN(auto filter, - compute::MakeFilterNode(scan, "filter", options->filter)); - - auto exprs = options->projection.call()->arguments; - exprs.push_back(compute::field_ref("__fragment_index")); - exprs.push_back(compute::field_ref("__batch_index")); - exprs.push_back(compute::field_ref("__last_in_fragment")); - ASSERT_OK_AND_ASSIGN(auto project, compute::MakeProjectNode(filter, "project", exprs)); - - AsyncGenerator> sink_gen = - compute::MakeSinkNode(project, "sink"); - - ASSERT_OK(plan->StartProducing()); - - auto from_plan = - CollectAsyncGenerator( - MakeMappedGenerator( - sink_gen, - [&](const util::optional& batch) - -> Result { - int num_fields = options->projected_schema->num_fields(); - - ArrayVector columns(num_fields); - for (size_t i = 0; i < columns.size(); ++i) { - const Datum& value = batch->values[i]; - if (value.is_array()) { - columns[i] = value.make_array(); - continue; - } - ARROW_ASSIGN_OR_RAISE( - columns[i], - MakeArrayFromScalar(*value.scalar(), batch->length, options->pool)); - } - - EnumeratedRecordBatch out; - out.fragment.index = - batch->values[num_fields].scalar_as().value; - out.fragment.value = fragments[out.fragment.index]; - out.fragment.last = false; // ignored during reordering - - out.record_batch.index = - batch->values[num_fields + 1].scalar_as().value; - out.record_batch.value = RecordBatch::Make( - options->projected_schema, batch->length, std::move(columns)); - out.record_batch.last = - batch->values[num_fields + 2].scalar_as().value; - - return out; - })) - .result(); - - ASSERT_OK_AND_ASSIGN(auto from_scanner_gen, scanner->ScanBatchesUnorderedAsync()); - auto from_scanner = CollectAsyncGenerator(from_scanner_gen).result(); - - auto less = [](const EnumeratedRecordBatch& l, const EnumeratedRecordBatch& r) { - if (l.fragment.index < r.fragment.index) return true; - return l.record_batch.index < r.record_batch.index; - }; - - ASSERT_OK(from_plan); - std::sort(from_plan->begin(), from_plan->end(), less); - - ASSERT_OK(from_scanner); - std::sort(from_scanner->begin(), from_scanner->end(), less); - - ASSERT_EQ(from_plan->size(), from_scanner->size()); - for (size_t i = 0; i < from_plan->size(); ++i) { - const auto& p = from_plan->at(i); - const auto& s = from_scanner->at(i); - SCOPED_TRACE(i); - ASSERT_EQ(p.fragment.index, s.fragment.index); - ASSERT_EQ(p.fragment.value, s.fragment.value); - ASSERT_EQ(p.record_batch.last, s.record_batch.last); - ASSERT_EQ(p.record_batch.index, s.record_batch.index); - AssertBatchesEqual(*p.record_batch.value, *s.record_batch.value); - } + Finishes(ResultWith(UnorderedElementsAreArray(expected)))); } } // namespace dataset diff --git a/cpp/src/arrow/testing/future_util.h b/cpp/src/arrow/testing/future_util.h index 878840587ff..2ca70d05402 100644 --- a/cpp/src/arrow/testing/future_util.h +++ b/cpp/src/arrow/testing/future_util.h @@ -21,21 +21,21 @@ #include "arrow/util/future.h" // This macro should be called by futures that are expected to -// complete pretty quickly. 2 seconds is the default max wait -// here. Anything longer than that and it's a questionable -// unit test anyways. -#define ASSERT_FINISHES_IMPL(fut) \ - do { \ - ASSERT_TRUE(fut.Wait(300)); \ - if (!fut.is_finished()) { \ - FAIL() << "Future did not finish in a timely fashion"; \ - } \ +// complete pretty quickly. arrow::kDefaultAssertFinishesWaitSeconds is the +// default max wait here. Anything longer than that and it's a questionable unit test +// anyways. +#define ASSERT_FINISHES_IMPL(fut) \ + do { \ + ASSERT_TRUE(fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \ + if (!fut.is_finished()) { \ + FAIL() << "Future did not finish in a timely fashion"; \ + } \ } while (false) #define ASSERT_FINISHES_OK(expr) \ do { \ auto&& _fut = (expr); \ - ASSERT_TRUE(_fut.Wait(300)); \ + ASSERT_TRUE(_fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \ if (!_fut.is_finished()) { \ FAIL() << "Future did not finish in a timely fashion"; \ } \ @@ -74,12 +74,12 @@ ASSERT_EQ(expected, _actual); \ } while (0) -#define EXPECT_FINISHES_IMPL(fut) \ - do { \ - EXPECT_TRUE(fut.Wait(300)); \ - if (!fut.is_finished()) { \ - ADD_FAILURE() << "Future did not finish in a timely fashion"; \ - } \ +#define EXPECT_FINISHES_IMPL(fut) \ + do { \ + EXPECT_TRUE(fut.Wait(::arrow::kDefaultAssertFinishesWaitSeconds)); \ + if (!fut.is_finished()) { \ + ADD_FAILURE() << "Future did not finish in a timely fashion"; \ + } \ } while (false) #define ON_FINISH_ASSIGN_OR_HANDLE_ERROR_IMPL(handle_error, future_name, lhs, rexpr) \ @@ -105,6 +105,8 @@ namespace arrow { +constexpr double kDefaultAssertFinishesWaitSeconds = 64; + template void AssertNotFinished(const Future& fut) { ASSERT_FALSE(IsFutureFinished(fut.state())); diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h index 246f321e8fa..f76c25dc096 100644 --- a/cpp/src/arrow/testing/matchers.h +++ b/cpp/src/arrow/testing/matchers.h @@ -21,9 +21,60 @@ #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/testing/future_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/future.h" namespace arrow { +template +class FutureMatcher { + public: + explicit FutureMatcher(ResultMatcher result_matcher, double wait_seconds) + : result_matcher_(std::move(result_matcher)), wait_seconds_(wait_seconds) {} + + template ::type::ValueType> + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(const ResultMatcher& result_matcher, double wait_seconds) + : result_matcher_(testing::MatcherCast>(result_matcher)), + wait_seconds_(wait_seconds) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "value "; + result_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "value "; + result_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain(const Fut& fut, + testing::MatchResultListener* listener) const override { + if (!fut.Wait(wait_seconds_)) { + *listener << "which didn't finish within " << wait_seconds_ << " seconds"; + return false; + } + + const Result& maybe_value = fut.result(); + testing::StringMatchResultListener value_listener; + return result_matcher_.MatchAndExplain(maybe_value, &value_listener); + } + + const testing::Matcher> result_matcher_; + const double wait_seconds_; + }; + + return testing::Matcher(new Impl(result_matcher_, wait_seconds_)); + } + + private: + const ResultMatcher result_matcher_; + const double wait_seconds_; +}; + template class ResultMatcher { public: @@ -55,7 +106,7 @@ class ResultMatcher { << " doesn't match"; return false; } - const ValueType& value = GetValue(maybe_value); + const ValueType& value = maybe_value.ValueOrDie(); testing::StringMatchResultListener value_listener; const bool match = value_matcher_.MatchAndExplain(value, &value_listener); *listener << "whose value " << testing::PrintToString(value) @@ -71,23 +122,13 @@ class ResultMatcher { } private: - template - static const T& GetValue(const Result& maybe_value) { - return maybe_value.ValueOrDie(); - } - - template - static const T& GetValue(const Future& value_fut) { - return GetValue(value_fut.result()); - } - const ValueMatcher value_matcher_; }; -class StatusMatcher { +class ErrorMatcher { public: - explicit StatusMatcher(StatusCode code, - util::optional> message_matcher) + explicit ErrorMatcher(StatusCode code, + util::optional> message_matcher) : code_(code), message_matcher_(std::move(message_matcher)) {} template @@ -115,7 +156,7 @@ class StatusMatcher { bool MatchAndExplain(const Res& maybe_value, testing::MatchResultListener* listener) const override { - const Status& status = GetStatus(maybe_value); + const Status& status = internal::GenericToStatus(maybe_value); testing::StringMatchResultListener value_listener; bool match = status.code() == code_; @@ -138,40 +179,62 @@ class StatusMatcher { } private: - static const Status& GetStatus(const Status& status) { return status; } + const StatusCode code_; + const util::optional> message_matcher_; +}; - template - static const Status& GetStatus(const Result& maybe_value) { - return maybe_value.status(); - } +class OkMatcher { + public: + template + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + void DescribeTo(::std::ostream* os) const override { *os << "is ok"; } - template - static const Status& GetStatus(const Future& value_fut) { - return value_fut.status(); - } + void DescribeNegationTo(::std::ostream* os) const override { *os << "is not ok"; } - const StatusCode code_; - const util::optional> message_matcher_; + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + const Status& status = internal::GenericToStatus(maybe_value); + testing::StringMatchResultListener value_listener; + + const bool match = status.ok(); + *listener << "whose value " << testing::PrintToString(status.ToString()) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + }; + + return testing::Matcher(new Impl()); + } }; -// Returns a matcher that matches the value of a successful Result or Future. -// (Future will be waited upon to acquire its result for matching.) +// Returns a matcher that waits on a Future (by default for 16 seconds) +// then applies a matcher to the result. +template +FutureMatcher Finishes( + const ResultMatcher& result_matcher, + double wait_seconds = kDefaultAssertFinishesWaitSeconds) { + return FutureMatcher(result_matcher, wait_seconds); +} + +// Returns a matcher that matches the value of a successful Result. template ResultMatcher ResultWith(const ValueMatcher& value_matcher) { return ResultMatcher(value_matcher); } -// Returns a matcher that matches the StatusCode of a Status, Result, or Future. -// (Future will be waited upon to acquire its result for matching.) -inline StatusMatcher Raises(StatusCode code) { - return StatusMatcher(code, util::nullopt); -} +// Returns a matcher that matches an ok Status or Result. +inline OkMatcher Ok() { return {}; } + +// Returns a matcher that matches the StatusCode of a Status or Result. +// Do not use Raises(StatusCode::OK) to match a non error code. +inline ErrorMatcher Raises(StatusCode code) { return ErrorMatcher(code, util::nullopt); } -// Returns a matcher that matches the StatusCode and message of a Status, Result, or -// Future. (Future will be waited upon to acquire its result for matching.) +// Returns a matcher that matches the StatusCode and message of a Status or Result. template -StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { - return StatusMatcher(code, testing::MatcherCast(message_matcher)); +ErrorMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { + return ErrorMatcher(code, testing::MatcherCast(message_matcher)); } } // namespace arrow diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 5a6321fd418..18149884204 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -77,16 +77,15 @@ Future AsyncGeneratorEnd() { } /// returning a future that completes when all have been visited -template -Future<> VisitAsyncGenerator(AsyncGenerator generator, - std::function visitor) { +template +Future<> VisitAsyncGenerator(AsyncGenerator generator, Visitor visitor) { struct LoopBody { struct Callback { - Result> operator()(const T& result) { - if (IsIterationEnd(result)) { + Result> operator()(const T& next) { + if (IsIterationEnd(next)) { return Break(); } else { - auto visited = visitor(result); + auto visited = visitor(next); if (visited.ok()) { return Continue(); } else { @@ -95,7 +94,7 @@ Future<> VisitAsyncGenerator(AsyncGenerator generator, } } - std::function visitor; + Visitor visitor; }; Future> operator()() { @@ -105,7 +104,7 @@ Future<> VisitAsyncGenerator(AsyncGenerator generator, } AsyncGenerator generator; - std::function visitor; + Visitor visitor; }; return Loop(LoopBody{std::move(generator), std::move(visitor)}); @@ -775,7 +774,7 @@ class PushGenerator { /// Producer API for PushGenerator class Producer { public: - explicit Producer(const std::shared_ptr state) : weak_state_(state) {} + explicit Producer(const std::shared_ptr& state) : weak_state_(state) {} /// \brief Push a value on the queue /// diff --git a/cpp/src/arrow/util/future.cc b/cpp/src/arrow/util/future.cc index b329f99ed17..f288a15be3f 100644 --- a/cpp/src/arrow/util/future.cc +++ b/cpp/src/arrow/util/future.cc @@ -272,6 +272,8 @@ class ConcreteFutureImpl : public FutureImpl { return true; case ShouldSchedule::IfUnfinished: return !in_add_callback; + case ShouldSchedule::IfDifferentExecutor: + return !callback_record.options.executor->OwnsThisThread(); default: DCHECK(false) << "Unrecognized ShouldSchedule option"; return false; @@ -309,7 +311,7 @@ class ConcreteFutureImpl : public FutureImpl { } cv_.notify_all(); - // run callbacks, lock not needed since the future is finsihed by this + // run callbacks, lock not needed since the future is finished by this // point so nothing else can modify the callbacks list and it is safe // to iterate. // diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index c7c5ba802f9..d9e0a939f25 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -66,10 +66,9 @@ using first_arg_is_status = std::is_same>::type, Status>; -template -struct has_no_args { - static constexpr bool value = internal::call_traits::argument_count::value == 0; -}; +template > +using if_has_no_args = typename std::conditional::type; /// Creates a callback that can be added to a future to mark a `dest` future finished template callback{std::move(next)}; signal_to_complete_next.AddCallback(std::move(callback)); } + + /// Helpers to conditionally ignore arguments to ContinueFunc + template + void IgnoringArgsIf(std::true_type, NextFuture&& next, ContinueFunc&& f, + Args&&...) const { + operator()(std::forward(next), std::forward(f)); + } + template + void IgnoringArgsIf(std::false_type, NextFuture&& next, ContinueFunc&& f, + Args&&... a) const { + operator()(std::forward(next), std::forward(f), + std::forward(a)...); + } }; /// Helper struct which tells us what kind of Future gets returned from `Then` based on @@ -213,7 +225,10 @@ enum class ShouldSchedule { /// callback is added IfUnfinished = 1, /// Always schedule the callback as a new task - Always = 2 + Always = 2, + /// Schedule a new task only if it would run on an executor other than + /// the specified executor. + IfDifferentExecutor = 3, }; /// \brief Options that control how a continuation is run @@ -222,9 +237,9 @@ struct CallbackOptions { ShouldSchedule should_schedule = ShouldSchedule::Never; /// If the callback is scheduled then this is the executor it should be scheduled /// on. If this is NULL then should_schedule must be Never - internal::Executor* executor = NULL; + internal::Executor* executor = NULLPTR; - static CallbackOptions Defaults() { return CallbackOptions(); } + static CallbackOptions Defaults() { return {}; } }; // Untyped private implementation @@ -343,7 +358,7 @@ class ARROW_EXPORT FutureWaiter { /// to complete, or wait on multiple Futures at once (using WaitForAll, /// WaitForAny or AsCompletedIterator). template -class Future { +class ARROW_MUST_USE_TYPE Future { public: using ValueType = T; using SyncType = typename detail::SyncType::type; @@ -464,6 +479,34 @@ class Future { return MakeFinished(E::ToResult(std::move(s))); } + struct WrapResultyOnComplete { + template + struct Callback { + void operator()(const FutureImpl& impl) && { + std::move(on_complete)(*impl.CastResult()); + } + OnComplete on_complete; + }; + }; + + struct WrapStatusyOnComplete { + template + struct Callback { + static_assert(std::is_same::value, + "Only callbacks for Future<> should accept Status and not Result"); + + void operator()(const FutureImpl& impl) && { + std::move(on_complete)(impl.CastResult()->status()); + } + OnComplete on_complete; + }; + }; + + template + using WrapOnComplete = typename std::conditional< + detail::first_arg_is_status::value, WrapStatusyOnComplete, + WrapResultyOnComplete>::type::template Callback; + /// \brief Consumer API: Register a callback to run when this future completes /// /// The callback should receive the result of the future (const Result&) @@ -485,35 +528,12 @@ class Future { /// /// In this example `fut` falls out of scope but is not destroyed because it holds a /// cyclic reference to itself through the callback. - template - typename std::enable_if::value>::type - AddCallback(OnComplete on_complete, - CallbackOptions opts = CallbackOptions::Defaults()) const { + template > + void AddCallback(OnComplete on_complete, + CallbackOptions opts = CallbackOptions::Defaults()) const { // We know impl_ will not be dangling when invoking callbacks because at least one // thread will be waiting for MarkFinished to return. Thus it's safe to keep a // weak reference to impl_ here - struct Callback { - void operator()(const FutureImpl& impl) && { - std::move(on_complete)(*impl.CastResult()); - } - OnComplete on_complete; - }; - impl_->AddCallback(Callback{std::move(on_complete)}, opts); - } - - /// Overload for callbacks accepting a Status - template - typename std::enable_if::value>::type - AddCallback(OnComplete on_complete, - CallbackOptions opts = CallbackOptions::Defaults()) const { - static_assert(std::is_same::value, - "Callbacks for Future<> should accept Status and not Result"); - struct Callback { - void operator()(const FutureImpl& impl) && { - std::move(on_complete)(impl.CastResult()->status()); - } - OnComplete on_complete; - }; impl_->AddCallback(Callback{std::move(on_complete)}, opts); } @@ -531,36 +551,62 @@ class Future { /// Returns true if a callback was actually added and false if the callback failed /// to add because the future was marked complete. template > - typename std::enable_if::value, bool>::type - TryAddCallback(const CallbackFactory& callback_factory, - CallbackOptions opts = CallbackOptions::Defaults()) const { - struct Callback { - void operator()(const FutureImpl& impl) && { - std::move(on_complete)(*static_cast*>(impl.result_.get())); - } - OnComplete on_complete; - }; - return impl_->TryAddCallback( - [&callback_factory]() { return Callback{callback_factory()}; }, opts); + typename OnComplete = detail::result_of_t, + typename Callback = WrapOnComplete> + bool TryAddCallback(const CallbackFactory& callback_factory, + CallbackOptions opts = CallbackOptions::Defaults()) const { + return impl_->TryAddCallback([&]() { return Callback{callback_factory()}; }, opts); } - template > - typename std::enable_if::value, bool>::type - TryAddCallback(const CallbackFactory& callback_factory, - CallbackOptions opts = CallbackOptions::Defaults()) const { - struct Callback { - void operator()(const FutureImpl& impl) && { - std::move(on_complete)( - static_cast*>(impl.result_.get())->status()); - } - OnComplete on_complete; + template + struct ThenOnComplete { + static constexpr bool has_no_args = + internal::call_traits::argument_count::value == 0; + + using ContinuedFuture = detail::ContinueFuture::ForSignature< + detail::if_has_no_args>; + + static_assert( + std::is_same, + ContinuedFuture>::value, + "OnSuccess and OnFailure must continue with the same future type"); + + struct DummyOnSuccess { + void operator()(const T&); }; + using OnSuccessArg = typename std::decay>>::type; - return impl_->TryAddCallback( - [&callback_factory]() { return Callback{callback_factory()}; }, opts); - } + static_assert( + !std::is_same::type>::value, + "OnSuccess' argument should not be a Result"); + + void operator()(const Result& result) && { + detail::ContinueFuture continue_future; + if (ARROW_PREDICT_TRUE(result.ok())) { + // move on_failure to a(n immediately destroyed) temporary to free its resources + ARROW_UNUSED(OnFailure(std::move(on_failure))); + continue_future.IgnoringArgsIf( + detail::if_has_no_args{}, + std::move(next), std::move(on_success), result.ValueOrDie()); + } else { + ARROW_UNUSED(OnSuccess(std::move(on_success))); + continue_future(std::move(next), std::move(on_failure), result.status()); + } + } + + OnSuccess on_success; + OnFailure on_failure; + ContinuedFuture next; + }; + + template + struct PassthruOnFailure { + using ContinuedFuture = detail::ContinueFuture::ForSignature< + detail::if_has_no_args>; + + Result operator()(const Status& s) { return s; } + }; /// \brief Consumer API: Register a continuation to run when this future completes /// @@ -573,6 +619,7 @@ class Future { /// - OnSuccess, called with the result (const ValueType&) on successul completion. /// for an empty future this will be called with nothing () /// - OnFailure, called with the error (const Status&) on failed completion. + /// This callback is optional and defaults to a passthru of any errors. /// /// Then() returns a Future whose ValueType is derived from the return type of the /// callbacks. If a callback returns: @@ -595,114 +642,18 @@ class Future { /// and the returned future may already be marked complete. /// /// See AddCallback for general considerations when writing callbacks. - template > - ContinuedFuture Then( - OnSuccess on_success, OnFailure on_failure, - typename std::enable_if::value>::type* = - NULLPTR) const { - static_assert( - std::is_same, - ContinuedFuture>::value, - "OnSuccess and OnFailure must continue with the same future type"); - using OnSuccessArg = - typename std::decay>::type; - static_assert( - !std::is_same::type>::value, - "OnSuccess' argument should not be a Result"); - - auto next = ContinuedFuture::Make(); - - struct Callback { - void operator()(const Result& result) && { - detail::ContinueFuture continue_future; - if (ARROW_PREDICT_TRUE(result.ok())) { - // move on_failure to a(n immediately destroyed) temporary to free its resources - ARROW_UNUSED(OnFailure(std::move(on_failure))); - continue_future(std::move(next), std::move(on_success), result.ValueOrDie()); - } else { - ARROW_UNUSED(OnSuccess(std::move(on_success))); - continue_future(std::move(next), std::move(on_failure), result.status()); - } - } - - OnSuccess on_success; - OnFailure on_failure; - ContinuedFuture next; - }; - - AddCallback(Callback{std::forward(on_success), - std::forward(on_failure), next}); - - return next; - } - - /// \brief Overload for callbacks which ignore the value - template < - typename OnSuccess, typename OnFailure, - typename ContinuedFuture = detail::ContinueFuture::ForSignature> - ContinuedFuture Then( - OnSuccess on_success, OnFailure on_failure, - typename std::enable_if::value>::type* = - NULLPTR) const { - static_assert( - std::is_same, - ContinuedFuture>::value, - "OnSuccess and OnFailure must continue with the same future type"); - + template , + typename OnComplete = ThenOnComplete, + typename ContinuedFuture = typename OnComplete::ContinuedFuture> + ContinuedFuture Then(OnSuccess on_success, OnFailure on_failure = {}, + CallbackOptions options = CallbackOptions::Defaults()) const { auto next = ContinuedFuture::Make(); - - struct Callback { - void operator()(const Result& result) && { - detail::ContinueFuture continue_future; - if (ARROW_PREDICT_TRUE(result.ok())) { - // move on_failure to a(n immediately destroyed) temporary to free its resources - ARROW_UNUSED(OnFailure(std::move(on_failure))); - continue_future(std::move(next), std::move(on_success)); - } else { - ARROW_UNUSED(OnSuccess(std::move(on_success))); - continue_future(std::move(next), std::move(on_failure), result.status()); - } - } - - OnSuccess on_success; - OnFailure on_failure; - ContinuedFuture next; - }; - - AddCallback(Callback{std::forward(on_success), - std::forward(on_failure), next}); - + AddCallback(OnComplete{std::forward(on_success), + std::forward(on_failure), next}, + options); return next; } - /// \brief Overload without OnFailure. Failures will be passed through unchanged. - template , - typename E = ValueType> - typename std::enable_if::value, ContinuedFuture>::type - Then(OnSuccess&& on_success) const { - return Then(std::forward(on_success), [](const Status& s) { - return Result(s); - }); - } - - /// \brief Statusy overload without OnFailure - template < - typename OnSuccess, - typename ContinuedFuture = detail::ContinueFuture::ForSignature, - typename E = ValueType> - typename std::enable_if::value, ContinuedFuture>::type - Then(OnSuccess&& on_success) const { - static_assert(std::is_same::value, - "Then callback OnSuccess must receive const T&"); - return Then(std::forward(on_success), [](const Status& s) { - return Result(s); - }); - } - /// \brief Implicit constructor to create a finished future from a value Future(ValueType val) : Future() { // NOLINT runtime/explicit impl_ = FutureImpl::MakeFinished(FutureState::SUCCESS); diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index b25d77c48cd..0db355433e8 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -1052,6 +1052,59 @@ TEST_F(FutureSchedulingTest, ScheduleIfUnfinished) { } } +TEST_F(FutureSchedulingTest, ScheduleIfDifferentExecutor) { + struct : internal::Executor { + int GetCapacity() override { return pool_->GetCapacity(); } + + bool OwnsThisThread() override { return pool_->OwnsThisThread(); } + + Status SpawnReal(internal::TaskHints hints, internal::FnOnce task, + StopToken stop_token, StopCallback&& stop_callback) override { + ++spawn_count; + return pool_->Spawn(hints, std::move(task), std::move(stop_token), + std::move(stop_callback)); + } + + std::atomic spawn_count{0}; + internal::Executor* pool_ = internal::GetCpuThreadPool(); + } executor; + + CallbackOptions options; + options.executor = &executor; + options.should_schedule = ShouldSchedule::IfDifferentExecutor; + auto pass_err = [](const Status& s) { return s; }; + + std::atomic fut0_on_executor{false}; + std::atomic fut1_on_executor{false}; + + auto fut0 = Future<>::Make(); + auto fut1 = Future<>::Make(); + + auto fut0_done = fut0.Then( + [&] { + // marked finished on main thread -> must be scheduled to executor + fut0_on_executor.store(executor.OwnsThisThread()); + + fut1.MarkFinished(); + }, + pass_err, options); + + auto fut1_done = fut1.Then( + [&] { + // marked finished on executor -> no need to schedule + fut1_on_executor.store(executor.OwnsThisThread()); + }, + pass_err, options); + + fut0.MarkFinished(); + + AllComplete({fut0_done, fut1_done}).Wait(); + + ASSERT_EQ(executor.spawn_count, 1); + ASSERT_TRUE(fut0_on_executor); + ASSERT_TRUE(fut1_on_executor); +} + TEST_F(FutureSchedulingTest, ScheduleAlwaysKeepsFutureAliveUntilCallback) { CallbackOptions options; options.should_schedule = ShouldSchedule::Always; @@ -1708,25 +1761,26 @@ TEST(FnOnceTest, MoveOnlyDataType) { TEST(FutureTest, MatcherExamples) { EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), - Raises(StatusCode::Invalid)); + Finishes(Raises(StatusCode::Invalid))); EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), - Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + Finishes(Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary")))); // message doesn't match, so no match - EXPECT_THAT( - Future::MakeFinished(Status::Invalid("arbitrary error")), - testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), + Finishes(testing::Not( + Raises(StatusCode::Invalid, testing::HasSubstr("reasonable"))))); // different error code, so no match EXPECT_THAT(Future::MakeFinished(Status::TypeError("arbitrary error")), - testing::Not(Raises(StatusCode::Invalid))); + Finishes(testing::Not(Raises(StatusCode::Invalid)))); // not an error, so no match - EXPECT_THAT(Future::MakeFinished(333), testing::Not(Raises(StatusCode::Invalid))); + EXPECT_THAT(Future::MakeFinished(333), + Finishes(testing::Not(Raises(StatusCode::Invalid)))); EXPECT_THAT(Future::MakeFinished("hello world"), - ResultWith(testing::HasSubstr("hello"))); + Finishes(ResultWith(testing::HasSubstr("hello")))); // Matcher waits on Futures auto string_fut = Future::Make(); @@ -1734,15 +1788,15 @@ TEST(FutureTest, MatcherExamples) { SleepABit(); string_fut.MarkFinished("hello world"); }); - EXPECT_THAT(string_fut, ResultWith(testing::HasSubstr("hello"))); + EXPECT_THAT(string_fut, Finishes(ResultWith(testing::HasSubstr("hello")))); finisher.join(); EXPECT_THAT(Future::MakeFinished(Status::Invalid("XXX")), - testing::Not(ResultWith(testing::HasSubstr("hello")))); + Finishes(testing::Not(ResultWith(testing::HasSubstr("hello"))))); // holds a value, but that value doesn't match the given pattern EXPECT_THAT(Future::MakeFinished("foo bar"), - testing::Not(ResultWith(testing::HasSubstr("hello")))); + Finishes(testing::Not(ResultWith(testing::HasSubstr("hello"))))); } } // namespace internal diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index febbc997852..9ac8e36a3d8 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -82,17 +82,31 @@ class ARROW_EXPORT Executor { // Spawn a fire-and-forget task. template - Status Spawn(Function&& func, StopToken stop_token = StopToken::Unstoppable()) { + Status Spawn(Function&& func) { + return SpawnReal(TaskHints{}, std::forward(func), StopToken::Unstoppable(), + StopCallback{}); + } + template + Status Spawn(Function&& func, StopToken stop_token) { return SpawnReal(TaskHints{}, std::forward(func), std::move(stop_token), StopCallback{}); } - template - Status Spawn(TaskHints hints, Function&& func, - StopToken stop_token = StopToken::Unstoppable()) { + Status Spawn(TaskHints hints, Function&& func) { + return SpawnReal(hints, std::forward(func), StopToken::Unstoppable(), + StopCallback{}); + } + template + Status Spawn(TaskHints hints, Function&& func, StopToken stop_token) { return SpawnReal(hints, std::forward(func), std::move(stop_token), StopCallback{}); } + template + Status Spawn(TaskHints hints, Function&& func, StopToken stop_token, + StopCallback stop_callback) { + return SpawnReal(hints, std::forward(func), std::move(stop_token), + std::move(stop_callback)); + } // Transfers a future to this executor. Any continuations added to the // returned future will run in this executor. Otherwise they would run @@ -237,7 +251,7 @@ class ARROW_EXPORT SerialExecutor : public Executor { template using TopLevelTask = internal::FnOnce(Executor*)>; - ~SerialExecutor(); + ~SerialExecutor() override; int GetCapacity() override { return 1; }; Status SpawnReal(TaskHints hints, FnOnce task, StopToken, diff --git a/dev/archery/archery/lang/cpp.py b/dev/archery/archery/lang/cpp.py index 045d23b56b1..c2b1ca68001 100644 --- a/dev/archery/archery/lang/cpp.py +++ b/dev/archery/archery/lang/cpp.py @@ -42,7 +42,7 @@ def __init__(self, cc=None, cxx=None, cxx_flags=None, build_type=None, warn_level=None, cpp_package_prefix=None, install_prefix=None, use_conda=None, - build_static=False, build_shared=True, + build_static=False, build_shared=True, build_unity=True, # tests & examples with_tests=None, with_benchmarks=None, with_examples=None, with_integration=None, @@ -76,6 +76,7 @@ def __init__(self, self._use_conda = use_conda self.build_static = build_static self.build_shared = build_shared + self.build_unity = build_unity self.with_tests = with_tests self.with_benchmarks = with_benchmarks @@ -176,7 +177,6 @@ def _gen_defs(self): yield ("CMAKE_EXPORT_COMPILE_COMMANDS", truthifier(True)) yield ("CMAKE_BUILD_TYPE", self.build_type) - yield ("CMAKE_UNITY_BUILD", True) if not self.with_lint_only: yield ("BUILD_WARNING_LEVEL", @@ -195,6 +195,7 @@ def _gen_defs(self): yield ("ARROW_BUILD_STATIC", truthifier(self.build_static)) yield ("ARROW_BUILD_SHARED", truthifier(self.build_shared)) + yield ("CMAKE_UNITY_BUILD", truthifier(self.build_unity)) # Tests and benchmarks yield ("ARROW_BUILD_TESTS", truthifier(self.with_tests))