From 9ec897874e33ac108712de2821198b6a7b2dc346 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 17 Mar 2021 17:10:30 +0100 Subject: [PATCH 1/9] ARROW-11928: [C++] Execution engine API --- ci/appveyor-cpp-build.bat | 1 + ci/scripts/cpp_build.sh | 1 + cpp/CMakeLists.txt | 4 + cpp/cmake_modules/DefineOptions.cmake | 2 + cpp/src/arrow/CMakeLists.txt | 9 + cpp/src/arrow/compute/type_fwd.h | 1 + cpp/src/arrow/engine/CMakeLists.txt | 25 ++ cpp/src/arrow/engine/api.h | 20 ++ cpp/src/arrow/engine/exec_plan.cc | 211 ++++++++++++++ cpp/src/arrow/engine/exec_plan.h | 245 ++++++++++++++++ cpp/src/arrow/engine/plan_test.cc | 391 ++++++++++++++++++++++++++ cpp/src/arrow/engine/test_util.cc | 379 +++++++++++++++++++++++++ cpp/src/arrow/engine/test_util.h | 72 +++++ cpp/src/arrow/type_fwd.h | 1 + cpp/src/arrow/util/iterator.h | 6 + cpp/src/arrow/util/iterator_test.cc | 3 + 16 files changed, 1371 insertions(+) create mode 100644 cpp/src/arrow/engine/CMakeLists.txt create mode 100644 cpp/src/arrow/engine/api.h create mode 100644 cpp/src/arrow/engine/exec_plan.cc create mode 100644 cpp/src/arrow/engine/exec_plan.h create mode 100644 cpp/src/arrow/engine/plan_test.cc create mode 100644 cpp/src/arrow/engine/test_util.cc create mode 100644 cpp/src/arrow/engine/test_util.h diff --git a/ci/appveyor-cpp-build.bat b/ci/appveyor-cpp-build.bat index 6b930939660..534f73c2d50 100644 --- a/ci/appveyor-cpp-build.bat +++ b/ci/appveyor-cpp-build.bat @@ -97,6 +97,7 @@ cmake -G "%GENERATOR%" %CMAKE_ARGS% ^ -DARROW_CXXFLAGS="%ARROW_CXXFLAGS%" ^ -DARROW_DATASET=ON ^ -DARROW_ENABLE_TIMING_TESTS=OFF ^ + -DARROW_ENGINE=ON ^ -DARROW_FLIGHT=%ARROW_BUILD_FLIGHT% ^ -DARROW_GANDIVA=%ARROW_BUILD_GANDIVA% ^ -DARROW_MIMALLOC=ON ^ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 8a1e4f32f3a..d47a6696e8f 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -59,6 +59,7 @@ cmake -G "${CMAKE_GENERATOR:-Ninja}" \ -DARROW_CUDA=${ARROW_CUDA:-OFF} \ -DARROW_CXXFLAGS=${ARROW_CXXFLAGS:-} \ -DARROW_DATASET=${ARROW_DATASET:-ON} \ + -DARROW_ENGINE=${ARROW_ENGINE:-ON} \ -DARROW_DEPENDENCY_SOURCE=${ARROW_DEPENDENCY_SOURCE:-AUTO} \ -DARROW_EXTRA_ERROR_CONTEXT=${ARROW_EXTRA_ERROR_CONTEXT:-OFF} \ -DARROW_ENABLE_TIMING_TESTS=${ARROW_ENABLE_TIMING_TESTS:-ON} \ diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a6946403deb..a31af74f68e 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -343,6 +343,10 @@ if(ARROW_CUDA set(ARROW_IPC ON) endif() +if(ARROW_ENGINE) + set(ARROW_COMPUTE ON) +endif() + if(ARROW_DATASET) set(ARROW_COMPUTE ON) set(ARROW_FILESYSTEM ON) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 0e92811da8c..b2423cf3c76 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -211,6 +211,8 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) + define_option(ARROW_ENGINE "Build the Arrow Execution Engine" OFF) + define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) define_option(ARROW_FLIGHT diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 62ea94b8d02..01994316310 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -414,6 +414,11 @@ if(ARROW_COMPUTE) endif() endif() +if(ARROW_ENGINE) + list(APPEND ARROW_SRCS engine/exec_plan.cc) + list(APPEND ARROW_TESTING_SRCS engine/test_util.cc) +endif() + if(ARROW_FILESYSTEM) if(ARROW_HDFS) add_definitions(-DARROW_HDFS) @@ -679,6 +684,10 @@ if(ARROW_DATASET) add_subdirectory(dataset) endif() +if(ARROW_ENGINE) + add_subdirectory(engine) +endif() + if(ARROW_FILESYSTEM) add_subdirectory(filesystem) endif() diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 4f4393486ff..5370837f1b9 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -29,6 +29,7 @@ struct FunctionOptions; struct CastOptions; +struct ExecBatch; class ExecContext; class KernelContext; diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt new file mode 100644 index 00000000000..f34ae549df5 --- /dev/null +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Headers: top level +arrow_install_all_headers("arrow/engine") + +add_arrow_test(engine-plan-test + SOURCES + plan_test.cc + EXTRA_LABELS + engine) diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h new file mode 100644 index 00000000000..22b7f46181f --- /dev/null +++ b/cpp/src/arrow/engine/api.h @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/engine/exec_plan.h" // IWYU pragma: export diff --git a/cpp/src/arrow/engine/exec_plan.cc b/cpp/src/arrow/engine/exec_plan.cc new file mode 100644 index 00000000000..960ac109228 --- /dev/null +++ b/cpp/src/arrow/engine/exec_plan.cc @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/exec_plan.h" + +#include + +#include "arrow/datum.h" +#include "arrow/result.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +namespace { + +struct ExecPlanImpl : public ExecPlan { + ExecPlanImpl() = default; + + ~ExecPlanImpl() = default; + + void AddNode(std::unique_ptr node) { + if (node->num_inputs() == 0) { + sources_.push_back(node.get()); + } + if (node->num_outputs() == 0) { + sinks_.push_back(node.get()); + } + nodes_.push_back(std::move(node)); + } + + Status Validate() const { + if (nodes_.empty()) { + return Status::Invalid("ExecPlan has no node"); + } + for (const auto& node : nodes_) { + RETURN_NOT_OK(node->Validate()); + } + return Status::OK(); + } + + Status StartProducing() { + ARROW_ASSIGN_OR_RAISE(auto sorted_nodes, ReverseTopoSort()); + Status st; + auto it = sorted_nodes.begin(); + while (it != sorted_nodes.end() && st.ok()) { + st &= (*it++)->StartProducing(); + } + if (!st.ok()) { + // Stop nodes that successfully started, in reverse order + // (`it` now points after the node that failed starting, so need to rewind) + --it; + while (it != sorted_nodes.begin()) { + (*--it)->StopProducing(); + } + } + return st; + } + + Result ReverseTopoSort() { + struct ReverseTopoSort { + const std::vector>& nodes; + std::unordered_set visited; + std::unordered_set visiting; + NodeVector sorted; + + explicit ReverseTopoSort(const std::vector>& nodes) + : nodes(nodes) { + visited.reserve(nodes.size()); + sorted.reserve(nodes.size()); + } + + Status Sort() { + for (const auto& node : nodes) { + RETURN_NOT_OK(Visit(node.get())); + } + DCHECK_EQ(sorted.size(), nodes.size()); + DCHECK_EQ(visited.size(), nodes.size()); + DCHECK_EQ(visiting.size(), 0); + return Status::OK(); + } + + Status Visit(ExecNode* node) { + if (visited.count(node) != 0) { + return Status::OK(); + } + if (!visiting.insert(node).second) { + // Insertion failed => node is already being visited + return Status::Invalid("Cycle detected in execution plan"); + } + for (const auto& out : node->outputs()) { + RETURN_NOT_OK(Visit(out.output)); + } + visiting.erase(node); + visited.insert(node); + sorted.push_back(node); + return Status::OK(); + } + } topo_sort(nodes_); + + RETURN_NOT_OK(topo_sort.Sort()); + return std::move(topo_sort.sorted); + } + + std::vector> nodes_; + NodeVector sources_; + NodeVector sinks_; +}; + +ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast(ptr); } + +const ExecPlanImpl* ToDerived(const ExecPlan* ptr) { + return checked_cast(ptr); +} + +} // namespace + +Result> ExecPlan::Make() { + return std::make_shared(); +} + +void ExecPlan::AddNode(std::unique_ptr node) { + ToDerived(this)->AddNode(std::move(node)); +} + +const ExecPlan::NodeVector& ExecPlan::sources() const { + return ToDerived(this)->sources_; +} + +const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; } + +Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } + +Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } + +ExecNode::~ExecNode() = default; + +ExecNode::ExecNode(ExecPlan* plan, std::string label) + : plan_(plan), label_(std::move(label)) {} + +Status ExecNode::Validate() const { + if (inputs_.size() != static_cast(num_inputs())) { + return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ", + num_inputs(), ", actual ", inputs_.size(), ")"); + } + if (input_descrs_.size() != static_cast(num_inputs())) { + return Status::Invalid("Invalid number of input descrs for '", label(), + "' (expected ", num_inputs(), ", actual ", + input_descrs_.size(), ")"); + } + if (outputs_.size() != static_cast(num_outputs())) { + return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", + num_outputs(), ", actual ", outputs_.size(), ")"); + } + if (output_descrs_.size() != static_cast(num_outputs())) { + return Status::Invalid("Invalid number of output descrs for '", label(), + "' (expected ", num_outputs(), ", actual ", + output_descrs_.size(), ")"); + } + for (size_t i = 0; i < outputs_.size(); ++i) { + const auto& out = outputs_[i]; + if (out.input_index >= static_cast(out.output->inputs_.size()) || + out.input_index >= static_cast(out.output->input_descrs_.size()) || + this != out.output->inputs_[out.input_index]) { + return Status::Invalid("Output node configuration for '", label(), + "' inconsistent with input node configuration for '", + out.output->label(), "'"); + } + const auto& out_descr = output_descrs_[i]; + const auto& in_descr = out.output->input_descrs_[out.input_index]; + if (in_descr != out_descr) { + return Status::Invalid( + "Output node produces batches with type '", ValueDescr::ToString(out_descr), + "' inconsistent with input node configuration for '", out.output->label(), "'"); + } + } + return Status::OK(); +} + +void ExecNode::PauseProducing() { + for (const auto& node : inputs_) { + node->PauseProducing(); + } +} + +void ExecNode::ResumeProducing() { + for (const auto& node : inputs_) { + node->ResumeProducing(); + } +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/exec_plan.h b/cpp/src/arrow/engine/exec_plan.h new file mode 100644 index 00000000000..b28fac08369 --- /dev/null +++ b/cpp/src/arrow/engine/exec_plan.h @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" +#include "arrow/util/visibility.h" + +// NOTES: +// - ExecBatches only have arrays or scalars +// - data streams may be ordered, so add input number? +// - node to combine input needs to reorder + +namespace arrow { +namespace engine { + +class ExecNode; + +class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { + public: + using NodeVector = std::vector; + + virtual ~ExecPlan() = default; + + /// Make an empty exec plan + static Result> Make(); + + void AddNode(std::unique_ptr node); + + /// The initial inputs + const NodeVector& sources() const; + + /// The final outputs + const NodeVector& sinks() const; + + // XXX API question: + // There are clearly two phases in the ExecPlan lifecycle: + // - one construction phase where AddNode() and ExecNode::Bind() is called + // (with optional validation at the end) + // - one execution phase where the nodes are topo-sorted and then started + // + // => Should we separate out those APIs? e.g. have a ExecPlanBuilder + // for the first phase. + + Status Validate(); + + /// Start producing on all nodes + /// + /// Nodes are started in reverse topological order, such that any node + /// is started before all of its inputs. + Status StartProducing(); + + // XXX should we also have `void StopProducing()`? + + protected: + ExecPlan() = default; +}; + +class ARROW_EXPORT ExecNode { + public: + struct OutputNode { + ExecNode* output; + // Index of corresponding input in `output` node + int input_index; + }; + + using NodeVector = std::vector; + using OutputNodeVector = std::vector; + using BatchDescr = std::vector; + + virtual ~ExecNode(); + + virtual const char* kind_name() = 0; + // The number of inputs and outputs expected by this node + // XXX should these simply return `input_descrs_.size()` + // (`output_descrs_.size()` respectively)? + virtual int num_inputs() const = 0; + virtual int num_outputs() const = 0; + + /// This node's predecessors in the exec plan + const NodeVector& inputs() const { return inputs_; } + + /// The datatypes for each input + // XXX Should it be std::vector? + const std::vector& input_descrs() const { return input_descrs_; } + + /// This node's successors in the exec plan + const OutputNodeVector& outputs() const { return outputs_; } + + /// The datatypes for each output + // XXX Should it be std::vector? + const std::vector& output_descrs() const { return output_descrs_; } + + /// This node's exec plan + ExecPlan* plan() { return plan_; } + std::shared_ptr plan_ref() { return plan_->shared_from_this(); } + + /// \brief An optional label, for display and debugging + /// + /// There is no guarantee that this value is non-empty or unique. + const std::string& label() const { return label_; } + + int AddInput(ExecNode* node) { + inputs_.push_back(node); + return static_cast(inputs_.size() - 1); + } + + void AddOutput(ExecNode* node, int input_index) { + outputs_.push_back({node, input_index}); + } + + static void Bind(ExecNode* input, ExecNode* output) { + input->AddOutput(output, output->AddInput(input)); + } + + Status Validate() const; + + /// Upstream API: + /// These functions are called by input nodes that want to inform this node + /// about an updated condition (a new input batch, an error, an impeding + /// end of stream). + /// + /// Implementation rules: + /// - these may be called anytime after StartProducing() has succeeded + /// (and even during or after StopProducing()) + /// - these may be called concurrently + /// - these are allowed to call back into PauseProducing(), ResumeProducing() + /// and StopProducing() + + /// Transfer input batch to ExecNode + virtual void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) = 0; + + /// Signal error to ExecNode + virtual void ErrorReceived(int input_index, Status error) = 0; + + /// Mark the inputs finished after the given number of batches. + /// + /// This may be called before all inputs are received. This simply fixes + /// the total number of incoming batches for an input, so that the ExecNode + /// knows when it has received all input, regardless of order. + virtual void InputFinished(int input_index, int seq_stop) = 0; + + /// Lifecycle API: + /// - start / stop to initiate and terminate production + /// - pause / resume to apply backpressure + /// + /// Implementation rules: + /// - StartProducing() should not recurse into the inputs, as it is + /// handled by ExecPlan::StartProducing() + /// - PauseProducing(), ResumeProducing(), StopProducing() may be called + /// concurrently (but only after StartProducing() has returned successfully) + /// - PauseProducing(), ResumeProducing(), StopProducing() may be called + /// by the downstream nodes' InputReceived(), ErrorReceived(), InputFinished() + /// methods + /// - StopProducing() should recurse into the inputs + /// - StopProducing() must be idempotent + + // XXX What happens if StartProducing() calls an output's InputReceived() + // synchronously, and InputReceived() decides to call back into StopProducing() + // (or PauseProducing()) because it received enough data? + // + // Right now, since synchronous calls happen in both directions (input to + // output and then output to input), a node must be careful to be reentrant + // against synchronous calls from its output, *and* also concurrent calls from + // other threads. The most reliable solution is to update the internal state + // first, and notify outputs only at the end. + // + // Alternate rules: + // - StartProducing(), ResumeProducing() can call synchronously into + // its ouputs' consuming methods (InputReceived() etc.) + // - InputReceived(), ErrorReceived(), InputFinished() can call asynchronously + // into its inputs' PauseProducing(), StopProducing() + // + // Alternate API: + // - InputReceived(), ErrorReceived(), InputFinished() return a ProductionHint + // enum: either None (default), PauseProducing, ResumeProducing, StopProducing + // - A method allows passing a ProductionHint asynchronously from an output node + // (replacing PauseProducing(), ResumeProducing(), StopProducing()) + + // TODO PauseProducing() etc. should probably take the index of the output which calls + // them? + + /// \brief Start producing + /// + /// This must only be called once. If this fails, then other lifecycle + /// methods must not be called. + /// + /// This is typically called automatically by ExecPlan::StartProducing(). + virtual Status StartProducing() = 0; + + /// \brief Pause producing temporarily + /// + /// This call is a hint that an output node is currently not willing + /// to receive data. + /// + /// This may be called any number of times after StartProducing() succeeds. + /// However, the node is still free to produce data (which may be difficult + /// to prevent anyway if data is producer using multiple threads). + virtual void PauseProducing(); + + /// \brief Resume producing after a temporary pause + /// + /// This call is a hint that an output node is willing to receive data again. + /// + /// This may be called any number of times after StartProducing() succeeds. + /// This may also be called concurrently with PauseProducing(), which suggests + /// the implementation may use an atomic counter. + virtual void ResumeProducing(); + + /// \brief Stop producing definitively + virtual void StopProducing() = 0; + + protected: + ExecNode(ExecPlan* plan, std::string label); + + ExecPlan* plan_; + std::string label_; + NodeVector inputs_; + OutputNodeVector outputs_; + std::vector input_descrs_; + std::vector output_descrs_; +}; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/plan_test.cc b/cpp/src/arrow/engine/plan_test.cc new file mode 100644 index 00000000000..282b3f7b395 --- /dev/null +++ b/cpp/src/arrow/engine/plan_test.cc @@ -0,0 +1,391 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include + +#include "arrow/engine/exec_plan.h" +#include "arrow/engine/test_util.h" +#include "arrow/record_batch.h" +#include "arrow/testing/future_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/logging.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::Executor; + +namespace engine { + +void AssertBatchesEqual(const RecordBatchVector& expected, + const RecordBatchVector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + for (size_t i = 0; i < expected.size(); ++i) { + AssertBatchesEqual(*expected[i], *actual[i]); + } +} + +TEST(ExecPlanConstruction, Empty) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + ASSERT_RAISES(Invalid, plan->Validate()); +} + +TEST(ExecPlanConstruction, SingleNode) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/0); + ASSERT_OK(plan->Validate()); + ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); + ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); + + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/0); + // Input not bound + ASSERT_RAISES(Invalid, plan->Validate()); + + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/1); + // Output not bound + ASSERT_RAISES(Invalid, plan->Validate()); +} + +TEST(ExecPlanConstruction, SourceSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); + // Input / output not bound + ASSERT_RAISES(Invalid, plan->Validate()); + + ExecNode::Bind(source, sink); + ASSERT_OK(plan->Validate()); + ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source)); + ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); +} + +TEST(ExecPlanConstruction, MultipleNode) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source1 = + MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2); + auto source2 = + MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2); + auto process2 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/1); + ExecNode::Bind(source1, process1); + ExecNode::Bind(source1, process2); + ExecNode::Bind(source2, process2); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1); + ExecNode::Bind(process1, process3); + ExecNode::Bind(process1, process3); + ExecNode::Bind(process2, process3); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); + ExecNode::Bind(process3, sink); + + ASSERT_OK(plan->Validate()); + ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source1, source2)); + ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); +} + +struct StartStopTracker { + std::vector started; + std::vector stopped; + + StartProducingFunc start_producing_func(Status st = Status::OK()) { + return [this, st](ExecNode* node) { + started.push_back(node->label()); + return st; + }; + } + + StopProducingFunc stop_producing_func() { + return [this](ExecNode* node) { stopped.push_back(node->label()); }; + } +}; + +TEST(ExecPlan, DummyStartProducing) { + StartStopTracker t; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + t.start_producing_func(), t.stop_producing_func()); + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + t.start_producing_func(), t.stop_producing_func()); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(source1, process1); + ExecNode::Bind(process1, process2); + ExecNode::Bind(source2, process2); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(process1, process3); + ExecNode::Bind(source1, process3); + ExecNode::Bind(process2, process3); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(process3, sink); + + ASSERT_OK(plan->Validate()); + ASSERT_EQ(t.started.size(), 0); + ASSERT_EQ(t.stopped.size(), 0); + + ASSERT_OK(plan->StartProducing()); + // Note that any correct reverse topological order may do + ASSERT_THAT(t.started, ::testing::ElementsAre("sink", "process3", "process2", + "process1", "source1", "source2")); + ASSERT_EQ(t.stopped.size(), 0); +} + +TEST(ExecPlan, DummyStartProducingCycle) { + // A trivial cycle + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/1); + ExecNode::Bind(node, node); + ASSERT_OK(plan->Validate()); + ASSERT_RAISES(Invalid, plan->StartProducing()); + + // A less trivial one + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/2); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1, /*num_outputs=*/1); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2, /*num_outputs=*/2); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); + ExecNode::Bind(source, process1); + ExecNode::Bind(process1, process2); + ExecNode::Bind(process2, process3); + ExecNode::Bind(process1, process3); + ExecNode::Bind(process3, process1); + ExecNode::Bind(process3, sink); + ASSERT_OK(plan->Validate()); + ASSERT_RAISES(Invalid, plan->StartProducing()); +} + +TEST(ExecPlan, DummyStartProducingError) { + StartStopTracker t; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + t.start_producing_func(Status::NotImplemented("zzz")), + t.stop_producing_func()); + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + auto process1 = MakeDummyNode( + plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(source1, process1); + ExecNode::Bind(process1, process2); + ExecNode::Bind(source2, process2); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(process1, process3); + ExecNode::Bind(source1, process3); + ExecNode::Bind(process2, process3); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); + ExecNode::Bind(process3, sink); + + ASSERT_OK(plan->Validate()); + ASSERT_EQ(t.started.size(), 0); + ASSERT_EQ(t.stopped.size(), 0); + + // `process1` raises IOError + ASSERT_RAISES(IOError, plan->StartProducing()); + ASSERT_THAT(t.started, + ::testing::ElementsAre("sink", "process3", "process2", "process1")); + // Nodes that started successfully were stopped in reverse order + ASSERT_THAT(t.stopped, ::testing::ElementsAre("process2", "process3", "sink")); +} + +// TODO move this to gtest_util.h? + +class SlowRecordBatchReader : public RecordBatchReader { + public: + explicit SlowRecordBatchReader(std::shared_ptr reader) + : reader_(std::move(reader)) {} + + std::shared_ptr schema() const override { return reader_->schema(); } + + Status ReadNext(std::shared_ptr* batch) override { + SleepABit(); + return reader_->ReadNext(batch); + } + + static Result> Make( + RecordBatchVector batches, std::shared_ptr schema = nullptr) { + ARROW_ASSIGN_OR_RAISE(auto reader, + RecordBatchReader::Make(std::move(batches), std::move(schema))); + return std::make_shared(std::move(reader)); + } + + protected: + std::shared_ptr reader_; +}; + +static Result MakeSlowRecordBatchGenerator( + RecordBatchVector batches, std::shared_ptr schema) { + auto gen = MakeVectorGenerator(batches); + // TODO move this into testing/async_generator_util.h? + auto delayed_gen = MakeMappedGenerator>( + std::move(gen), [](const std::shared_ptr& batch) { + auto fut = Future>::Make(); + SleepABitAsync().AddCallback( + [fut, batch](const Result<::arrow::detail::Empty>&) mutable { + fut.MarkFinished(batch); + }); + return fut; + }); + // Adding readahead implicitly adds parallelism by pulling reentrantly from + // the delayed generator + return MakeReadaheadGenerator(std::move(delayed_gen), /*max_readahead=*/64); +} + +class TestExecPlanExecution : public ::testing::Test { + public: + void SetUp() override { + ASSERT_OK_AND_ASSIGN(io_executor_, internal::ThreadPool::Make(8)); + } + + RecordBatchVector MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, int batch_size = 4) { + random::RandomArrayGenerator rng(42); + RecordBatchVector batches; + batches.reserve(num_batches); + for (int i = 0; i < num_batches; ++i) { + batches.push_back(rng.BatchOf(schema->fields(), batch_size)); + } + return batches; + } + + struct CollectorPlan { + std::shared_ptr plan; + RecordBatchCollectNode* sink; + }; + + Result MakeSourceSink(std::shared_ptr reader, + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); + auto source = + MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); + auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); + ExecNode::Bind(source, sink); + return CollectorPlan{plan, sink}; + } + + Result MakeSourceSink(RecordBatchGenerator generator, + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); + auto source = MakeRecordBatchReaderNode(plan.get(), "source", schema, generator, + io_executor_.get()); + auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); + ExecNode::Bind(source, sink); + return CollectorPlan{plan, sink}; + } + + Result MakeSourceSink(const RecordBatchVector& batches, + const std::shared_ptr& schema) { + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches, schema)); + return MakeSourceSink(std::move(reader), schema); + } + + Result StartAndCollect(ExecPlan* plan, + RecordBatchCollectNode* sink) { + RETURN_NOT_OK(plan->StartProducing()); + auto fut = CollectAsyncGenerator(sink->generator()); + return fut.result(); + } + + template + void TestSourceSink(RecordBatchReaderFactory batch_factory) { + auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); + // clang-format off + RecordBatchVector batches{ + RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])") + }; + // clang-format on + + ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); + ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); + ASSERT_OK(cp.plan->Validate()); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); + AssertBatchesEqual(batches, got_batches); + } + + template + void TestStressSourceSink(int num_batches, RecordBatchReaderFactory batch_factory) { + auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); + auto batches = MakeRandomBatches(schema, num_batches); + + ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); + ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); + ASSERT_OK(cp.plan->Validate()); + + ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); + AssertBatchesEqual(batches, got_batches); + } + + protected: + std::shared_ptr io_executor_; +}; + +TEST_F(TestExecPlanExecution, SourceSink) { TestSourceSink(RecordBatchReader::Make); } + +TEST_F(TestExecPlanExecution, SlowSourceSink) { + TestSourceSink(SlowRecordBatchReader::Make); +} + +TEST_F(TestExecPlanExecution, SlowSourceSinkParallel) { + TestSourceSink(MakeSlowRecordBatchGenerator); +} + +TEST_F(TestExecPlanExecution, StressSourceSink) { + TestStressSourceSink(/*num_batches=*/200, RecordBatchReader::Make); +} + +TEST_F(TestExecPlanExecution, StressSlowSourceSink) { + // This doesn't create parallelism as the RecordBatchReader is iterated serially. + TestStressSourceSink(/*num_batches=*/30, SlowRecordBatchReader::Make); +} + +TEST_F(TestExecPlanExecution, StressSlowSourceSinkParallel) { + TestStressSourceSink(/*num_batches=*/300, MakeSlowRecordBatchGenerator); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/test_util.cc b/cpp/src/arrow/engine/test_util.cc new file mode 100644 index 00000000000..001517d708d --- /dev/null +++ b/cpp/src/arrow/engine/test_util.cc @@ -0,0 +1,379 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/test_util.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/compute/exec.h" +#include "arrow/datum.h" +#include "arrow/engine/exec_plan.h" +#include "arrow/record_batch.h" +#include "arrow/type.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/iterator.h" +#include "arrow/util/logging.h" +#include "arrow/util/optional.h" + +namespace arrow { + +using internal::Executor; + +namespace engine { +namespace { + +// TODO expose this as `static ValueDescr::FromSchemaColumns`? +std::vector DescrFromSchemaColumns(const Schema& schema) { + std::vector descr; + descr.reserve(schema.num_fields()); + std::transform(schema.fields().begin(), schema.fields().end(), + std::back_inserter(descr), [](const std::shared_ptr& field) { + return ValueDescr::Array(field->type()); + }); + return descr; +} + +struct DummyNode : ExecNode { + DummyNode(ExecPlan* plan, std::string label, int num_inputs, int num_outputs, + StartProducingFunc start_producing, StopProducingFunc stop_producing) + : ExecNode(plan, std::move(label)), + num_inputs_(num_inputs), + num_outputs_(num_outputs), + start_producing_(std::move(start_producing)), + stop_producing_(std::move(stop_producing)) { + input_descrs_.assign(num_inputs, descr()); + output_descrs_.assign(num_outputs, descr()); + } + + const char* kind_name() override { return "RecordBatchReader"; } + + int num_inputs() const override { return num_inputs_; } + + int num_outputs() const override { return num_outputs_; } + + void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) override {} + + void ErrorReceived(int input_index, Status error) override {} + + void InputFinished(int input_index, int seq_stop) override {} + + Status StartProducing() override { + if (start_producing_) { + RETURN_NOT_OK(start_producing_(this)); + } + started_ = true; + return Status::OK(); + } + + void StopProducing() override { + if (started_) { + started_ = false; + for (const auto& input : inputs_) { + input->StopProducing(); + } + if (stop_producing_) { + stop_producing_(this); + } + } + } + + private: + BatchDescr descr() const { return std::vector{ValueDescr(null())}; } + + int num_inputs_; + int num_outputs_; + StartProducingFunc start_producing_; + StopProducingFunc stop_producing_; + bool started_ = false; +}; + +struct RecordBatchReaderNode : ExecNode { + RecordBatchReaderNode(ExecPlan* plan, std::string label, + std::shared_ptr reader, Executor* io_executor) + : ExecNode(plan, std::move(label)), + schema_(reader->schema()), + reader_(std::move(reader)), + io_executor_(io_executor) { + output_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + } + + RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, + RecordBatchGenerator generator, Executor* io_executor) + : ExecNode(plan, std::move(label)), + schema_(std::move(schema)), + io_executor_(io_executor), + generator_(std::move(generator)) { + output_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + } + + const char* kind_name() override { return "RecordBatchReader"; } + + int num_inputs() const override { return 0; } + + int num_outputs() const override { return 1; } + + void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) override {} + + void ErrorReceived(int input_index, Status error) override {} + + void InputFinished(int input_index, int seq_stop) override {} + + Status StartProducing() override { + next_batch_index_ = 0; + if (!generator_) { + auto it = MakePointerIterator(reader_.get()); + ARROW_ASSIGN_OR_RAISE(generator_, + MakeBackgroundGenerator(std::move(it), io_executor_)); + } + GenerateOne(std::unique_lock{mutex_}); + return Status::OK(); + } + + void StopProducing() override { + std::unique_lock lock(mutex_); + generator_ = nullptr; // null function + } + + // TODO implement PauseProducing / ResumeProducing + + private: + void GenerateOne(std::unique_lock&& lock) { + if (!generator_) { + // Stopped + return; + } + auto plan = plan_ref(); + auto fut = generator_(); + const auto batch_index = next_batch_index_++; + + lock.unlock(); + // TODO we want to transfer always here + io_executor_->Transfer(std::move(fut)) + .AddCallback( + [plan, batch_index, this](const Result>& res) { + std::unique_lock lock(mutex_); + DCHECK_EQ(outputs_.size(), 1); + OutputNode* out = &outputs_[0]; + if (!res.ok()) { + out->output->ErrorReceived(out->input_index, res.status()); + return; + } + const auto& batch = *res; + if (IsIterationEnd(batch)) { + lock.unlock(); + out->output->InputFinished(out->input_index, batch_index); + } else { + lock.unlock(); + out->output->InputReceived(out->input_index, batch_index, + compute::ExecBatch(*batch)); + lock.lock(); + GenerateOne(std::move(lock)); + } + }); + } + + const std::shared_ptr schema_; + const std::shared_ptr reader_; + Executor* const io_executor_; + + std::mutex mutex_; + RecordBatchGenerator generator_; + int next_batch_index_; +}; + +struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { + RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, + const std::shared_ptr& schema) + : RecordBatchCollectNode(plan, std::move(label)), schema_(schema) { + input_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + } + + RecordBatchGenerator generator() override { return generator_; } + + const char* kind_name() override { return "RecordBatchReader"; } + + int num_inputs() const override { return 1; } + + int num_outputs() const override { return 0; } + + Status StartProducing() override { + num_received_ = 0; + num_emitted_ = 0; + emit_stop_ = -1; + stopped_ = false; + producer_.emplace(generator_.producer()); + return Status::OK(); + } + + void StopProducing() override { + std::unique_lock lock(mutex_); + StopProducing(&lock); + } + + void InputReceived(int input_index, int seq_num, + compute::ExecBatch exec_batch) override { + std::unique_lock lock(mutex_); + if (stopped_) { + return; + } + auto maybe_batch = MakeBatch(std::move(exec_batch)); + if (!maybe_batch.ok()) { + lock.unlock(); + producer_->Push(std::move(maybe_batch)); + return; + } + + // TODO would be nice to factor this out in a ReorderQueue + auto batch = *std::move(maybe_batch); + if (seq_num <= static_cast(received_batches_.size())) { + received_batches_.resize(seq_num + 1, nullptr); + } + DCHECK_EQ(received_batches_[seq_num], nullptr); + received_batches_[seq_num] = std::move(batch); + ++num_received_; + + if (seq_num != num_emitted_) { + // Cannot emit yet as there is a hole at `num_emitted_` + DCHECK_GT(seq_num, num_emitted_); + DCHECK_EQ(received_batches_[num_emitted_], nullptr); + return; + } + if (num_received_ == emit_stop_) { + StopProducing(&lock); + } + + // Emit batches in order as far as possible + // First collect these batches, then unlock before producing. + const auto seq_start = seq_num; + while (seq_num < static_cast(received_batches_.size()) && + received_batches_[seq_num] != nullptr) { + ++seq_num; + } + DCHECK_GT(seq_num, seq_start); + // By moving the values now, we make sure another thread won't emit the same values + // below + RecordBatchVector to_emit( + std::make_move_iterator(received_batches_.begin() + seq_start), + std::make_move_iterator(received_batches_.begin() + seq_num)); + + lock.unlock(); + for (auto&& batch : to_emit) { + producer_->Push(std::move(batch)); + } + lock.lock(); + + DCHECK_EQ(seq_start, num_emitted_); // num_emitted_ wasn't bumped in the meantime + num_emitted_ = seq_num; + } + + void ErrorReceived(int input_index, Status error) override { + // XXX do we care about properly sequencing the error? + producer_->Push(std::move(error)); + StopProducing(); + } + + void InputFinished(int input_index, int seq_stop) override { + std::unique_lock lock(mutex_); + DCHECK_GE(seq_stop, static_cast(received_batches_.size())); + received_batches_.reserve(seq_stop); + emit_stop_ = seq_stop; + if (emit_stop_ == num_received_) { + DCHECK_EQ(emit_stop_, num_emitted_); + StopProducing(&lock); + } + } + + private: + void StopProducing(std::unique_lock* lock) { + if (!stopped_) { + stopped_ = true; + producer_->Close(); + inputs_[0]->StopProducing(); + } + } + + // TODO factor this out as ExecBatch::ToRecordBatch()? + Result> MakeBatch(compute::ExecBatch&& exec_batch) { + ArrayDataVector columns; + columns.reserve(exec_batch.values.size()); + for (auto&& value : exec_batch.values) { + if (!value.is_array()) { + return Status::TypeError("Expected array input"); + } + columns.push_back(std::move(value).array()); + } + return RecordBatch::Make(schema_, exec_batch.length, std::move(columns)); + } + + const std::shared_ptr schema_; + + std::mutex mutex_; + RecordBatchVector received_batches_; + int num_received_; + int num_emitted_; + int emit_stop_; + bool stopped_; + + PushGenerator> generator_; + util::optional>::Producer> producer_; +}; + +} // namespace + +ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, + std::shared_ptr reader, + Executor* io_executor) { + auto ptr = + new RecordBatchReaderNode(plan, std::move(label), std::move(reader), io_executor); + plan->AddNode(std::unique_ptr{ptr}); + return ptr; +} + +ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, + std::shared_ptr schema, + RecordBatchGenerator generator, + ::arrow::internal::Executor* io_executor) { + auto ptr = new RecordBatchReaderNode(plan, std::move(label), std::move(schema), + std::move(generator), io_executor); + plan->AddNode(std::unique_ptr{ptr}); + return ptr; +} + +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, + int num_outputs, StartProducingFunc start_producing, + StopProducingFunc stop_producing) { + auto ptr = new DummyNode(plan, std::move(label), num_inputs, num_outputs, + std::move(start_producing), std::move(stop_producing)); + plan->AddNode(std::unique_ptr{ptr}); + return ptr; +} + +RecordBatchCollectNode* MakeRecordBatchCollectNode( + ExecPlan* plan, std::string label, const std::shared_ptr& schema) { + auto ptr = new RecordBatchCollectNodeImpl(plan, std::move(label), schema); + plan->AddNode(std::unique_ptr{ptr}); + return ptr; +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/test_util.h b/cpp/src/arrow/engine/test_util.h new file mode 100644 index 00000000000..e5d19859b26 --- /dev/null +++ b/cpp/src/arrow/engine/test_util.h @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/engine/exec_plan.h" +#include "arrow/record_batch.h" +#include "arrow/testing/visibility.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/type_fwd.h" + +namespace arrow { +namespace engine { + +using StartProducingFunc = std::function; +using StopProducingFunc = std::function; + +// Make a dummy node that has no execution behaviour +ARROW_TESTING_EXPORT +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, + int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); + +using RecordBatchGenerator = AsyncGenerator>; + +// Make a source node that produces record batches by reading in the background +// from a RecordBatchReader. +// 0 input +// 1 output (N columns) +ARROW_TESTING_EXPORT +ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, + std::shared_ptr reader, + ::arrow::internal::Executor* io_executor); + +ARROW_TESTING_EXPORT +ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, + std::shared_ptr schema, + RecordBatchGenerator generator, + ::arrow::internal::Executor* io_executor); + +class RecordBatchCollectNode : public ExecNode { + public: + virtual RecordBatchGenerator generator() = 0; + + protected: + using ExecNode::ExecNode; +}; + +ARROW_TESTING_EXPORT +RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, + const std::shared_ptr& schema); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 7eb318c8b41..d541209a314 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -81,6 +81,7 @@ class RecordBatchReader; class Table; struct Datum; +struct ValueDescr; using ChunkedArrayVector = std::vector>; using RecordBatchVector = std::vector>; diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index b82021e4b21..97ad0d73c35 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -370,6 +370,12 @@ Iterator MakeErrorIterator(Status s) { }); } +template ().Next())::ValueType> +Iterator MakePointerIterator(It* it) { + return MakeFunctionIterator([it]() -> Result { return it->Next(); }); +} + /// \brief Simple iterator which yields the elements of a std::vector template class VectorIterator { diff --git a/cpp/src/arrow/util/iterator_test.cc b/cpp/src/arrow/util/iterator_test.cc index 60b57dea1e2..dc2a2398729 100644 --- a/cpp/src/arrow/util/iterator_test.cc +++ b/cpp/src/arrow/util/iterator_test.cc @@ -31,6 +31,9 @@ #include "arrow/util/iterator.h" #include "arrow/util/test_common.h" #include "arrow/util/vector.h" + +// TODO add test for MakePointerIterator + namespace arrow { template From 1b7cbf47140e687b8a0ee041595a372fd6c680ff Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 6 May 2021 13:44:31 -0400 Subject: [PATCH 2/9] remove engine component --- cpp/cmake_modules/DefineOptions.cmake | 2 -- cpp/src/arrow/CMakeLists.txt | 11 +++----- cpp/src/arrow/compute/exec/CMakeLists.txt | 2 ++ .../{engine => compute/exec}/exec_plan.cc | 2 +- .../{engine => compute/exec}/exec_plan.h | 0 .../{engine => compute/exec}/plan_test.cc | 4 +-- .../{engine => compute/exec}/test_util.cc | 4 +-- .../{engine => compute/exec}/test_util.h | 2 +- cpp/src/arrow/engine/CMakeLists.txt | 25 ------------------- cpp/src/arrow/engine/api.h | 20 --------------- 10 files changed, 11 insertions(+), 61 deletions(-) rename cpp/src/arrow/{engine => compute/exec}/exec_plan.cc (99%) rename cpp/src/arrow/{engine => compute/exec}/exec_plan.h (100%) rename cpp/src/arrow/{engine => compute/exec}/plan_test.cc (99%) rename cpp/src/arrow/{engine => compute/exec}/test_util.cc (99%) rename cpp/src/arrow/{engine => compute/exec}/test_util.h (98%) delete mode 100644 cpp/src/arrow/engine/CMakeLists.txt delete mode 100644 cpp/src/arrow/engine/api.h diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index b2423cf3c76..0e92811da8c 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -211,8 +211,6 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}") define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) - define_option(ARROW_ENGINE "Build the Arrow Execution Engine" OFF) - define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) define_option(ARROW_FLIGHT diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 01994316310..bee14ae4ce3 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -367,6 +367,7 @@ if(ARROW_COMPUTE) compute/api_vector.cc compute/cast.cc compute/exec.cc + compute/exec/exec_plan.cc compute/exec/expression.cc compute/function.cc compute/kernel.cc @@ -405,6 +406,7 @@ if(ARROW_COMPUTE) set_source_files_properties(compute/kernels/aggregate_basic_avx2.cc PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG}) endif() + if(ARROW_HAVE_RUNTIME_AVX512) list(APPEND ARROW_SRCS compute/kernels/aggregate_basic_avx512.cc) set_source_files_properties(compute/kernels/aggregate_basic_avx512.cc PROPERTIES @@ -412,11 +414,8 @@ if(ARROW_COMPUTE) set_source_files_properties(compute/kernels/aggregate_basic_avx512.cc PROPERTIES COMPILE_FLAGS ${ARROW_AVX512_FLAG}) endif() -endif() -if(ARROW_ENGINE) - list(APPEND ARROW_SRCS engine/exec_plan.cc) - list(APPEND ARROW_TESTING_SRCS engine/test_util.cc) + list(APPEND ARROW_TESTING_SRCS compute/exec/test_util.cc) endif() if(ARROW_FILESYSTEM) @@ -684,10 +683,6 @@ if(ARROW_DATASET) add_subdirectory(dataset) endif() -if(ARROW_ENGINE) - add_subdirectory(engine) -endif() - if(ARROW_FILESYSTEM) add_subdirectory(filesystem) endif() diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index a10c1dad469..ac6ddc51dff 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -19,4 +19,6 @@ arrow_install_all_headers("arrow/compute/exec") add_arrow_compute_test(expression_test PREFIX "arrow-compute") +add_arrow_compute_test(plan_test PREFIX "arrow-compute") + add_arrow_benchmark(expression_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/engine/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc similarity index 99% rename from cpp/src/arrow/engine/exec_plan.cc rename to cpp/src/arrow/compute/exec/exec_plan.cc index 960ac109228..3ddcdae2932 100644 --- a/cpp/src/arrow/engine/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/engine/exec_plan.h" +#include "arrow/compute/exec/exec_plan.h" #include diff --git a/cpp/src/arrow/engine/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h similarity index 100% rename from cpp/src/arrow/engine/exec_plan.h rename to cpp/src/arrow/compute/exec/exec_plan.h diff --git a/cpp/src/arrow/engine/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc similarity index 99% rename from cpp/src/arrow/engine/plan_test.cc rename to cpp/src/arrow/compute/exec/plan_test.cc index 282b3f7b395..9bb9d43a5fc 100644 --- a/cpp/src/arrow/engine/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -20,8 +20,8 @@ #include #include -#include "arrow/engine/exec_plan.h" -#include "arrow/engine/test_util.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/test_util.h" #include "arrow/record_batch.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/engine/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc similarity index 99% rename from cpp/src/arrow/engine/test_util.cc rename to cpp/src/arrow/compute/exec/test_util.cc index 001517d708d..787049bcf5e 100644 --- a/cpp/src/arrow/engine/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/engine/test_util.h" +#include "arrow/compute/exec/test_util.h" #include #include @@ -27,7 +27,7 @@ #include "arrow/compute/exec.h" #include "arrow/datum.h" -#include "arrow/engine/exec_plan.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/record_batch.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" diff --git a/cpp/src/arrow/engine/test_util.h b/cpp/src/arrow/compute/exec/test_util.h similarity index 98% rename from cpp/src/arrow/engine/test_util.h rename to cpp/src/arrow/compute/exec/test_util.h index e5d19859b26..520e4f0f867 100644 --- a/cpp/src/arrow/engine/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -22,7 +22,7 @@ #include #include -#include "arrow/engine/exec_plan.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/record_batch.h" #include "arrow/testing/visibility.h" #include "arrow/util/async_generator.h" diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt deleted file mode 100644 index f34ae549df5..00000000000 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Headers: top level -arrow_install_all_headers("arrow/engine") - -add_arrow_test(engine-plan-test - SOURCES - plan_test.cc - EXTRA_LABELS - engine) diff --git a/cpp/src/arrow/engine/api.h b/cpp/src/arrow/engine/api.h deleted file mode 100644 index 22b7f46181f..00000000000 --- a/cpp/src/arrow/engine/api.h +++ /dev/null @@ -1,20 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include "arrow/engine/exec_plan.h" // IWYU pragma: export From 10119aa8f71ef5a08009c4f746dfb56c63b20152 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 6 May 2021 13:55:16 -0400 Subject: [PATCH 3/9] ~namespace engine --- cpp/src/arrow/compute/exec/exec_plan.cc | 4 ++-- cpp/src/arrow/compute/exec/exec_plan.h | 4 ++-- cpp/src/arrow/compute/exec/plan_test.cc | 4 ++-- cpp/src/arrow/compute/exec/test_util.cc | 4 ++-- cpp/src/arrow/compute/exec/test_util.h | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 3ddcdae2932..33fd65f488e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -28,7 +28,7 @@ namespace arrow { using internal::checked_cast; -namespace engine { +namespace compute { namespace { @@ -207,5 +207,5 @@ void ExecNode::ResumeProducing() { } } -} // namespace engine +} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index b28fac08369..dd133bb5c24 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -32,7 +32,7 @@ // - node to combine input needs to reorder namespace arrow { -namespace engine { +namespace compute { class ExecNode; @@ -241,5 +241,5 @@ class ARROW_EXPORT ExecNode { std::vector output_descrs_; }; -} // namespace engine +} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 9bb9d43a5fc..3066eda283d 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -33,7 +33,7 @@ namespace arrow { using internal::Executor; -namespace engine { +namespace compute { void AssertBatchesEqual(const RecordBatchVector& expected, const RecordBatchVector& actual) { @@ -387,5 +387,5 @@ TEST_F(TestExecPlanExecution, StressSlowSourceSinkParallel) { TestStressSourceSink(/*num_batches=*/300, MakeSlowRecordBatchGenerator); } -} // namespace engine +} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 787049bcf5e..d099feb5255 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -39,7 +39,7 @@ namespace arrow { using internal::Executor; -namespace engine { +namespace compute { namespace { // TODO expose this as `static ValueDescr::FromSchemaColumns`? @@ -375,5 +375,5 @@ RecordBatchCollectNode* MakeRecordBatchCollectNode( return ptr; } -} // namespace engine +} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 520e4f0f867..c4b4e5d79e6 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -29,7 +29,7 @@ #include "arrow/util/type_fwd.h" namespace arrow { -namespace engine { +namespace compute { using StartProducingFunc = std::function; using StopProducingFunc = std::function; @@ -68,5 +68,5 @@ ARROW_TESTING_EXPORT RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, const std::shared_ptr& schema); -} // namespace engine +} // namespace compute } // namespace arrow From 6ccdb61349c4115302b6e6873303da8eb91f1ac6 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 6 May 2021 16:31:30 -0400 Subject: [PATCH 4/9] refactor to singular outputs --- cpp/src/arrow/compute/exec/exec_plan.cc | 101 +++++++------ cpp/src/arrow/compute/exec/exec_plan.h | 76 +++++----- cpp/src/arrow/compute/exec/plan_test.cc | 192 ++++++++++++------------ cpp/src/arrow/compute/exec/test_util.cc | 93 +++++------- cpp/src/arrow/compute/exec/test_util.h | 8 +- 5 files changed, 226 insertions(+), 244 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 33fd65f488e..360065227c1 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -23,6 +23,7 @@ #include "arrow/result.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/optional.h" namespace arrow { @@ -35,16 +36,14 @@ namespace { struct ExecPlanImpl : public ExecPlan { ExecPlanImpl() = default; - ~ExecPlanImpl() = default; + ~ExecPlanImpl() override = default; - void AddNode(std::unique_ptr node) { + ExecNode* AddNode(std::unique_ptr node) { if (node->num_inputs() == 0) { sources_.push_back(node.get()); } - if (node->num_outputs() == 0) { - sinks_.push_back(node.get()); - } nodes_.push_back(std::move(node)); + return nodes_.back().get(); } Status Validate() const { @@ -76,13 +75,13 @@ struct ExecPlanImpl : public ExecPlan { } Result ReverseTopoSort() { - struct ReverseTopoSort { + struct TopoSort { const std::vector>& nodes; std::unordered_set visited; std::unordered_set visiting; NodeVector sorted; - explicit ReverseTopoSort(const std::vector>& nodes) + explicit TopoSort(const std::vector>& nodes) : nodes(nodes) { visited.reserve(nodes.size()); sorted.reserve(nodes.size()); @@ -102,27 +101,36 @@ struct ExecPlanImpl : public ExecPlan { if (visited.count(node) != 0) { return Status::OK(); } - if (!visiting.insert(node).second) { + + auto it_success = visiting.insert(node); + if (!it_success.second) { // Insertion failed => node is already being visited return Status::Invalid("Cycle detected in execution plan"); } - for (const auto& out : node->outputs()) { - RETURN_NOT_OK(Visit(out.output)); + + for (auto input : node->inputs()) { + // Ensure that producers are inserted before this consumer + RETURN_NOT_OK(Visit(input)); } - visiting.erase(node); + + visiting.erase(it_success.first); visited.insert(node); sorted.push_back(node); return Status::OK(); } + + NodeVector Reverse() { + std::reverse(sorted.begin(), sorted.end()); + return std::move(sorted); + } } topo_sort(nodes_); RETURN_NOT_OK(topo_sort.Sort()); - return std::move(topo_sort.sorted); + return topo_sort.Reverse(); } std::vector> nodes_; NodeVector sources_; - NodeVector sinks_; }; ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast(ptr); } @@ -131,67 +139,68 @@ const ExecPlanImpl* ToDerived(const ExecPlan* ptr) { return checked_cast(ptr); } +util::optional GetNodeIndex(const std::vector& nodes, + const ExecNode* node) { + for (int i = 0; i < static_cast(nodes.size()); ++i) { + if (nodes[i] == node) return i; + } + return util::nullopt; +} + } // namespace Result> ExecPlan::Make() { return std::make_shared(); } -void ExecPlan::AddNode(std::unique_ptr node) { - ToDerived(this)->AddNode(std::move(node)); +ExecNode* ExecPlan::AddNode(std::unique_ptr node) { + return ToDerived(this)->AddNode(std::move(node)); } const ExecPlan::NodeVector& ExecPlan::sources() const { return ToDerived(this)->sources_; } -const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; } +ExecPlan::NodeVector ExecPlan::sinks() const { + NodeVector sinks; + for (const auto& node : ToDerived(this)->nodes_) { + if (node->output() == nullptr) { + sinks.push_back(node.get()); + } + } + return sinks; +} Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } -ExecNode::~ExecNode() = default; - ExecNode::ExecNode(ExecPlan* plan, std::string label) : plan_(plan), label_(std::move(label)) {} Status ExecNode::Validate() const { - if (inputs_.size() != static_cast(num_inputs())) { + if (inputs_.size() != input_descrs_.size()) { return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ", num_inputs(), ", actual ", inputs_.size(), ")"); } - if (input_descrs_.size() != static_cast(num_inputs())) { - return Status::Invalid("Invalid number of input descrs for '", label(), - "' (expected ", num_inputs(), ", actual ", - input_descrs_.size(), ")"); - } - if (outputs_.size() != static_cast(num_outputs())) { - return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", - num_outputs(), ", actual ", outputs_.size(), ")"); - } - if (output_descrs_.size() != static_cast(num_outputs())) { - return Status::Invalid("Invalid number of output descrs for '", label(), - "' (expected ", num_outputs(), ", actual ", - output_descrs_.size(), ")"); - } - for (size_t i = 0; i < outputs_.size(); ++i) { - const auto& out = outputs_[i]; - if (out.input_index >= static_cast(out.output->inputs_.size()) || - out.input_index >= static_cast(out.output->input_descrs_.size()) || - this != out.output->inputs_[out.input_index]) { - return Status::Invalid("Output node configuration for '", label(), - "' inconsistent with input node configuration for '", - out.output->label(), "'"); + + if (output_) { + auto input_index = GetNodeIndex(output_->inputs(), this); + if (!input_index) { + return Status::Invalid("Node '", label(), "' outputs to node '", output_->label(), + "' but is not listed as an input."); } - const auto& out_descr = output_descrs_[i]; - const auto& in_descr = out.output->input_descrs_[out.input_index]; - if (in_descr != out_descr) { + + const auto& in_descr = output_->input_descrs_[*input_index]; + if (in_descr != output_descr_) { return Status::Invalid( - "Output node produces batches with type '", ValueDescr::ToString(out_descr), - "' inconsistent with input node configuration for '", out.output->label(), "'"); + "Node '", label(), "' (bound to input ", input_labels_[*input_index], + ") produces batches with type '", ValueDescr::ToString(output_descr_), + "' inconsistent with consumer '", output_->label(), "' which accepts '", + ValueDescr::ToString(in_descr), "'"); } } + return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index dd133bb5c24..1b113008639 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -45,13 +45,18 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// Make an empty exec plan static Result> Make(); - void AddNode(std::unique_ptr node); + ExecNode* AddNode(std::unique_ptr node); + + template + ExecNode* EmplaceNode(Args&&... args) { + return AddNode(std::unique_ptr(new Node{std::forward(args)...})); + } /// The initial inputs const NodeVector& sources() const; /// The final outputs - const NodeVector& sinks() const; + NodeVector sinks() const; // XXX API question: // There are clearly two phases in the ExecPlan lifecycle: @@ -78,59 +83,44 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { class ARROW_EXPORT ExecNode { public: - struct OutputNode { - ExecNode* output; - // Index of corresponding input in `output` node - int input_index; - }; - using NodeVector = std::vector; - using OutputNodeVector = std::vector; using BatchDescr = std::vector; - virtual ~ExecNode(); + virtual ~ExecNode() = default; virtual const char* kind_name() = 0; - // The number of inputs and outputs expected by this node - // XXX should these simply return `input_descrs_.size()` - // (`output_descrs_.size()` respectively)? - virtual int num_inputs() const = 0; - virtual int num_outputs() const = 0; + + // The number of inputs expected by this node + int num_inputs() const { return static_cast(input_descrs_.size()); } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } - /// The datatypes for each input - // XXX Should it be std::vector? + /// The datatypes accepted by this node for each input const std::vector& input_descrs() const { return input_descrs_; } - /// This node's successors in the exec plan - const OutputNodeVector& outputs() const { return outputs_; } + /// \brief Labels identifying the function of each input. + /// + /// For example, FilterNode accepts "target" and "filter" inputs. + const std::string& input_labels() const { return input_labels_; } + + /// This node's successor in the exec plan + ExecNode* output() const { return output_; } - /// The datatypes for each output - // XXX Should it be std::vector? - const std::vector& output_descrs() const { return output_descrs_; } + /// The datatypes for batches produced by this node + const BatchDescr& output_descr() const { return output_descr_; } /// This node's exec plan ExecPlan* plan() { return plan_; } - std::shared_ptr plan_ref() { return plan_->shared_from_this(); } /// \brief An optional label, for display and debugging /// /// There is no guarantee that this value is non-empty or unique. const std::string& label() const { return label_; } - int AddInput(ExecNode* node) { - inputs_.push_back(node); - return static_cast(inputs_.size() - 1); - } - - void AddOutput(ExecNode* node, int input_index) { - outputs_.push_back({node, input_index}); - } - - static void Bind(ExecNode* input, ExecNode* output) { - input->AddOutput(output, output->AddInput(input)); + void AddInput(ExecNode* input) { + inputs_.push_back(input); + input->output_ = this; } Status Validate() const; @@ -148,17 +138,17 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) = 0; + virtual void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) = 0; /// Signal error to ExecNode - virtual void ErrorReceived(int input_index, Status error) = 0; + virtual void ErrorReceived(ExecNode* input, Status error) = 0; /// Mark the inputs finished after the given number of batches. /// /// This may be called before all inputs are received. This simply fixes /// the total number of incoming batches for an input, so that the ExecNode /// knows when it has received all input, regardless of order. - virtual void InputFinished(int input_index, int seq_stop) = 0; + virtual void InputFinished(ExecNode* input, int seq_stop) = 0; /// Lifecycle API: /// - start / stop to initiate and terminate production @@ -215,7 +205,7 @@ class ARROW_EXPORT ExecNode { /// /// This may be called any number of times after StartProducing() succeeds. /// However, the node is still free to produce data (which may be difficult - /// to prevent anyway if data is producer using multiple threads). + /// to prevent anyway if data is produced using multiple threads). virtual void PauseProducing(); /// \brief Resume producing after a temporary pause @@ -234,11 +224,15 @@ class ARROW_EXPORT ExecNode { ExecNode(ExecPlan* plan, std::string label); ExecPlan* plan_; + std::string label_; - NodeVector inputs_; - OutputNodeVector outputs_; + std::vector input_descrs_; - std::vector output_descrs_; + std::string input_labels_; + NodeVector inputs_; + + BatchDescr output_descr_; + ExecNode* output_ = NULLPTR; }; } // namespace compute diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 3066eda283d..ef039fb8070 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -51,55 +51,58 @@ TEST(ExecPlanConstruction, Empty) { TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/0); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0); ASSERT_OK(plan->Validate()); ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/0); + node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1); // Input not bound ASSERT_RAISES(Invalid, plan->Validate()); - - ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/1); - // Output not bound - ASSERT_RAISES(Invalid, plan->Validate()); } TEST(ExecPlanConstruction, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - // Input / output not bound + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); + EXPECT_EQ(source->num_inputs(), 0); + EXPECT_EQ(sink->num_inputs(), 1); + EXPECT_EQ(sink->inputs().size(), 0); + // Sink's input not bound ASSERT_RAISES(Invalid, plan->Validate()); - ExecNode::Bind(source, sink); + sink->AddInput(source); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source)); - ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); + EXPECT_THAT(plan->sources(), ::testing::ElementsAre(source)); + EXPECT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); } TEST(ExecPlanConstruction, MultipleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = - MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2); - auto source2 = - MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/1); - ExecNode::Bind(source1, process1); - ExecNode::Bind(source1, process2); - ExecNode::Bind(source2, process2); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1); - ExecNode::Bind(process1, process3); - ExecNode::Bind(process1, process3); - ExecNode::Bind(process2, process3); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - ExecNode::Bind(process3, sink); + + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0); + + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0); + + auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1); + + auto process2 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2); + + auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3); + + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); + + sink->AddInput(process3); + + process3->AddInput(process1); + process3->AddInput(process2); + process3->AddInput(process1); + + process2->AddInput(source1); + process2->AddInput(source2); + + process1->AddInput(source1); ASSERT_OK(plan->Validate()); ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source1, source2)); @@ -107,8 +110,7 @@ TEST(ExecPlanConstruction, MultipleNode) { } struct StartStopTracker { - std::vector started; - std::vector stopped; + std::vector started, stopped; StartProducingFunc start_producing_func(Status st = Status::OK()) { return [this, st](ExecNode* node) { @@ -126,28 +128,32 @@ TEST(ExecPlan, DummyStartProducing) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, t.start_producing_func(), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, t.start_producing_func(), t.stop_producing_func()); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, - t.start_producing_func(), t.stop_producing_func()); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(source1, process1); - ExecNode::Bind(process1, process2); - ExecNode::Bind(source2, process2); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(process1, process3); - ExecNode::Bind(source1, process3); - ExecNode::Bind(process2, process3); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, + + auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + + auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, + t.start_producing_func(), t.stop_producing_func()); + + auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, + t.start_producing_func(), t.stop_producing_func()); + + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(process3, sink); + + process1->AddInput(source1); + process2->AddInput(process1); + process2->AddInput(source2); + process3->AddInput(process1); + process3->AddInput(source1); + process3->AddInput(process2); + sink->AddInput(process3); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -156,34 +162,38 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_OK(plan->StartProducing()); // Note that any correct reverse topological order may do ASSERT_THAT(t.started, ::testing::ElementsAre("sink", "process3", "process2", - "process1", "source1", "source2")); + "process1", "source2", "source1")); ASSERT_EQ(t.stopped.size(), 0); } TEST(ExecPlan, DummyStartProducingCycle) { // A trivial cycle ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/1); - ExecNode::Bind(node, node); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1); + node->AddInput(node); ASSERT_OK(plan->Validate()); ASSERT_RAISES(Invalid, plan->StartProducing()); // A less trivial one ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1, /*num_outputs=*/1); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2, /*num_outputs=*/2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - ExecNode::Bind(source, process1); - ExecNode::Bind(process1, process2); - ExecNode::Bind(process2, process3); - ExecNode::Bind(process1, process3); - ExecNode::Bind(process3, process1); - ExecNode::Bind(process3, sink); + + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0); + + auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2); + + auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1); + + auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2); + + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); + + process1->AddInput(source); + process2->AddInput(process1); + process3->AddInput(process2); + process3->AddInput(process1); + process1->AddInput(process3); + sink->AddInput(process3); + ASSERT_OK(plan->Validate()); ASSERT_RAISES(Invalid, plan->StartProducing()); } @@ -192,29 +202,27 @@ TEST(ExecPlan, DummyStartProducingError) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, t.start_producing_func(), t.stop_producing_func()); - auto process1 = MakeDummyNode( - plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, - t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(source1, process1); - ExecNode::Bind(process1, process2); - ExecNode::Bind(source2, process2); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(process1, process3); - ExecNode::Bind(source1, process3); - ExecNode::Bind(process2, process3); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, + auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, + t.start_producing_func(Status::IOError("xxx")), + t.stop_producing_func()); + auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, + t.start_producing_func(), t.stop_producing_func()); + process1->AddInput(source1); + process2->AddInput(process1); + process2->AddInput(source2); + auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, + t.start_producing_func(), t.stop_producing_func()); + process3->AddInput(process1); + process3->AddInput(source1); + process3->AddInput(process2); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, t.start_producing_func(), t.stop_producing_func()); - ExecNode::Bind(process3, sink); + sink->AddInput(process3); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -299,7 +307,7 @@ class TestExecPlanExecution : public ::testing::Test { auto source = MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - ExecNode::Bind(source, sink); + sink->AddInput(source); return CollectorPlan{plan, sink}; } @@ -309,7 +317,7 @@ class TestExecPlanExecution : public ::testing::Test { auto source = MakeRecordBatchReaderNode(plan.get(), "source", schema, generator, io_executor_.get()); auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - ExecNode::Bind(source, sink); + sink->AddInput(source); return CollectorPlan{plan, sink}; } @@ -329,15 +337,13 @@ class TestExecPlanExecution : public ::testing::Test { template void TestSourceSink(RecordBatchReaderFactory batch_factory) { auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - // clang-format off RecordBatchVector batches{ RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, {"a": 4, "b": false}])"), RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, {"a": 6, "b": false}, - {"a": 7, "b": false}])") + {"a": 7, "b": false}])"), }; - // clang-format on ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index d099feb5255..b2e04ddea45 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -23,11 +23,12 @@ #include #include #include +#include #include #include "arrow/compute/exec.h" -#include "arrow/datum.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/datum.h" #include "arrow/record_batch.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" @@ -44,38 +45,31 @@ namespace { // TODO expose this as `static ValueDescr::FromSchemaColumns`? std::vector DescrFromSchemaColumns(const Schema& schema) { - std::vector descr; - descr.reserve(schema.num_fields()); - std::transform(schema.fields().begin(), schema.fields().end(), - std::back_inserter(descr), [](const std::shared_ptr& field) { + std::vector descr(schema.num_fields()); + std::transform(schema.fields().begin(), schema.fields().end(), descr.begin(), + [](const std::shared_ptr& field) { return ValueDescr::Array(field->type()); }); return descr; } struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, std::string label, int num_inputs, int num_outputs, + DummyNode(ExecPlan* plan, std::string label, int num_inputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) : ExecNode(plan, std::move(label)), - num_inputs_(num_inputs), - num_outputs_(num_outputs), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { input_descrs_.assign(num_inputs, descr()); - output_descrs_.assign(num_outputs, descr()); + output_descr_ = descr(); } const char* kind_name() override { return "RecordBatchReader"; } - int num_inputs() const override { return num_inputs_; } - - int num_outputs() const override { return num_outputs_; } - - void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) override {} + void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} - void ErrorReceived(int input_index, Status error) override {} + void ErrorReceived(ExecNode* input, Status error) override {} - void InputFinished(int input_index, int seq_stop) override {} + void InputFinished(ExecNode* input, int seq_stop) override {} Status StartProducing() override { if (start_producing_) { @@ -100,8 +94,6 @@ struct DummyNode : ExecNode { private: BatchDescr descr() const { return std::vector{ValueDescr(null())}; } - int num_inputs_; - int num_outputs_; StartProducingFunc start_producing_; StopProducingFunc stop_producing_; bool started_ = false; @@ -114,7 +106,7 @@ struct RecordBatchReaderNode : ExecNode { schema_(reader->schema()), reader_(std::move(reader)), io_executor_(io_executor) { - output_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + output_descr_ = DescrFromSchemaColumns(*schema_); } RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, @@ -123,20 +115,16 @@ struct RecordBatchReaderNode : ExecNode { schema_(std::move(schema)), io_executor_(io_executor), generator_(std::move(generator)) { - output_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + output_descr_ = DescrFromSchemaColumns(*schema_); } const char* kind_name() override { return "RecordBatchReader"; } - int num_inputs() const override { return 0; } + void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} - int num_outputs() const override { return 1; } + void ErrorReceived(ExecNode* input, Status error) override {} - void InputReceived(int input_index, int seq_num, compute::ExecBatch batch) override {} - - void ErrorReceived(int input_index, Status error) override {} - - void InputFinished(int input_index, int seq_stop) override {} + void InputFinished(ExecNode* input, int seq_stop) override {} Status StartProducing() override { next_batch_index_ = 0; @@ -162,7 +150,7 @@ struct RecordBatchReaderNode : ExecNode { // Stopped return; } - auto plan = plan_ref(); + auto plan = this->plan()->shared_from_this(); auto fut = generator_(); const auto batch_index = next_batch_index_++; @@ -172,20 +160,17 @@ struct RecordBatchReaderNode : ExecNode { .AddCallback( [plan, batch_index, this](const Result>& res) { std::unique_lock lock(mutex_); - DCHECK_EQ(outputs_.size(), 1); - OutputNode* out = &outputs_[0]; if (!res.ok()) { - out->output->ErrorReceived(out->input_index, res.status()); + output_->ErrorReceived(output_, res.status()); return; } const auto& batch = *res; if (IsIterationEnd(batch)) { lock.unlock(); - out->output->InputFinished(out->input_index, batch_index); + output_->InputFinished(output_, batch_index); } else { lock.unlock(); - out->output->InputReceived(out->input_index, batch_index, - compute::ExecBatch(*batch)); + output_->InputReceived(output_, batch_index, compute::ExecBatch(*batch)); lock.lock(); GenerateOne(std::move(lock)); } @@ -203,8 +188,8 @@ struct RecordBatchReaderNode : ExecNode { struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, - const std::shared_ptr& schema) - : RecordBatchCollectNode(plan, std::move(label)), schema_(schema) { + std::shared_ptr schema) + : RecordBatchCollectNode(plan, std::move(label)), schema_(std::move(schema)) { input_descrs_.push_back(DescrFromSchemaColumns(*schema_)); } @@ -212,10 +197,6 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { const char* kind_name() override { return "RecordBatchReader"; } - int num_inputs() const override { return 1; } - - int num_outputs() const override { return 0; } - Status StartProducing() override { num_received_ = 0; num_emitted_ = 0; @@ -230,7 +211,7 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { StopProducing(&lock); } - void InputReceived(int input_index, int seq_num, + void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch exec_batch) override { std::unique_lock lock(mutex_); if (stopped_) { @@ -286,13 +267,13 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { num_emitted_ = seq_num; } - void ErrorReceived(int input_index, Status error) override { + void ErrorReceived(ExecNode* input, Status error) override { // XXX do we care about properly sequencing the error? producer_->Push(std::move(error)); StopProducing(); } - void InputFinished(int input_index, int seq_stop) override { + void InputFinished(ExecNode* input, int seq_stop) override { std::unique_lock lock(mutex_); DCHECK_GE(seq_stop, static_cast(received_batches_.size())); received_batches_.reserve(seq_stop); @@ -343,36 +324,30 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr reader, Executor* io_executor) { - auto ptr = - new RecordBatchReaderNode(plan, std::move(label), std::move(reader), io_executor); - plan->AddNode(std::unique_ptr{ptr}); - return ptr; + return plan->EmplaceNode(plan, std::move(label), + std::move(reader), io_executor); } ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, RecordBatchGenerator generator, ::arrow::internal::Executor* io_executor) { - auto ptr = new RecordBatchReaderNode(plan, std::move(label), std::move(schema), - std::move(generator), io_executor); - plan->AddNode(std::unique_ptr{ptr}); - return ptr; + return plan->EmplaceNode( + plan, std::move(label), std::move(schema), std::move(generator), io_executor); } ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, - int num_outputs, StartProducingFunc start_producing, + StartProducingFunc start_producing, StopProducingFunc stop_producing) { - auto ptr = new DummyNode(plan, std::move(label), num_inputs, num_outputs, - std::move(start_producing), std::move(stop_producing)); - plan->AddNode(std::unique_ptr{ptr}); - return ptr; + return plan->EmplaceNode(plan, std::move(label), num_inputs, + std::move(start_producing), + std::move(stop_producing)); } RecordBatchCollectNode* MakeRecordBatchCollectNode( ExecPlan* plan, std::string label, const std::shared_ptr& schema) { - auto ptr = new RecordBatchCollectNodeImpl(plan, std::move(label), schema); - plan->AddNode(std::unique_ptr{ptr}); - return ptr; + return internal::checked_cast( + plan->EmplaceNode(plan, std::move(label), schema)); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index c4b4e5d79e6..37be6b80cff 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -37,14 +37,12 @@ using StopProducingFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, - int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); + StartProducingFunc = {}, StopProducingFunc = {}); using RecordBatchGenerator = AsyncGenerator>; -// Make a source node that produces record batches by reading in the background -// from a RecordBatchReader. -// 0 input -// 1 output (N columns) +// Make a source node (no inputs) that produces record batches by reading in the +// background from a RecordBatchReader. ARROW_TESTING_EXPORT ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr reader, From 6f5c7dd2c16b6ed51a4f44c7bc173bd6243ae3f4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Thu, 6 May 2021 16:42:25 -0400 Subject: [PATCH 5/9] fix input_labels() --- cpp/src/arrow/compute/exec/exec_plan.cc | 2 ++ cpp/src/arrow/compute/exec/exec_plan.h | 4 ++-- cpp/src/arrow/compute/exec/test_util.cc | 6 +++++- cpp/src/arrow/util/iterator.h | 6 ------ cpp/src/arrow/util/iterator_test.cc | 2 -- 5 files changed, 9 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 360065227c1..893ed46f639 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -184,6 +184,8 @@ Status ExecNode::Validate() const { num_inputs(), ", actual ", inputs_.size(), ")"); } + DCHECK_EQ(input_descrs_.size(), input_labels_.size()); + if (output_) { auto input_index = GetNodeIndex(output_->inputs(), this); if (!input_index) { diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 1b113008639..af3d0070358 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -102,7 +102,7 @@ class ARROW_EXPORT ExecNode { /// \brief Labels identifying the function of each input. /// /// For example, FilterNode accepts "target" and "filter" inputs. - const std::string& input_labels() const { return input_labels_; } + const std::vector& input_labels() const { return input_labels_; } /// This node's successor in the exec plan ExecNode* output() const { return output_; } @@ -228,7 +228,7 @@ class ARROW_EXPORT ExecNode { std::string label_; std::vector input_descrs_; - std::string input_labels_; + std::vector input_labels_; NodeVector inputs_; BatchDescr output_descr_; diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index b2e04ddea45..84067b0bd95 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -59,6 +59,9 @@ struct DummyNode : ExecNode { : ExecNode(plan, std::move(label)), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { + for (int i = 0; i < num_inputs; ++i) { + input_labels_.push_back(std::to_string(i)); + } input_descrs_.assign(num_inputs, descr()); output_descr_ = descr(); } @@ -129,7 +132,7 @@ struct RecordBatchReaderNode : ExecNode { Status StartProducing() override { next_batch_index_ = 0; if (!generator_) { - auto it = MakePointerIterator(reader_.get()); + auto it = MakeIteratorFromReader(reader_); ARROW_ASSIGN_OR_RAISE(generator_, MakeBackgroundGenerator(std::move(it), io_executor_)); } @@ -191,6 +194,7 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { std::shared_ptr schema) : RecordBatchCollectNode(plan, std::move(label)), schema_(std::move(schema)) { input_descrs_.push_back(DescrFromSchemaColumns(*schema_)); + input_labels_.emplace_back("batches_to_collect"); } RecordBatchGenerator generator() override { return generator_; } diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index 97ad0d73c35..b82021e4b21 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -370,12 +370,6 @@ Iterator MakeErrorIterator(Status s) { }); } -template ().Next())::ValueType> -Iterator MakePointerIterator(It* it) { - return MakeFunctionIterator([it]() -> Result { return it->Next(); }); -} - /// \brief Simple iterator which yields the elements of a std::vector template class VectorIterator { diff --git a/cpp/src/arrow/util/iterator_test.cc b/cpp/src/arrow/util/iterator_test.cc index dc2a2398729..ab62fcb7034 100644 --- a/cpp/src/arrow/util/iterator_test.cc +++ b/cpp/src/arrow/util/iterator_test.cc @@ -32,8 +32,6 @@ #include "arrow/util/test_common.h" #include "arrow/util/vector.h" -// TODO add test for MakePointerIterator - namespace arrow { template From 28742139a72e7a9a381534b4e99222ef0079bbe4 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 10 May 2021 12:39:17 -0400 Subject: [PATCH 6/9] revert to multiple outputs --- cpp/src/arrow/compute/exec/exec_plan.cc | 12 ++++++------ cpp/src/arrow/compute/exec/exec_plan.h | 10 +++++----- cpp/src/arrow/compute/exec/test_util.cc | 12 +++++++++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 893ed46f639..91671c33c07 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -164,7 +164,7 @@ const ExecPlan::NodeVector& ExecPlan::sources() const { ExecPlan::NodeVector ExecPlan::sinks() const { NodeVector sinks; for (const auto& node : ToDerived(this)->nodes_) { - if (node->output() == nullptr) { + if (node->outputs().empty()) { sinks.push_back(node.get()); } } @@ -186,19 +186,19 @@ Status ExecNode::Validate() const { DCHECK_EQ(input_descrs_.size(), input_labels_.size()); - if (output_) { - auto input_index = GetNodeIndex(output_->inputs(), this); + for (auto out : outputs_) { + auto input_index = GetNodeIndex(out->inputs(), this); if (!input_index) { - return Status::Invalid("Node '", label(), "' outputs to node '", output_->label(), + return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), "' but is not listed as an input."); } - const auto& in_descr = output_->input_descrs_[*input_index]; + const auto& in_descr = out->input_descrs_[*input_index]; if (in_descr != output_descr_) { return Status::Invalid( "Node '", label(), "' (bound to input ", input_labels_[*input_index], ") produces batches with type '", ValueDescr::ToString(output_descr_), - "' inconsistent with consumer '", output_->label(), "' which accepts '", + "' inconsistent with consumer '", out->label(), "' which accepts '", ValueDescr::ToString(in_descr), "'"); } } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index af3d0070358..21fe9aa541b 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -60,7 +60,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { // XXX API question: // There are clearly two phases in the ExecPlan lifecycle: - // - one construction phase where AddNode() and ExecNode::Bind() is called + // - one construction phase where AddNode() and ExecNode::AddInput() is called // (with optional validation at the end) // - one execution phase where the nodes are topo-sorted and then started // @@ -104,8 +104,8 @@ class ARROW_EXPORT ExecNode { /// For example, FilterNode accepts "target" and "filter" inputs. const std::vector& input_labels() const { return input_labels_; } - /// This node's successor in the exec plan - ExecNode* output() const { return output_; } + /// This node's successors in the exec plan + const NodeVector& outputs() const { return outputs_; } /// The datatypes for batches produced by this node const BatchDescr& output_descr() const { return output_descr_; } @@ -120,7 +120,7 @@ class ARROW_EXPORT ExecNode { void AddInput(ExecNode* input) { inputs_.push_back(input); - input->output_ = this; + input->outputs_.push_back(this); } Status Validate() const; @@ -232,7 +232,7 @@ class ARROW_EXPORT ExecNode { NodeVector inputs_; BatchDescr output_descr_; - ExecNode* output_ = NULLPTR; + NodeVector outputs_; }; } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 84067b0bd95..05d53276c03 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -164,16 +164,22 @@ struct RecordBatchReaderNode : ExecNode { [plan, batch_index, this](const Result>& res) { std::unique_lock lock(mutex_); if (!res.ok()) { - output_->ErrorReceived(output_, res.status()); + for (auto out : outputs_) { + out->ErrorReceived(this, res.status()); + } return; } const auto& batch = *res; if (IsIterationEnd(batch)) { lock.unlock(); - output_->InputFinished(output_, batch_index); + for (auto out : outputs_) { + out->InputFinished(this, batch_index); + } } else { lock.unlock(); - output_->InputReceived(output_, batch_index, compute::ExecBatch(*batch)); + for (auto out : outputs_) { + out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); + } lock.lock(); GenerateOne(std::move(lock)); } From cf294ce91474568b5d1c722ba5dce77a67b51a65 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 10 May 2021 16:21:35 -0400 Subject: [PATCH 7/9] reintroduce explicit num_outputs() --- cpp/src/arrow/compute/exec/exec_plan.cc | 44 +++++----- cpp/src/arrow/compute/exec/exec_plan.h | 23 +++-- cpp/src/arrow/compute/exec/plan_test.cc | 109 +++++++++++++----------- cpp/src/arrow/compute/exec/test_util.cc | 100 +++++++++++++++------- cpp/src/arrow/compute/exec/test_util.h | 2 +- 5 files changed, 161 insertions(+), 117 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 91671c33c07..f765ceccf0c 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -42,6 +42,9 @@ struct ExecPlanImpl : public ExecPlan { if (node->num_inputs() == 0) { sources_.push_back(node.get()); } + if (node->num_outputs() == 0) { + sinks_.push_back(node.get()); + } nodes_.push_back(std::move(node)); return nodes_.back().get(); } @@ -130,7 +133,7 @@ struct ExecPlanImpl : public ExecPlan { } std::vector> nodes_; - NodeVector sources_; + NodeVector sources_, sinks_; }; ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast(ptr); } @@ -161,22 +164,22 @@ const ExecPlan::NodeVector& ExecPlan::sources() const { return ToDerived(this)->sources_; } -ExecPlan::NodeVector ExecPlan::sinks() const { - NodeVector sinks; - for (const auto& node : ToDerived(this)->nodes_) { - if (node->outputs().empty()) { - sinks.push_back(node.get()); - } - } - return sinks; -} +const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; } Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } -ExecNode::ExecNode(ExecPlan* plan, std::string label) - : plan_(plan), label_(std::move(label)) {} +ExecNode::ExecNode(ExecPlan* plan, std::string label, + std::vector input_descrs, + std::vector input_labels, BatchDescr output_descr, + int num_outputs) + : plan_(plan), + label_(std::move(label)), + input_descrs_(std::move(input_descrs)), + input_labels_(std::move(input_labels)), + output_descr_(std::move(output_descr)), + num_outputs_(num_outputs) {} Status ExecNode::Validate() const { if (inputs_.size() != input_descrs_.size()) { @@ -184,6 +187,11 @@ Status ExecNode::Validate() const { num_inputs(), ", actual ", inputs_.size(), ")"); } + if (static_cast(outputs_.size()) != num_outputs_) { + return Status::Invalid("Invalid number of outputs for '", label(), "' (expected ", + num_outputs(), ", actual ", outputs_.size(), ")"); + } + DCHECK_EQ(input_descrs_.size(), input_labels_.size()); for (auto out : outputs_) { @@ -206,17 +214,5 @@ Status ExecNode::Validate() const { return Status::OK(); } -void ExecNode::PauseProducing() { - for (const auto& node : inputs_) { - node->PauseProducing(); - } -} - -void ExecNode::ResumeProducing() { - for (const auto& node : inputs_) { - node->ResumeProducing(); - } -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 21fe9aa541b..0d2faea0ddc 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -56,7 +56,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { const NodeVector& sources() const; /// The final outputs - NodeVector sinks() const; + const NodeVector& sinks() const; // XXX API question: // There are clearly two phases in the ExecPlan lifecycle: @@ -90,8 +90,9 @@ class ARROW_EXPORT ExecNode { virtual const char* kind_name() = 0; - // The number of inputs expected by this node + // The number of inputs/outputs expected by this node int num_inputs() const { return static_cast(input_descrs_.size()); } + int num_outputs() const { return num_outputs_; } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } @@ -187,9 +188,6 @@ class ARROW_EXPORT ExecNode { // - A method allows passing a ProductionHint asynchronously from an output node // (replacing PauseProducing(), ResumeProducing(), StopProducing()) - // TODO PauseProducing() etc. should probably take the index of the output which calls - // them? - /// \brief Start producing /// /// This must only be called once. If this fails, then other lifecycle @@ -206,7 +204,7 @@ class ARROW_EXPORT ExecNode { /// This may be called any number of times after StartProducing() succeeds. /// However, the node is still free to produce data (which may be difficult /// to prevent anyway if data is produced using multiple threads). - virtual void PauseProducing(); + virtual void PauseProducing(ExecNode* output) = 0; /// \brief Resume producing after a temporary pause /// @@ -215,13 +213,21 @@ class ARROW_EXPORT ExecNode { /// This may be called any number of times after StartProducing() succeeds. /// This may also be called concurrently with PauseProducing(), which suggests /// the implementation may use an atomic counter. - virtual void ResumeProducing(); + virtual void ResumeProducing(ExecNode* output) = 0; + + /// \brief Stop producing definitively to a single output + /// + /// This call is a hint that an output node has completed and is not willing + /// to not receive any further data. + virtual void StopProducing(ExecNode* output) = 0; /// \brief Stop producing definitively virtual void StopProducing() = 0; protected: - ExecNode(ExecPlan* plan, std::string label); + ExecNode(ExecPlan* plan, std::string label, std::vector input_descrs, + std::vector input_labels, BatchDescr output_descr, + int num_outputs); ExecPlan* plan_; @@ -232,6 +238,7 @@ class ARROW_EXPORT ExecNode { NodeVector inputs_; BatchDescr output_descr_; + int num_outputs_; NodeVector outputs_; }; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index ef039fb8070..d809409b28d 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -51,25 +51,27 @@ TEST(ExecPlanConstruction, Empty) { TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1); + node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/0); // Input not bound ASSERT_RAISES(Invalid, plan->Validate()); + + ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); + node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/1); + // Output not bound + ASSERT_RAISES(Invalid, plan->Validate()); } TEST(ExecPlanConstruction, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); - EXPECT_EQ(source->num_inputs(), 0); - EXPECT_EQ(sink->num_inputs(), 1); - EXPECT_EQ(sink->inputs().size(), 0); - // Sink's input not bound + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); + // Input / output not bound ASSERT_RAISES(Invalid, plan->Validate()); sink->AddInput(source); @@ -81,17 +83,22 @@ TEST(ExecPlanConstruction, SourceSink) { TEST(ExecPlanConstruction, MultipleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0); + auto source1 = + MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0); + auto source2 = + MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1); - auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2); - auto process2 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2); + auto process2 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/1); - auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); sink->AddInput(process3); @@ -128,23 +135,21 @@ TEST(ExecPlan, DummyStartProducing) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, t.start_producing_func(), t.stop_producing_func()); - - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - - auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - - auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, - t.start_producing_func(), t.stop_producing_func()); - - auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, - t.start_producing_func(), t.stop_producing_func()); - - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + t.start_producing_func(), t.stop_producing_func()); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, t.start_producing_func(), t.stop_producing_func()); process1->AddInput(source1); @@ -169,23 +174,21 @@ TEST(ExecPlan, DummyStartProducing) { TEST(ExecPlan, DummyStartProducingCycle) { // A trivial cycle ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1); + auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/1); node->AddInput(node); ASSERT_OK(plan->Validate()); ASSERT_RAISES(Invalid, plan->StartProducing()); // A less trivial one ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0); - - auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2); - - auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1); - - auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2); - - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1); + auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); + auto process1 = + MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/2); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1, /*num_outputs=*/1); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2, /*num_outputs=*/2); + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); process1->AddInput(source); process2->AddInput(process1); @@ -202,25 +205,27 @@ TEST(ExecPlan, DummyStartProducingError) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, + auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, + auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - auto process1 = MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, - t.start_producing_func(Status::IOError("xxx")), - t.stop_producing_func()); - auto process2 = MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, - t.start_producing_func(), t.stop_producing_func()); + auto process1 = MakeDummyNode( + plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); + auto process2 = + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); process1->AddInput(source1); process2->AddInput(process1); process2->AddInput(source2); - auto process3 = MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, - t.start_producing_func(), t.stop_producing_func()); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); process3->AddInput(process1); process3->AddInput(source1); process3->AddInput(process2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, + auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, t.start_producing_func(), t.stop_producing_func()); sink->AddInput(process3); @@ -335,7 +340,7 @@ class TestExecPlanExecution : public ::testing::Test { } template - void TestSourceSink(RecordBatchReaderFactory batch_factory) { + void TestSourceSink(RecordBatchReaderFactory reader_factory) { auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); RecordBatchVector batches{ RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, @@ -345,7 +350,7 @@ class TestExecPlanExecution : public ::testing::Test { {"a": 7, "b": false}])"), }; - ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); + ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(batches, schema)); ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); ASSERT_OK(cp.plan->Validate()); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 05d53276c03..f2cd7d2a740 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -26,6 +26,9 @@ #include #include +#include +#include + #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/datum.h" @@ -54,19 +57,18 @@ std::vector DescrFromSchemaColumns(const Schema& schema) { } struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, std::string label, int num_inputs, + DummyNode(ExecPlan* plan, std::string label, int num_inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(label)), + : ExecNode(plan, std::move(label), std::vector(num_inputs, descr()), {}, + descr(), num_outputs), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { for (int i = 0; i < num_inputs; ++i) { input_labels_.push_back(std::to_string(i)); } - input_descrs_.assign(num_inputs, descr()); - output_descr_ = descr(); } - const char* kind_name() override { return "RecordBatchReader"; } + const char* kind_name() override { return "Dummy"; } void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} @@ -82,11 +84,27 @@ struct DummyNode : ExecNode { return Status::OK(); } + void PauseProducing(ExecNode* output) override { + ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + AssertIsOutput(output); + } + + void ResumeProducing(ExecNode* output) override { + ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + AssertIsOutput(output); + } + + void StopProducing(ExecNode* output) override { + ASSERT_GE(num_outputs(), 0) << "Sink nodes should not experience backpressure"; + AssertIsOutput(output); + StopProducing(); + } + void StopProducing() override { if (started_) { started_ = false; for (const auto& input : inputs_) { - input->StopProducing(); + input->StopProducing(this); } if (stop_producing_) { stop_producing_(this); @@ -95,6 +113,10 @@ struct DummyNode : ExecNode { } private: + void AssertIsOutput(ExecNode* output) { + ASSERT_NE(std::find(outputs_.begin(), outputs_.end(), output), outputs_.end()); + } + BatchDescr descr() const { return std::vector{ValueDescr(null())}; } StartProducingFunc start_producing_; @@ -105,21 +127,19 @@ struct DummyNode : ExecNode { struct RecordBatchReaderNode : ExecNode { RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr reader, Executor* io_executor) - : ExecNode(plan, std::move(label)), + : ExecNode(plan, std::move(label), {}, {}, + DescrFromSchemaColumns(*reader->schema()), /*num_outputs=*/1), schema_(reader->schema()), reader_(std::move(reader)), - io_executor_(io_executor) { - output_descr_ = DescrFromSchemaColumns(*schema_); - } + io_executor_(io_executor) {} RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, RecordBatchGenerator generator, Executor* io_executor) - : ExecNode(plan, std::move(label)), + : ExecNode(plan, std::move(label), {}, {}, DescrFromSchemaColumns(*schema), + /*num_outputs=*/1), schema_(std::move(schema)), - io_executor_(io_executor), - generator_(std::move(generator)) { - output_descr_ = DescrFromSchemaColumns(*schema_); - } + generator_(std::move(generator)), + io_executor_(io_executor) {} const char* kind_name() override { return "RecordBatchReader"; } @@ -140,12 +160,17 @@ struct RecordBatchReaderNode : ExecNode { return Status::OK(); } - void StopProducing() override { + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + ASSERT_EQ(output, outputs_[0]); std::unique_lock lock(mutex_); generator_ = nullptr; // null function } - // TODO implement PauseProducing / ResumeProducing + void StopProducing() override { StopProducing(outputs_[0]); } private: void GenerateOne(std::unique_lock&& lock) { @@ -186,22 +211,21 @@ struct RecordBatchReaderNode : ExecNode { }); } + std::mutex mutex_; const std::shared_ptr schema_; const std::shared_ptr reader_; - Executor* const io_executor_; - - std::mutex mutex_; RecordBatchGenerator generator_; int next_batch_index_; + + Executor* const io_executor_; }; struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, std::shared_ptr schema) - : RecordBatchCollectNode(plan, std::move(label)), schema_(std::move(schema)) { - input_descrs_.push_back(DescrFromSchemaColumns(*schema_)); - input_labels_.emplace_back("batches_to_collect"); - } + : RecordBatchCollectNode(plan, std::move(label), {DescrFromSchemaColumns(*schema)}, + {"batches_to_collect"}, {}, 0), + schema_(std::move(schema)) {} RecordBatchGenerator generator() override { return generator_; } @@ -216,9 +240,20 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { return Status::OK(); } + // sink nodes have no outputs from which to feel backpressure + void ResumeProducing(ExecNode* output) override { + FAIL() << "no outputs; this should never be called"; + } + void PauseProducing(ExecNode* output) override { + FAIL() << "no outputs; this should never be called"; + } + void StopProducing(ExecNode* output) override { + FAIL() << "no outputs; this should never be called"; + } + void StopProducing() override { std::unique_lock lock(mutex_); - StopProducing(&lock); + StopProducingUnlocked(); } void InputReceived(ExecNode* input, int seq_num, @@ -250,7 +285,7 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { return; } if (num_received_ == emit_stop_) { - StopProducing(&lock); + StopProducingUnlocked(); } // Emit batches in order as far as possible @@ -280,7 +315,8 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { void ErrorReceived(ExecNode* input, Status error) override { // XXX do we care about properly sequencing the error? producer_->Push(std::move(error)); - StopProducing(); + std::unique_lock lock(mutex_); + StopProducingUnlocked(); } void InputFinished(ExecNode* input, int seq_stop) override { @@ -290,16 +326,16 @@ struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { emit_stop_ = seq_stop; if (emit_stop_ == num_received_) { DCHECK_EQ(emit_stop_, num_emitted_); - StopProducing(&lock); + StopProducingUnlocked(); } } private: - void StopProducing(std::unique_lock* lock) { + void StopProducingUnlocked() { if (!stopped_) { stopped_ = true; producer_->Close(); - inputs_[0]->StopProducing(); + inputs_[0]->StopProducing(this); } } @@ -347,9 +383,9 @@ ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, } ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, - StartProducingFunc start_producing, + int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) { - return plan->EmplaceNode(plan, std::move(label), num_inputs, + return plan->EmplaceNode(plan, std::move(label), num_inputs, num_outputs, std::move(start_producing), std::move(stop_producing)); } diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 37be6b80cff..c2dc785a501 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -37,7 +37,7 @@ using StopProducingFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, - StartProducingFunc = {}, StopProducingFunc = {}); + int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); using RecordBatchGenerator = AsyncGenerator>; From 7f1d5334e01325a74f3189a2f978d3b0b4faf996 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 10 May 2021 17:34:01 -0400 Subject: [PATCH 8/9] Use Loop in ExecPlanTest --- cpp/src/arrow/compute/exec/exec_plan.h | 2 +- cpp/src/arrow/compute/exec/test_util.cc | 79 ++++++++++++------------- cpp/src/arrow/util/future.h | 5 +- 3 files changed, 41 insertions(+), 45 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 0d2faea0ddc..4b70d920fcb 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -139,7 +139,7 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) = 0; + virtual void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) = 0; /// Signal error to ExecNode virtual void ErrorReceived(ExecNode* input, Status error) = 0; diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index f2cd7d2a740..83185063c45 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -33,6 +33,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/datum.h" #include "arrow/record_batch.h" +#include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" @@ -143,20 +144,50 @@ struct RecordBatchReaderNode : ExecNode { const char* kind_name() override { return "RecordBatchReader"; } - void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override {} void ErrorReceived(ExecNode* input, Status error) override {} void InputFinished(ExecNode* input, int seq_stop) override {} Status StartProducing() override { - next_batch_index_ = 0; + if (!stopped_) return Status::OK(); + if (!generator_) { auto it = MakeIteratorFromReader(reader_); ARROW_ASSIGN_OR_RAISE(generator_, MakeBackgroundGenerator(std::move(it), io_executor_)); } - GenerateOne(std::unique_lock{mutex_}); + + next_batch_index_ = 0; + stopped_ = false; + + (void)Loop([&] { + return io_executor_->Transfer(generator_()) + .Then( + [&](const std::shared_ptr& batch) -> ControlFlow { + std::unique_lock lock(mutex_); + int batch_index = next_batch_index_++; + if (stopped_) return Break(batch_index); + if (IsIterationEnd(batch)) return Break(batch_index); + lock.unlock(); + + for (auto out : outputs_) { + out->InputReceived(this, batch_index, ExecBatch(*batch)); + } + return Continue(); + }, + [&](const Status& err) { + for (auto out : outputs_) { + out->ErrorReceived(this, err); + } + return Break(0); + }); + }).Then([&](const util::optional& batch_index) { + for (auto out : outputs_) { + out->InputFinished(this, *batch_index); + } + }); return Status::OK(); } @@ -167,54 +198,18 @@ struct RecordBatchReaderNode : ExecNode { void StopProducing(ExecNode* output) override { ASSERT_EQ(output, outputs_[0]); std::unique_lock lock(mutex_); - generator_ = nullptr; // null function + stopped_ = true; + lock.unlock(); } void StopProducing() override { StopProducing(outputs_[0]); } private: - void GenerateOne(std::unique_lock&& lock) { - if (!generator_) { - // Stopped - return; - } - auto plan = this->plan()->shared_from_this(); - auto fut = generator_(); - const auto batch_index = next_batch_index_++; - - lock.unlock(); - // TODO we want to transfer always here - io_executor_->Transfer(std::move(fut)) - .AddCallback( - [plan, batch_index, this](const Result>& res) { - std::unique_lock lock(mutex_); - if (!res.ok()) { - for (auto out : outputs_) { - out->ErrorReceived(this, res.status()); - } - return; - } - const auto& batch = *res; - if (IsIterationEnd(batch)) { - lock.unlock(); - for (auto out : outputs_) { - out->InputFinished(this, batch_index); - } - } else { - lock.unlock(); - for (auto out : outputs_) { - out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); - } - lock.lock(); - GenerateOne(std::move(lock)); - } - }); - } - std::mutex mutex_; const std::shared_ptr schema_; const std::shared_ptr reader_; RecordBatchGenerator generator_; + bool stopped_ = true; int next_batch_index_; Executor* const io_executor_; diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 4c8de912f81..a3c7e5054c7 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -703,8 +703,9 @@ using ControlFlow = util::optional; /// /// \param[in] iterate A generator of Future>. On completion of /// each yielded future the resulting ControlFlow will be examined. A Break will terminate -/// the loop, while a Continue will re-invoke `iterate`. \return A future which will -/// complete when a Future returned by iterate completes with a Break +/// the loop, while a Continue will re-invoke `iterate`. +/// \return A future which will complete when a Future returned by iterate completes with +/// a Break template ::ValueType, typename BreakValueType = typename Control::value_type> From 46dcc4899395247a4519ef13c4009f28df5e2ce5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Tue, 11 May 2021 12:08:54 -0400 Subject: [PATCH 9/9] Revert "Use Loop in ExecPlanTest" This reverts commit 7f1d5334e01325a74f3189a2f978d3b0b4faf996. --- cpp/src/arrow/compute/exec/exec_plan.h | 2 +- cpp/src/arrow/compute/exec/test_util.cc | 79 +++++++++++++------------ cpp/src/arrow/util/future.h | 5 +- 3 files changed, 45 insertions(+), 41 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 4b70d920fcb..0d2faea0ddc 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -139,7 +139,7 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) = 0; + virtual void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) = 0; /// Signal error to ExecNode virtual void ErrorReceived(ExecNode* input, Status error) = 0; diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 83185063c45..f2cd7d2a740 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -33,7 +33,6 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/datum.h" #include "arrow/record_batch.h" -#include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" @@ -144,50 +143,20 @@ struct RecordBatchReaderNode : ExecNode { const char* kind_name() override { return "RecordBatchReader"; } - void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override {} + void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} void ErrorReceived(ExecNode* input, Status error) override {} void InputFinished(ExecNode* input, int seq_stop) override {} Status StartProducing() override { - if (!stopped_) return Status::OK(); - + next_batch_index_ = 0; if (!generator_) { auto it = MakeIteratorFromReader(reader_); ARROW_ASSIGN_OR_RAISE(generator_, MakeBackgroundGenerator(std::move(it), io_executor_)); } - - next_batch_index_ = 0; - stopped_ = false; - - (void)Loop([&] { - return io_executor_->Transfer(generator_()) - .Then( - [&](const std::shared_ptr& batch) -> ControlFlow { - std::unique_lock lock(mutex_); - int batch_index = next_batch_index_++; - if (stopped_) return Break(batch_index); - if (IsIterationEnd(batch)) return Break(batch_index); - lock.unlock(); - - for (auto out : outputs_) { - out->InputReceived(this, batch_index, ExecBatch(*batch)); - } - return Continue(); - }, - [&](const Status& err) { - for (auto out : outputs_) { - out->ErrorReceived(this, err); - } - return Break(0); - }); - }).Then([&](const util::optional& batch_index) { - for (auto out : outputs_) { - out->InputFinished(this, *batch_index); - } - }); + GenerateOne(std::unique_lock{mutex_}); return Status::OK(); } @@ -198,18 +167,54 @@ struct RecordBatchReaderNode : ExecNode { void StopProducing(ExecNode* output) override { ASSERT_EQ(output, outputs_[0]); std::unique_lock lock(mutex_); - stopped_ = true; - lock.unlock(); + generator_ = nullptr; // null function } void StopProducing() override { StopProducing(outputs_[0]); } private: + void GenerateOne(std::unique_lock&& lock) { + if (!generator_) { + // Stopped + return; + } + auto plan = this->plan()->shared_from_this(); + auto fut = generator_(); + const auto batch_index = next_batch_index_++; + + lock.unlock(); + // TODO we want to transfer always here + io_executor_->Transfer(std::move(fut)) + .AddCallback( + [plan, batch_index, this](const Result>& res) { + std::unique_lock lock(mutex_); + if (!res.ok()) { + for (auto out : outputs_) { + out->ErrorReceived(this, res.status()); + } + return; + } + const auto& batch = *res; + if (IsIterationEnd(batch)) { + lock.unlock(); + for (auto out : outputs_) { + out->InputFinished(this, batch_index); + } + } else { + lock.unlock(); + for (auto out : outputs_) { + out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); + } + lock.lock(); + GenerateOne(std::move(lock)); + } + }); + } + std::mutex mutex_; const std::shared_ptr schema_; const std::shared_ptr reader_; RecordBatchGenerator generator_; - bool stopped_ = true; int next_batch_index_; Executor* const io_executor_; diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index a3c7e5054c7..4c8de912f81 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -703,9 +703,8 @@ using ControlFlow = util::optional; /// /// \param[in] iterate A generator of Future>. On completion of /// each yielded future the resulting ControlFlow will be examined. A Break will terminate -/// the loop, while a Continue will re-invoke `iterate`. -/// \return A future which will complete when a Future returned by iterate completes with -/// a Break +/// the loop, while a Continue will re-invoke `iterate`. \return A future which will +/// complete when a Future returned by iterate completes with a Break template ::ValueType, typename BreakValueType = typename Control::value_type>