diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 73cb82ef026..78f3d753711 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -36,6 +36,7 @@ #include "arrow/compute/registry.h" #include "arrow/compute/util_internal.h" #include "arrow/datum.h" +#include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/scalar.h" #include "arrow/status.h" @@ -69,6 +70,48 @@ ExecBatch::ExecBatch(const RecordBatch& batch) std::move(columns.begin(), columns.end(), values.begin()); } +bool ExecBatch::Equals(const ExecBatch& other) const { + return guarantee == other.guarantee && values == other.values; +} + +void PrintTo(const ExecBatch& batch, std::ostream* os) { + *os << "ExecBatch\n"; + + static const std::string indent = " "; + + *os << indent << "# Rows: " << batch.length << "\n"; + if (batch.guarantee != literal(true)) { + *os << indent << "Guarantee: " << batch.guarantee.ToString() << "\n"; + } + + int i = 0; + for (const Datum& value : batch.values) { + *os << indent << "" << i++ << ": "; + + if (value.is_scalar()) { + *os << "Scalar[" << value.scalar()->ToString() << "]\n"; + continue; + } + + auto array = value.make_array(); + PrettyPrintOptions options; + options.skip_new_lines = true; + *os << "Array"; + ARROW_CHECK_OK(PrettyPrint(*array, options, os)); + *os << "\n"; + } +} + +ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const { + ExecBatch out = *this; + for (auto& value : out.values) { + if (value.is_scalar()) continue; + value = value.array()->Slice(offset, length); + } + out.length = length; + return out; +} + Result ExecBatch::Make(std::vector values) { if (values.empty()) { return Status::Invalid("Cannot infer ExecBatch length without at least one value"); @@ -77,9 +120,6 @@ Result ExecBatch::Make(std::vector values) { int64_t length = -1; for (const auto& value : values) { if (value.is_scalar()) { - if (length == -1) { - length = 1; - } continue; } @@ -94,8 +134,13 @@ Result ExecBatch::Make(std::vector values) { } } + if (length == -1) { + length = 1; + } + return ExecBatch(std::move(values), length); } + namespace { Result> AllocateDataBuffer(KernelContext* ctx, int64_t length, diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index cd95db2fd8c..e7015814d2a 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -28,6 +28,7 @@ #include #include "arrow/array/data.h" +#include "arrow/compute/exec/expression.h" #include "arrow/datum.h" #include "arrow/memory_pool.h" #include "arrow/result.h" @@ -186,6 +187,9 @@ struct ARROW_EXPORT ExecBatch { /// ExecBatch::length is equal to the length of this array. std::shared_ptr selection_vector; + /// A predicate Expression guaranteed to evaluate to true for all rows in this batch. + Expression guarantee = literal(true); + /// The semantic length of the ExecBatch. When the values are all scalars, /// the length should be set to 1, otherwise the length is taken from the /// array values, except when there is a selection vector. When there is a @@ -203,9 +207,13 @@ struct ARROW_EXPORT ExecBatch { return values[i]; } + bool Equals(const ExecBatch& other) const; + /// \brief A convenience for the number of values / arguments. int num_values() const { return static_cast(values.size()); } + ExecBatch Slice(int64_t offset, int64_t length) const; + /// \brief A convenience for returning the ValueDescr objects (types and /// shapes) from the batch. std::vector GetDescriptors() const { @@ -215,8 +223,13 @@ struct ARROW_EXPORT ExecBatch { } return result; } + + ARROW_EXPORT friend void PrintTo(const ExecBatch&, std::ostream*); }; +inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals(r); } +inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } + /// \defgroup compute-call-function One-shot calls to compute functions /// /// @{ diff --git a/cpp/src/arrow/compute/exec/doc/exec_node.md b/cpp/src/arrow/compute/exec/doc/exec_node.md new file mode 100644 index 00000000000..797cc87d90a --- /dev/null +++ b/cpp/src/arrow/compute/exec/doc/exec_node.md @@ -0,0 +1,147 @@ + + +# ExecNodes and logical operators + +`ExecNode`s are intended to implement individual logical operators +in a streaming execution graph. Each node receives batches from +upstream nodes (inputs), processes them in some way, then pushes +results to downstream nodes (outputs). `ExecNode`s are owned and +(to an extent) coordinated by an `ExecPlan`. + +> Terminology: "operator" and "node" are mostly interchangable, like +> "Interface" and "Abstract Base Class" in c++ space. The latter is +> a formal and specific bit of code which implements the abstract +> concept. + +## Types of logical operators + +Each of these will have at least one corresponding concrete +`ExecNode`. Where possible, compatible implementations of a +logical operator will *not* be exposed as independent subclasses +of `ExecNode`. Instead we prefer that they be +be encapsulated internally by a single subclass of `ExecNode` +to permit switching between them during a query. + +- Scan: materializes in-memory batches from storage (e.g. Parquet + files, flight stream, ...) +- Filter: evaluates an `Expression` on each input batch and outputs + a copy with any rows excluded for which the filter did not return + `true`. +- Project: evaluates `Expression`s on each input batch to produce + the columns of an output batch. +- Grouped Aggregate: identify groups based on one or more key columns + in each input batch, then update aggregates corresponding to those + groups. Node that this is a pipeline breaker; it will wait for its + inputs to complete before outputting any batches. +- Union: merge two or more streams of batches into a single stream + of batches. +- Write: write each batch to storage +- ToTable: Collect batches into a `Table` with stable row ordering where + possible. + +#### Not in scope for Arrow 5.0: + +- Join: perform an inner, left, outer, semi, or anti join given some + join predicates. +- Sort: accumulate all input batches into a single table, reorder its + rows by some sorting condition, then stream the sorted table out as + batches +- Top-K: retrieve a limited subset of rows from a table as though it + were in sorted order. + +For example: a dataset scan with only a filter and a +projection will correspond to a fairly trivial graph: + +``` +ScanNode -> FilterNode -> ProjectNode -> ToTableNode +``` + +A scan node loads batches from disk and pushes to a filter node. +The filter node excludes some rows based on an `Expression` then +pushes filtered batches to a project node. The project node +materializes new columns based on `Expression`s then pushes those +batches to a table collection node. The table collection node +assembles these batches into a `Table` which is handed off as the +result of the `ExecPlan`. + +## Parallelism, pipelines + +The execution graph is orthogonal to parallelism; any +node may push to any other node from any thread. A scan node causes +each batch to arrive on a thread after which it will pass through +each node in the example graph above, never leaving that thread +(memory/other resource pressure permitting). + +The example graph above happens to be simple enough that processing +of any batch by any node is independent of other nodes and other +batches; it is a pipeline. Note that there is no explicit `Pipeline` +class- pipelined execution is an emergent property of some sub +graphs. + +Nodes which do not share this property (pipeline breakers) are +responsible for deciding when they have received sufficient input, +when they can start emitting output, etc. For example a `GroupByNode` +will wait for its input to be exhausted before it begins pushing +batches to its own outputs. + +Parallelism is "seeded" by `ScanNode` (or other source nodes)- it +owns a reference to the thread pool on which the graph is executing +and fans out pushing to its outputs across that pool. A subsequent +`ProjectNode` will process the batch immediately after it is handed +off by the `ScanNode`- no explicit scheduling required. +Eventually, individual nodes may internally +parallelize processing of individual batches (for example, if a +`FilterNode`'s filter expression is slow). This decision is also left +up to each `ExecNode` implementation. + +# ExecNode interface and usage + +`ExecNode`s are constructed using one of the available factory +functions, such as `arrow::compute::MakeFilterNode` +or `arrow::dataset::MakeScanNode`. Any inputs to an `ExecNode` +must be provided when the node is constructed, so the first +nodes to be constructed are source nodes with no inputs +such as `ScanNode`. + +The batches yielded by an `ExecNode` always conform precisely +to its output schema. NB: no by-name field lookups or type +checks are performed during execution. The output schema +is usually derived from the output schemas of inputs. For +example a `FilterNode`'s output schema is always identical to +that of its input since batches are only modified by exclusion +of some rows. + +An `ExecNode` will begin producing batches when +`node->StartProducing()` is invoked and will proceed until stopped +with `node->StopProducing()`. Started nodes may not be destroyed +until stopped. `ExecNode`s are not currently restartable. +An `ExecNode` pushes batches to its outputs by passing each batch +to `output->InputReceived()`. It signals exhaustion by invoking +`output->InputFinished()`. + +Error recovery is permitted within a node. For example, if evaluation +of an `Expression` runs out of memory the governing node may +try that evaluation again after some memory has been freed up. +If a node experiences an error from which it cannot recover (for +example an IO error while parsing a CSV file) then it reports this +with `output->ErrorReceived()`. An error which escapes the scope of +a single node should not be considered recoverable (no `FilterNode` +should `try/catch` the IO error above). + diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index f765ceccf0c..2dcbfb24724 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,10 +17,15 @@ #include "arrow/compute/exec/exec_plan.h" +#include #include +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/expression.h" #include "arrow/datum.h" #include "arrow/result.h" +#include "arrow/util/async_generator.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -36,7 +41,11 @@ namespace { struct ExecPlanImpl : public ExecPlan { ExecPlanImpl() = default; - ~ExecPlanImpl() override = default; + ~ExecPlanImpl() override { + if (started_ && !stopped_) { + StopProducing(); + } + } ExecNode* AddNode(std::unique_ptr node) { if (node->num_inputs() == 0) { @@ -60,79 +69,73 @@ struct ExecPlanImpl : public ExecPlan { } Status StartProducing() { - ARROW_ASSIGN_OR_RAISE(auto sorted_nodes, ReverseTopoSort()); - Status st; - auto it = sorted_nodes.begin(); - while (it != sorted_nodes.end() && st.ok()) { - st &= (*it++)->StartProducing(); + if (started_) { + return Status::Invalid("restarted ExecPlan"); } - if (!st.ok()) { + started_ = true; + + // producers precede consumers + sorted_nodes_ = TopoSort(); + + for (size_t i = 0, rev_i = sorted_nodes_.size() - 1; i < sorted_nodes_.size(); + ++i, --rev_i) { + auto st = sorted_nodes_[rev_i]->StartProducing(); + if (st.ok()) continue; + // Stop nodes that successfully started, in reverse order - // (`it` now points after the node that failed starting, so need to rewind) - --it; - while (it != sorted_nodes.begin()) { - (*--it)->StopProducing(); + for (; rev_i < sorted_nodes_.size(); ++rev_i) { + sorted_nodes_[rev_i]->StopProducing(); } + return st; } - return st; + return Status::OK(); } - Result ReverseTopoSort() { - struct TopoSort { + void StopProducing() { + DCHECK(started_) << "stopped an ExecPlan which never started"; + stopped_ = true; + + for (const auto& node : sorted_nodes_) { + node->StopProducing(); + } + } + + NodeVector TopoSort() { + struct Impl { const std::vector>& nodes; std::unordered_set visited; - std::unordered_set visiting; NodeVector sorted; - explicit TopoSort(const std::vector>& nodes) - : nodes(nodes) { + explicit Impl(const std::vector>& nodes) : nodes(nodes) { visited.reserve(nodes.size()); - sorted.reserve(nodes.size()); - } + sorted.resize(nodes.size()); - Status Sort() { for (const auto& node : nodes) { - RETURN_NOT_OK(Visit(node.get())); + Visit(node.get()); } - DCHECK_EQ(sorted.size(), nodes.size()); + DCHECK_EQ(visited.size(), nodes.size()); - DCHECK_EQ(visiting.size(), 0); - return Status::OK(); } - Status Visit(ExecNode* node) { - if (visited.count(node) != 0) { - return Status::OK(); - } - - auto it_success = visiting.insert(node); - if (!it_success.second) { - // Insertion failed => node is already being visited - return Status::Invalid("Cycle detected in execution plan"); - } + void Visit(ExecNode* node) { + if (visited.count(node) != 0) return; for (auto input : node->inputs()) { // Ensure that producers are inserted before this consumer - RETURN_NOT_OK(Visit(input)); + Visit(input); } - visiting.erase(it_success.first); + sorted[visited.size()] = node; visited.insert(node); - sorted.push_back(node); - return Status::OK(); - } - - NodeVector Reverse() { - std::reverse(sorted.begin(), sorted.end()); - return std::move(sorted); } - } topo_sort(nodes_); + }; - RETURN_NOT_OK(topo_sort.Sort()); - return topo_sort.Reverse(); + return std::move(Impl{nodes_}.sorted); } + bool started_ = false, stopped_ = false; std::vector> nodes_; + NodeVector sorted_nodes_; NodeVector sources_, sinks_; }; @@ -170,21 +173,26 @@ Status ExecPlan::Validate() { return ToDerived(this)->Validate(); } Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); } -ExecNode::ExecNode(ExecPlan* plan, std::string label, - std::vector input_descrs, - std::vector input_labels, BatchDescr output_descr, - int num_outputs) +void ExecPlan::StopProducing() { ToDerived(this)->StopProducing(); } + +ExecNode::ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, + std::vector input_labels, + std::shared_ptr output_schema, int num_outputs) : plan_(plan), label_(std::move(label)), - input_descrs_(std::move(input_descrs)), + inputs_(std::move(inputs)), input_labels_(std::move(input_labels)), - output_descr_(std::move(output_descr)), - num_outputs_(num_outputs) {} + output_schema_(std::move(output_schema)), + num_outputs_(num_outputs) { + for (auto input : inputs_) { + input->outputs_.push_back(this); + } +} Status ExecNode::Validate() const { - if (inputs_.size() != input_descrs_.size()) { + if (inputs_.size() != input_labels_.size()) { return Status::Invalid("Invalid number of inputs for '", label(), "' (expected ", - num_inputs(), ", actual ", inputs_.size(), ")"); + num_inputs(), ", actual ", input_labels_.size(), ")"); } if (static_cast(outputs_.size()) != num_outputs_) { @@ -192,26 +200,369 @@ Status ExecNode::Validate() const { num_outputs(), ", actual ", outputs_.size(), ")"); } - DCHECK_EQ(input_descrs_.size(), input_labels_.size()); - for (auto out : outputs_) { auto input_index = GetNodeIndex(out->inputs(), this); if (!input_index) { return Status::Invalid("Node '", label(), "' outputs to node '", out->label(), "' but is not listed as an input."); } + } + + return Status::OK(); +} - const auto& in_descr = out->input_descrs_[*input_index]; - if (in_descr != output_descr_) { - return Status::Invalid( - "Node '", label(), "' (bound to input ", input_labels_[*input_index], - ") produces batches with type '", ValueDescr::ToString(output_descr_), - "' inconsistent with consumer '", out->label(), "' which accepts '", - ValueDescr::ToString(in_descr), "'"); +struct SourceNode : ExecNode { + SourceNode(ExecPlan* plan, std::string label, std::shared_ptr output_schema, + AsyncGenerator> generator) + : ExecNode(plan, std::move(label), {}, {}, std::move(output_schema), + /*num_outputs=*/1), + generator_(std::move(generator)) {} + + const char* kind_name() override { return "SourceNode"; } + + static void NoInputs() { DCHECK(false) << "no inputs; this should never be called"; } + void InputReceived(ExecNode*, int, ExecBatch) override { NoInputs(); } + void ErrorReceived(ExecNode*, Status) override { NoInputs(); } + void InputFinished(ExecNode*, int) override { NoInputs(); } + + Status StartProducing() override { + if (finished_) { + return Status::Invalid("Restarted SourceNode '", label(), "'"); } + + finished_fut_ = + Loop([this] { + std::unique_lock lock(mutex_); + int seq = next_batch_index_++; + if (finished_) { + return Future>::MakeFinished(Break(seq)); + } + lock.unlock(); + + return generator_().Then( + [=](const util::optional& batch) -> ControlFlow { + std::unique_lock lock(mutex_); + if (!batch || finished_) { + finished_ = true; + return Break(seq); + } + lock.unlock(); + + // TODO check if we are on the desired Executor and transfer if not. + // This can happen for in-memory scans where batches didn't require + // any CPU work to decode. Otherwise, parsing etc should have already + // been placed us on the thread pool + outputs_[0]->InputReceived(this, seq, *batch); + return Continue(); + }, + [=](const Status& error) -> ControlFlow { + std::unique_lock lock(mutex_); + if (!finished_) { + finished_ = true; + lock.unlock(); + // unless we were already finished, push the error to our output + // XXX is this correct? Is it reasonable for a consumer to + // ignore errors from a finished producer? + outputs_[0]->ErrorReceived(this, error); + } + return Break(seq); + }); + }).Then([&](int seq) { + /// XXX this is probably redundant: do we always call InputFinished after + /// ErrorReceived or will ErrorRecieved be sufficient? + outputs_[0]->InputFinished(this, seq); + }); + + return Status::OK(); } - return Status::OK(); + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + { + std::unique_lock lock(mutex_); + finished_ = true; + } + finished_fut_.Wait(); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + std::mutex mutex_; + bool finished_{false}; + int next_batch_index_{0}; + Future<> finished_fut_ = Future<>::MakeFinished(); + AsyncGenerator> generator_; +}; + +ExecNode* MakeSourceNode(ExecPlan* plan, std::string label, + std::shared_ptr output_schema, + AsyncGenerator> generator) { + return plan->EmplaceNode(plan, std::move(label), std::move(output_schema), + std::move(generator)); +} + +struct FilterNode : ExecNode { + FilterNode(ExecNode* input, std::string label, Expression filter) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_schema=*/input->output_schema(), + /*num_outputs=*/1), + filter_(std::move(filter)) {} + + const char* kind_name() override { return "FilterNode"; } + + Result DoFilter(const ExecBatch& target) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_filter, + SimplifyWithGuarantee(filter_, target.guarantee)); + + // XXX get a non-default exec context + ARROW_ASSIGN_OR_RAISE(Datum mask, ExecuteScalarExpression(simplified_filter, target)); + + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + return target; + } + + return target.Slice(0, 0); + } + + auto values = target.values; + for (auto& value : values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, Filter(value, mask, FilterOptions::Defaults())); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_filtered = DoFilter(std::move(batch)); + if (!maybe_filtered.ok()) { + outputs_[0]->ErrorReceived(this, maybe_filtered.status()); + inputs_[0]->StopProducing(this); + return; + } + + maybe_filtered->guarantee = batch.guarantee; + outputs_[0]->InputReceived(this, seq, maybe_filtered.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + } + + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + Expression filter_; +}; + +Result MakeFilterNode(ExecNode* input, std::string label, Expression filter) { + if (!filter.IsBound()) { + ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(*input->output_schema())); + } + + if (filter.type()->id() != Type::BOOL) { + return Status::TypeError("Filter expression must evaluate to bool, but ", + filter.ToString(), " evaluates to ", + filter.type()->ToString()); + } + + return input->plan()->EmplaceNode(input, std::move(label), + std::move(filter)); +} + +struct ProjectNode : ExecNode { + ProjectNode(ExecNode* input, std::string label, std::shared_ptr output_schema, + std::vector exprs) + : ExecNode(input->plan(), std::move(label), {input}, {"target"}, + /*output_schema=*/std::move(output_schema), + /*num_outputs=*/1), + exprs_(std::move(exprs)) {} + + const char* kind_name() override { return "ProjectNode"; } + + Result DoProject(const ExecBatch& target) { + // XXX get a non-default exec context + std::vector values{exprs_.size()}; + for (size_t i = 0; i < exprs_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, + SimplifyWithGuarantee(exprs_[i], target.guarantee)); + + ARROW_ASSIGN_OR_RAISE(values[i], ExecuteScalarExpression(simplified_expr, target)); + } + return ExecBatch::Make(std::move(values)); + } + + void InputReceived(ExecNode* input, int seq, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + auto maybe_projected = DoProject(std::move(batch)); + if (!maybe_projected.ok()) { + outputs_[0]->ErrorReceived(this, maybe_projected.status()); + inputs_[0]->StopProducing(this); + return; + } + + maybe_projected->guarantee = batch.guarantee; + outputs_[0]->InputReceived(this, seq, maybe_projected.MoveValueUnsafe()); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + inputs_[0]->StopProducing(this); + } + + void InputFinished(ExecNode* input, int seq) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, seq); + } + + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + inputs_[0]->StopProducing(this); + } + + void StopProducing() override { StopProducing(outputs_[0]); } + + private: + std::vector exprs_; +}; + +Result MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs) { + FieldVector fields(exprs.size()); + + int i = 0; + for (auto& expr : exprs) { + if (!expr.IsBound()) { + ARROW_ASSIGN_OR_RAISE(expr, expr.Bind(*input->output_schema())); + } + fields[i] = field(expr.ToString(), expr.type()); + ++i; + } + + return input->plan()->EmplaceNode( + input, std::move(label), schema(std::move(fields)), std::move(exprs)); +} + +struct SinkNode : ExecNode { + SinkNode(ExecNode* input, std::string label, + AsyncGenerator>* generator) + : ExecNode(input->plan(), std::move(label), {input}, {"collected"}, {}, + /*num_outputs=*/0), + producer_(MakeProducer(generator)) {} + + static PushGenerator>::Producer MakeProducer( + AsyncGenerator>* out_gen) { + PushGenerator> gen; + auto out = gen.producer(); + *out_gen = std::move(gen); + return out; + } + + const char* kind_name() override { return "SinkNode"; } + + Status StartProducing() override { return Status::OK(); } + + // sink nodes have no outputs from which to feel backpressure + static void NoOutputs() { DCHECK(false) << "no outputs; this should never be called"; } + void ResumeProducing(ExecNode* output) override { NoOutputs(); } + void PauseProducing(ExecNode* output) override { NoOutputs(); } + void StopProducing(ExecNode* output) override { NoOutputs(); } + + void StopProducing() override { + std::unique_lock lock(mutex_); + InputFinishedUnlocked(); + } + + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override { + DCHECK_EQ(input, inputs_[0]); + + std::unique_lock lock(mutex_); + if (stopped_) return; + + ++num_received_; + if (num_received_ == emit_stop_) { + InputFinishedUnlocked(); + } + + if (emit_stop_ != -1) { + DCHECK_LE(seq_num, emit_stop_); + } + lock.unlock(); + + producer_.Push(std::move(batch)); + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + producer_.Push(std::move(error)); + std::unique_lock lock(mutex_); + InputFinishedUnlocked(); + } + + void InputFinished(ExecNode* input, int seq_stop) override { + std::unique_lock lock(mutex_); + emit_stop_ = seq_stop; + if (emit_stop_ == num_received_) { + InputFinishedUnlocked(); + } + } + + private: + void InputFinishedUnlocked() { + if (!stopped_) { + stopped_ = true; + producer_.Close(); + } + } + + std::mutex mutex_; + + int num_received_ = 0; + int emit_stop_ = -1; + bool stopped_ = false; + + PushGenerator>::Producer producer_; +}; + +AsyncGenerator> MakeSinkNode(ExecNode* input, + std::string label) { + AsyncGenerator> out; + (void)input->plan()->EmplaceNode(input, std::move(label), &out); + return out; } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 0d2faea0ddc..21a757af5a1 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -24,6 +25,7 @@ #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" +#include "arrow/util/optional.h" #include "arrow/util/visibility.h" // NOTES: @@ -48,8 +50,11 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { ExecNode* AddNode(std::unique_ptr node); template - ExecNode* EmplaceNode(Args&&... args) { - return AddNode(std::unique_ptr(new Node{std::forward(args)...})); + Node* EmplaceNode(Args&&... args) { + auto node = std::unique_ptr(new Node{std::forward(args)...}); + auto out = node.get(); + AddNode(std::move(node)); + return out; } /// The initial inputs @@ -58,15 +63,6 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// The final outputs const NodeVector& sinks() const; - // XXX API question: - // There are clearly two phases in the ExecPlan lifecycle: - // - one construction phase where AddNode() and ExecNode::AddInput() is called - // (with optional validation at the end) - // - one execution phase where the nodes are topo-sorted and then started - // - // => Should we separate out those APIs? e.g. have a ExecPlanBuilder - // for the first phase. - Status Validate(); /// Start producing on all nodes @@ -75,7 +71,7 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { /// is started before all of its inputs. Status StartProducing(); - // XXX should we also have `void StopProducing()`? + void StopProducing(); protected: ExecPlan() = default; @@ -84,32 +80,26 @@ class ARROW_EXPORT ExecPlan : public std::enable_shared_from_this { class ARROW_EXPORT ExecNode { public: using NodeVector = std::vector; - using BatchDescr = std::vector; virtual ~ExecNode() = default; virtual const char* kind_name() = 0; // The number of inputs/outputs expected by this node - int num_inputs() const { return static_cast(input_descrs_.size()); } + int num_inputs() const { return static_cast(inputs_.size()); } int num_outputs() const { return num_outputs_; } /// This node's predecessors in the exec plan const NodeVector& inputs() const { return inputs_; } - /// The datatypes accepted by this node for each input - const std::vector& input_descrs() const { return input_descrs_; } - /// \brief Labels identifying the function of each input. - /// - /// For example, FilterNode accepts "target" and "filter" inputs. const std::vector& input_labels() const { return input_labels_; } /// This node's successors in the exec plan const NodeVector& outputs() const { return outputs_; } /// The datatypes for batches produced by this node - const BatchDescr& output_descr() const { return output_descr_; } + const std::shared_ptr& output_schema() const { return output_schema_; } /// This node's exec plan ExecPlan* plan() { return plan_; } @@ -119,11 +109,6 @@ class ARROW_EXPORT ExecNode { /// There is no guarantee that this value is non-empty or unique. const std::string& label() const { return label_; } - void AddInput(ExecNode* input) { - inputs_.push_back(input); - input->outputs_.push_back(this); - } - Status Validate() const; /// Upstream API: @@ -139,7 +124,7 @@ class ARROW_EXPORT ExecNode { /// and StopProducing() /// Transfer input batch to ExecNode - virtual void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) = 0; + virtual void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) = 0; /// Signal error to ExecNode virtual void ErrorReceived(ExecNode* input, Status error) = 0; @@ -222,25 +207,62 @@ class ARROW_EXPORT ExecNode { virtual void StopProducing(ExecNode* output) = 0; /// \brief Stop producing definitively + /// + /// XXX maybe this should return a Future<>? virtual void StopProducing() = 0; protected: - ExecNode(ExecPlan* plan, std::string label, std::vector input_descrs, - std::vector input_labels, BatchDescr output_descr, + ExecNode(ExecPlan* plan, std::string label, NodeVector inputs, + std::vector input_labels, std::shared_ptr output_schema, int num_outputs); ExecPlan* plan_; - std::string label_; - std::vector input_descrs_; - std::vector input_labels_; NodeVector inputs_; + std::vector input_labels_; - BatchDescr output_descr_; + std::shared_ptr output_schema_; int num_outputs_; NodeVector outputs_; }; +/// \brief Adapt an AsyncGenerator as a source node +/// +/// TODO this should accept an Executor and explicitly handle batches +/// as they are generated on each of the Executor's threads. +ARROW_EXPORT +ExecNode* MakeSourceNode(ExecPlan*, std::string label, + std::shared_ptr output_schema, + std::function>()>); + +/// \brief Add a sink node which forwards to an AsyncGenerator +/// +/// Emitted batches will not be ordered; instead they will be tagged with the `seq` at +/// which they were received. +ARROW_EXPORT +std::function>()> MakeSinkNode(ExecNode* input, + std::string label); + +/// \brief Make a node which excludes some rows from batches passed through it +/// +/// The filter Expression will be evaluated against each batch which is pushed to +/// this node. Any rows for which the filter does not evaluate to `true` will be excluded +/// in the batch emitted by this node. +/// +/// If the filter is not already bound, it will be bound against the input's schema. +ARROW_EXPORT +Result MakeFilterNode(ExecNode* input, std::string label, Expression filter); + +/// \brief Make a node which executes expressions on input batches, producing new batches. +/// +/// Each expression will be evaluated against each batch which is pushed to +/// this node to produce a corresponding output column. +/// +/// If exprs are not already bound, they will be bound against the input's schema. +ARROW_EXPORT +Result MakeProjectNode(ExecNode* input, std::string label, + std::vector exprs); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index aeabbf7bc5b..022584d5b39 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -29,6 +29,7 @@ #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/util/atomic_shared_ptr.h" +#include "arrow/util/hash_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" @@ -42,7 +43,13 @@ using internal::checked_pointer_cast; namespace compute { -Expression::Expression(Call call) : impl_(std::make_shared(std::move(call))) {} +Expression::Expression(Call call) { + call.hash = std::hash{}(call.function_name); + for (const auto& arg : call.arguments) { + arrow::internal::hash_combine(call.hash, arg.hash()); + } + impl_ = std::make_shared(std::move(call)); +} Expression::Expression(Datum literal) : impl_(std::make_shared(std::move(literal))) {} @@ -53,7 +60,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), {}}); + return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, -1}); } Expression call(std::string function, std::vector arguments, @@ -67,8 +74,12 @@ Expression call(std::string function, std::vector arguments, const Datum* Expression::literal() const { return util::get_if(impl_.get()); } +const Expression::Parameter* Expression::parameter() const { + return util::get_if(impl_.get()); +} + const FieldRef* Expression::field_ref() const { - if (auto parameter = util::get_if(impl_.get())) { + if (auto parameter = this->parameter()) { return ¶meter->ref; } return nullptr; @@ -85,7 +96,7 @@ ValueDescr Expression::descr() const { return lit->descr(); } - if (auto parameter = util::get_if(impl_.get())) { + if (auto parameter = this->parameter()) { return parameter->descr; } @@ -235,21 +246,7 @@ size_t Expression::hash() const { return ref->hash(); } - auto call = CallNotNull(*this); - if (call->hash != nullptr) { - return call->hash->load(); - } - - size_t out = std::hash{}(call->function_name); - for (const auto& arg : call->arguments) { - out ^= arg.hash(); - } - - std::shared_ptr> expected = nullptr; - ::arrow::internal::atomic_compare_exchange_strong( - &const_cast(call)->hash, &expected, - std::make_shared>(out)); - return out; + return CallNotNull(*this)->hash; } bool Expression::IsBound() const { @@ -383,76 +380,113 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ return Expression(std::move(call)); } -struct FieldPathGetDatumImpl { - template ()))> - Result operator()(const std::shared_ptr& ptr) { - return path_.Get(*ptr).template As(); - } - - template - Result operator()(const T&) { - return Status::NotImplemented("FieldPath::Get() into Datum ", datum_.ToString()); +template +Result BindImpl(Expression expr, const TypeOrSchema& in, + ValueDescr::Shape shape, compute::ExecContext* exec_context) { + if (exec_context == nullptr) { + compute::ExecContext exec_context; + return BindImpl(std::move(expr), in, shape, &exec_context); } - const Datum& datum_; - const FieldPath& path_; -}; + if (expr.literal()) return expr; -inline Result GetDatumField(const FieldRef& ref, const Datum& input) { - Datum field; + if (auto ref = expr.field_ref()) { + if (ref->IsNested()) { + return Status::NotImplemented("nested field references"); + } - FieldPath match; - if (auto type = input.type()) { - ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*type)); - } else if (auto schema = input.schema()) { - ARROW_ASSIGN_OR_RAISE(match, ref.FindOneOrNone(*schema)); - } else { - return Status::NotImplemented("retrieving fields from datum ", input.ToString()); - } + ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in)); - if (!match.empty()) { - ARROW_ASSIGN_OR_RAISE(field, - util::visit(FieldPathGetDatumImpl{input, match}, input.value)); + auto bound = *expr.parameter(); + bound.index = path[0]; + ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); + bound.descr.type = field->type(); + bound.descr.shape = shape; + return Expression{std::move(bound)}; } - if (field == Datum{}) { - return Datum(std::make_shared()); + auto call = *CallNotNull(expr); + for (auto& argument : call.arguments) { + ARROW_ASSIGN_OR_RAISE(argument, + BindImpl(std::move(argument), in, shape, exec_context)); } - - return field; + return BindNonRecursive(std::move(call), + /*insert_implicit_casts=*/true, exec_context); } } // namespace -Result Expression::Bind(ValueDescr in, +Result Expression::Bind(const ValueDescr& in, compute::ExecContext* exec_context) const { - if (exec_context == nullptr) { - compute::ExecContext exec_context; - return Bind(std::move(in), &exec_context); - } + return BindImpl(*this, *in.type, in.shape, exec_context); +} - if (literal()) return *this; +Result Expression::Bind(const Schema& in_schema, + compute::ExecContext* exec_context) const { + return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context); +} - if (auto ref = field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto field, ref->GetOneOrNone(*in.type)); - auto descr = field ? ValueDescr{field->type(), in.shape} : ValueDescr::Scalar(null()); - return Expression{Parameter{*ref, std::move(descr)}}; +Result MakeExecBatch(const Schema& full_schema, const Datum& partial) { + ExecBatch out; + + if (partial.kind() == Datum::RECORD_BATCH) { + const auto& partial_batch = *partial.record_batch(); + out.length = partial_batch.num_rows(); + + for (const auto& field : full_schema.fields()) { + ARROW_ASSIGN_OR_RAISE(auto column, + FieldRef(field->name()).GetOneOrNone(partial_batch)); + + if (column) { + if (!column->type()->Equals(field->type())) { + // Referenced field was present but didn't have the expected type. + // This *should* be handled by readers, and will just be an error in the future. + ARROW_ASSIGN_OR_RAISE( + auto converted, + compute::Cast(column, field->type(), compute::CastOptions::Safe())); + column = converted.make_array(); + } + out.values.emplace_back(std::move(column)); + } else { + out.values.emplace_back(MakeNullScalar(field->type())); + } + } + return out; } - auto call = *CallNotNull(*this); - for (auto& argument : call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, argument.Bind(in, exec_context)); + // wasteful but useful for testing: + if (partial.type()->id() == Type::STRUCT) { + if (partial.is_array()) { + ARROW_ASSIGN_OR_RAISE(auto partial_batch, + RecordBatch::FromStructArray(partial.make_array())); + + return MakeExecBatch(full_schema, partial_batch); + } + + if (partial.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto partial_array, + MakeArrayFromScalar(*partial.scalar(), 1)); + ARROW_ASSIGN_OR_RAISE(auto out, MakeExecBatch(full_schema, partial_array)); + + for (Datum& value : out.values) { + if (value.is_scalar()) continue; + ARROW_ASSIGN_OR_RAISE(value, value.make_array()->GetScalar(0)); + } + return out; + } } - return BindNonRecursive(std::move(call), - /*insert_implicit_casts=*/true, exec_context); + + return Status::NotImplemented("MakeExecBatch from ", PrintDatum(partial)); } -Result Expression::Bind(const Schema& in_schema, - compute::ExecContext* exec_context) const { - return Bind(ValueDescr::Array(struct_(in_schema.fields())), exec_context); +Result ExecuteScalarExpression(const Expression& expr, const Schema& full_schema, + const Datum& partial_input, + compute::ExecContext* exec_context) { + ARROW_ASSIGN_OR_RAISE(auto input, MakeExecBatch(full_schema, partial_input)); + return ExecuteScalarExpression(expr, input, exec_context); } -Result ExecuteScalarExpression(const Expression& expr, const Datum& input, +Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& input, compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; @@ -470,15 +504,16 @@ Result ExecuteScalarExpression(const Expression& expr, const Datum& input if (auto lit = expr.literal()) return *lit; - if (auto ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(Datum field, GetDatumField(*ref, input)); + if (auto param = expr.parameter()) { + if (param->descr.type->id() == Type::NA) { + return MakeNullScalar(null()); + } - if (field.descr() != expr.descr()) { - // Refernced field was present but didn't have the expected type. - // Should we just error here? For now, pay dispatch cost and just cast. - ARROW_ASSIGN_OR_RAISE( - field, - compute::Cast(field, expr.type(), compute::CastOptions::Safe(), exec_context)); + const Datum& field = input[param->index]; + if (!field.type()->Equals(param->descr.type)) { + return Status::Invalid("Referenced field ", expr.ToString(), " was ", + field.type()->ToString(), " but should have been ", + param->descr.type->ToString()); } return field; @@ -574,7 +609,7 @@ Result FoldConstants(Expression expr) { if (std::all_of(call->arguments.begin(), call->arguments.end(), [](const Expression& argument) { return argument.literal(); })) { // all arguments are literal; we can evaluate this subexpression *now* - static const Datum ignored_input = Datum{}; + static const ExecBatch ignored_input = ExecBatch{}; ARROW_ASSIGN_OR_RAISE(Datum constant, ExecuteScalarExpression(expr, ignored_input)); @@ -683,17 +718,16 @@ Status ExtractKnownFieldValuesImpl( } // namespace -Result> ExtractKnownFieldValues( +Result ExtractKnownFieldValues( const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; - RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + KnownFieldValues known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); return known_values; } -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, - Expression expr) { +Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression expr) { if (!expr.IsBound()) { return Status::Invalid( "ReplaceFieldsWithKnownValues called on an unbound Expression"); @@ -703,8 +737,8 @@ Result ReplaceFieldsWithKnownValues( std::move(expr), [&known_values](Expression expr) -> Result { if (auto ref = expr.field_ref()) { - auto it = known_values.find(*ref); - if (it != known_values.end()) { + auto it = known_values.map.find(*ref); + if (it != known_values.map.end()) { Datum lit = it->second; if (lit.descr() == expr.descr()) return literal(std::move(lit)); // type mismatch, try casting the known value to the correct type @@ -906,8 +940,8 @@ Result SimplifyWithGuarantee(Expression expr, const Expression& guaranteed_true_predicate) { auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - std::unordered_map known_values; - RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values)); + KnownFieldValues known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); ARROW_ASSIGN_OR_RAISE(expr, ReplaceFieldsWithKnownValues(known_values, std::move(expr))); @@ -1144,13 +1178,5 @@ Expression or_(const std::vector& operands) { Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } -Expression operator&&(Expression lhs, Expression rhs) { - return and_(std::move(lhs), std::move(rhs)); -} - -Expression operator||(Expression lhs, Expression rhs) { - return or_(std::move(lhs), std::move(rhs)); -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index f5ca2c2118d..d06a923bb32 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -19,10 +19,8 @@ #pragma once -#include #include #include -#include #include #include @@ -44,13 +42,13 @@ class ARROW_EXPORT Expression { struct Call { std::string function_name; std::vector arguments; - std::shared_ptr options; - std::shared_ptr> hash; + std::shared_ptr options; + size_t hash; // post-Bind properties: - std::shared_ptr function; - const compute::Kernel* kernel = NULLPTR; - std::shared_ptr kernel_state; + std::shared_ptr function; + const Kernel* kernel = NULLPTR; + std::shared_ptr kernel_state; ValueDescr descr; }; @@ -64,8 +62,8 @@ class ARROW_EXPORT Expression { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(ValueDescr in, compute::ExecContext* = NULLPTR) const; - Result Bind(const Schema& in_schema, compute::ExecContext* = NULLPTR) const; + Result Bind(const ValueDescr& in, ExecContext* = NULLPTR) const; + Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; // XXX someday // Clone all KernelState in this bound expression. If any function referenced by this @@ -108,8 +106,12 @@ class ARROW_EXPORT Expression { struct Parameter { FieldRef ref; + + // post-bind properties ValueDescr descr; + int index; }; + const Parameter* parameter() const; Expression() = default; explicit Expression(Call call); @@ -143,10 +145,10 @@ Expression field_ref(FieldRef ref); ARROW_EXPORT Expression call(std::string function, std::vector arguments, - std::shared_ptr options = NULLPTR); + std::shared_ptr options = NULLPTR); -template ::value>::type> +template ::value>::type> Expression call(std::string function, std::vector arguments, Options options) { return call(std::move(function), std::move(arguments), @@ -162,8 +164,9 @@ ARROW_EXPORT bool ExpressionHasFieldRefs(const Expression&); /// Assemble a mapping from field references to known values. +struct ARROW_EXPORT KnownFieldValues; ARROW_EXPORT -Result> ExtractKnownFieldValues( +Result ExtractKnownFieldValues( const Expression& guaranteed_true_predicate); /// \defgroup expression-passes Functions for modification of Expressions @@ -182,7 +185,7 @@ Result> ExtractKnownFieldVal /// equivalent Expressions may result in different canonicalized expressions. /// TODO this could be a strong canonicalization ARROW_EXPORT -Result Canonicalize(Expression, compute::ExecContext* = NULLPTR); +Result Canonicalize(Expression, ExecContext* = NULLPTR); /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all @@ -192,8 +195,8 @@ Result FoldConstants(Expression); /// Simplify Expressions by replacing with known values of the fields which it references. ARROW_EXPORT -Result ReplaceFieldsWithKnownValues( - const std::unordered_map& known_values, Expression); +Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_values, + Expression); /// Simplify an expression by replacing subexpressions based on a guarantee: /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is @@ -207,11 +210,22 @@ Result SimplifyWithGuarantee(Expression, // Execution -/// Execute a scalar expression against the provided state and input Datum. This +/// Create an ExecBatch suitable for passing to ExecuteScalarExpression() from a +/// RecordBatch which may have missing or incorrectly ordered columns. +/// Missing fields will be replaced with null scalars. +ARROW_EXPORT Result MakeExecBatch(const Schema& full_schema, + const Datum& partial); + +/// Execute a scalar expression against the provided state and input ExecBatch. This /// expression must be bound. ARROW_EXPORT -Result ExecuteScalarExpression(const Expression&, const Datum& input, - compute::ExecContext* = NULLPTR); +Result ExecuteScalarExpression(const Expression&, const ExecBatch& input, + ExecContext* = NULLPTR); + +/// Convenience function for invoking against a RecordBatch +ARROW_EXPORT +Result ExecuteScalarExpression(const Expression&, const Schema& full_schema, + const Datum& partial_input, ExecContext* = NULLPTR); // Serialization diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index b9165a5f0c2..51d242e8d66 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -34,6 +34,10 @@ using internal::checked_cast; namespace compute { +struct KnownFieldValues { + std::unordered_map map; +}; + inline const Expression::Call* CallNotNull(const Expression& expr) { auto call = expr.call(); DCHECK_NE(call, nullptr); diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 908e8962e43..86909f4eb64 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -166,6 +166,56 @@ TEST(ExpressionUtils, StripOrderPreservingCasts) { Expect(cast(field_ref("i32"), uint64()), no_change); } +TEST(ExpressionUtils, MakeExecBatch) { + auto Expect = [](std::shared_ptr partial_batch) { + SCOPED_TRACE(partial_batch->ToString()); + ASSERT_OK_AND_ASSIGN(auto batch, MakeExecBatch(*kBoringSchema, partial_batch)); + + ASSERT_EQ(batch.num_values(), kBoringSchema->num_fields()); + for (int i = 0; i < kBoringSchema->num_fields(); ++i) { + const auto& field = *kBoringSchema->field(i); + + SCOPED_TRACE("Field#" + std::to_string(i) + " " + field.ToString()); + + EXPECT_TRUE(batch[i].type()->Equals(field.type())) + << "Incorrect type " << batch[i].type()->ToString(); + + ASSERT_OK_AND_ASSIGN(auto col, FieldRef(field.name()).GetOneOrNone(*partial_batch)); + + if (batch[i].is_scalar()) { + EXPECT_FALSE(batch[i].scalar()->is_valid) + << "Non-null placeholder scalar was injected"; + + EXPECT_EQ(col, nullptr) + << "Placeholder scalar overwrote column " << col->ToString(); + } else { + AssertDatumsEqual(col, batch[i]); + } + } + }; + + auto GetField = [](std::string name) { return kBoringSchema->GetFieldByName(name); }; + + constexpr int64_t kNumRows = 3; + auto i32 = ArrayFromJSON(int32(), "[1, 2, 3]"); + auto f32 = ArrayFromJSON(float32(), "[1.5, 2.25, 3.125]"); + + // empty + Expect(RecordBatchFromJSON(kBoringSchema, "[]")); + + // subset + Expect(RecordBatch::Make(schema({GetField("i32"), GetField("f32")}), kNumRows, + {i32, f32})); + + // flipped subset + Expect(RecordBatch::Make(schema({GetField("f32"), GetField("i32")}), kNumRows, + {f32, i32})); + + auto duplicated_names = + RecordBatch::Make(schema({GetField("i32"), GetField("i32")}), kNumRows, {i32, i32}); + ASSERT_RAISES(Invalid, MakeExecBatch(*kBoringSchema, duplicated_names)); +} + TEST(Expression, ToString) { EXPECT_EQ(field_ref("alpha").ToString(), "alpha"); @@ -445,21 +495,18 @@ TEST(Expression, BindFieldRef) { ExpectBindsTo(field_ref("i32"), no_change, &expr); EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); - // if the field is not found, a null scalar will be emitted - ExpectBindsTo(field_ref("no such field"), no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Scalar(null())); + // if the field is not found, an error will be raised + ASSERT_RAISES(Invalid, field_ref("no such field").Bind(*kBoringSchema)); // referencing a field by name is not supported if that name is not unique // in the input schema ASSERT_RAISES(Invalid, field_ref("alpha").Bind(Schema( {field("alpha", int32()), field("alpha", float32())}))); - // referencing nested fields is supported - ASSERT_OK_AND_ASSIGN(expr, - field_ref(FieldRef("a", "b")) - .Bind(Schema({field("a", struct_({field("b", int32())}))}))); - EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + // referencing nested fields is not supported + ASSERT_RAISES(NotImplemented, + field_ref(FieldRef("a", "b")) + .Bind(Schema({field("a", struct_({field("b", int32())}))}))); } TEST(Expression, BindCall) { @@ -525,7 +572,8 @@ TEST(Expression, ExecuteFieldRef) { auto expr = field_ref(ref); ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + ASSERT_OK_AND_ASSIGN(Datum actual, + ExecuteScalarExpression(expr, Schema(in.type()->fields()), in)); AssertDatumsEqual(actual, expected, /*verbose=*/true); }; @@ -537,39 +585,45 @@ TEST(Expression, ExecuteFieldRef) { ])"), ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); - // more nested: - ExpectRefIs(FieldRef{"a", "a"}, - ArrayFromJSON(struct_({field("a", struct_({field("a", float64())}))}), R"([ - {"a": {"a": 6.125}}, - {"a": {"a": 0.0}}, - {"a": {"a": -1}} + ExpectRefIs("a", + ArrayFromJSON(struct_({ + field("a", float64()), + field("b", float64()), + }), + R"([ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} ])"), ArrayFromJSON(float64(), R"([6.125, 0.0, -1])")); - // absent fields are resolved as a null scalar: - ExpectRefIs(FieldRef{"b"}, ArrayFromJSON(struct_({field("a", float64())}), R"([ - {"a": 6.125}, - {"a": 0.0}, - {"a": -1} + ExpectRefIs("b", + ArrayFromJSON(struct_({ + field("a", float64()), + field("b", float64()), + }), + R"([ + {"a": 6.125, "b": 7.5}, + {"a": 0.0, "b": 2.125}, + {"a": -1, "b": 4.0} ])"), - MakeNullScalar(null())); - - // XXX this *should* fail in Bind but for now it will just error in - // ExecuteScalarExpression - ASSERT_OK_AND_ASSIGN(auto list_item, field_ref("item").Bind(list(int32()))); - EXPECT_RAISES_WITH_MESSAGE_THAT( - NotImplemented, HasSubstr("non-struct array"), - ExecuteScalarExpression(list_item, - ArrayFromJSON(list(int32()), "[[1,2], [], null, [5]]"))); + ArrayFromJSON(float64(), R"([7.5, 2.125, 4.0])")); } Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& input) { - auto call = expr.call(); - if (call == nullptr) { - // already tested execution of field_ref, execution of literal is trivial - return ExecuteScalarExpression(expr, input); + if (auto lit = expr.literal()) { + return *lit; } + if (auto ref = expr.field_ref()) { + if (input.type()) { + return ref->GetOneOrNone(*input.make_array()); + } + return ref->GetOneOrNone(*input.record_batch()); + } + + auto call = CallNotNull(expr); + std::vector arguments(call->arguments.size()); for (size_t i = 0; i < arguments.size(); ++i) { ARROW_ASSIGN_OR_RAISE(arguments[i], @@ -587,13 +641,16 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& } void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { + std::shared_ptr schm; if (in.is_value()) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + schm = schema(in.type()->fields()); } else { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.schema())); + schm = in.schema(); } - ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, in)); + ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, *schm, in)); ASSERT_OK_AND_ASSIGN(Datum expected, NaiveExecuteScalarExpression(expr, in)); @@ -653,9 +710,9 @@ TEST(Expression, ExecuteDictionaryTransparent) { ASSERT_OK_AND_ASSIGN( expr, SimplifyWithGuarantee(expr, equal(field_ref("dict_str"), literal("eh")))); - ASSERT_OK_AND_ASSIGN( - auto res, - ExecuteScalarExpression(expr, ArrayFromJSON(struct_({field("i32", int32())}), R"([ + ASSERT_OK_AND_ASSIGN(auto res, ExecuteScalarExpression( + expr, *kBoringSchema, + ArrayFromJSON(struct_({field("i32", int32())}), R"([ {"i32": 0}, {"i32": 1}, {"i32": 2} @@ -773,7 +830,7 @@ TEST(Expression, ExtractKnownFieldValues) { void operator()(Expression guarantee, std::unordered_map expected) { ASSERT_OK_AND_ASSIGN(auto actual, ExtractKnownFieldValues(guarantee)); - EXPECT_THAT(actual, UnorderedElementsAreArray(expected)) + EXPECT_THAT(actual.map, UnorderedElementsAreArray(expected)) << " guarantee: " << guarantee.ToString(); } } ExpectKnown; @@ -825,8 +882,8 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto replaced, - ReplaceFieldsWithKnownValues(known_values, expr)); + ASSERT_OK_AND_ASSIGN(auto replaced, ReplaceFieldsWithKnownValues( + KnownFieldValues{known_values}, expr)); EXPECT_EQ(replaced, expected); ExpectIdenticalIfUnchanged(replaced, expr); @@ -841,7 +898,7 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { // NB: known_values will be cast ExpectReplacesTo(field_ref("i32"), {{"i32", Datum("3")}}, literal(3)); - ExpectReplacesTo(field_ref("b"), i32_is_3, field_ref("b")); + ExpectReplacesTo(field_ref("f32"), i32_is_3, field_ref("f32")); ExpectReplacesTo(equal(field_ref("i32"), literal(1)), i32_is_3, equal(literal(3), literal(1))); @@ -886,13 +943,13 @@ TEST(Expression, ReplaceFieldsWithKnownValues) { Datum dict_i32{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(int32(), R"([3])"))}; // Unsupported cast dictionary(int32(), int32()) -> dictionary(int32(), utf8()) - ASSERT_RAISES(NotImplemented, - ReplaceFieldsWithKnownValues({{"dict_str", dict_i32}}, expr)); + ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( + KnownFieldValues{{{"dict_str", dict_i32}}}, expr)); // Unsupported cast dictionary(int8(), utf8()) -> dictionary(int32(), utf8()) dict_str = Datum{ DictionaryScalar::Make(MakeScalar(0), ArrayFromJSON(utf8(), R"(["a"])"))}; - ASSERT_RAISES(NotImplemented, - ReplaceFieldsWithKnownValues({{"dict_str", dict_str}}, expr)); + ASSERT_RAISES(NotImplemented, ReplaceFieldsWithKnownValues( + KnownFieldValues{{{"dict_str", dict_str}}}, expr)); } struct { @@ -1082,7 +1139,8 @@ TEST(Expression, SingleComparisonGuarantees) { {"i32"})); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum evaluated, ExecuteScalarExpression(filter, input)); + ASSERT_OK_AND_ASSIGN(Datum evaluated, + ExecuteScalarExpression(filter, *kBoringSchema, input)); // ensure that the simplified filter is as simplified as it could be // (this is always possible for single comparisons) @@ -1193,7 +1251,8 @@ TEST(Expression, Filter) { auto expected_mask = batch->column(0); ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(Datum mask, ExecuteScalarExpression(filter, batch)); + ASSERT_OK_AND_ASSIGN(Datum mask, + ExecuteScalarExpression(filter, *kBoringSchema, batch)); AssertDatumsEqual(expected_mask, mask); }; @@ -1286,7 +1345,8 @@ TEST(Projection, AugmentWithNull) { auto ExpectProject = [&](Expression proj, Datum expected) { ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto actual, ExecuteScalarExpression(proj, input)); + ASSERT_OK_AND_ASSIGN(auto actual, + ExecuteScalarExpression(proj, *kBoringSchema, input)); AssertDatumsEqual(Datum(expected), actual); }; @@ -1316,7 +1376,8 @@ TEST(Projection, AugmentWithKnownValues) { Expression guarantee) { ASSERT_OK_AND_ASSIGN(proj, proj.Bind(*kBoringSchema)); ASSERT_OK_AND_ASSIGN(proj, SimplifyWithGuarantee(proj, guarantee)); - ASSERT_OK_AND_ASSIGN(auto actual, ExecuteScalarExpression(proj, input)); + ASSERT_OK_AND_ASSIGN(auto actual, + ExecuteScalarExpression(proj, *kBoringSchema, input)); AssertDatumsEqual(Datum(expected), actual); }; diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 86f1879cbe9..75b71f97535 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -15,34 +15,33 @@ // specific language governing permissions and limitations // under the License. -#include - #include #include +#include + +#include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/test_util.h" #include "arrow/record_batch.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/random.h" +#include "arrow/util/async_generator.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" +#include "arrow/util/vector.h" -namespace arrow { +using testing::ElementsAre; +using testing::HasSubstr; +using testing::UnorderedElementsAreArray; -using internal::Executor; +namespace arrow { namespace compute { -void AssertBatchesEqual(const RecordBatchVector& expected, - const RecordBatchVector& actual) { - ASSERT_EQ(expected.size(), actual.size()); - for (size_t i = 0; i < expected.size(); ++i) { - AssertBatchesEqual(*expected[i], *actual[i]); - } -} - TEST(ExecPlanConstruction, Empty) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); @@ -51,69 +50,49 @@ TEST(ExecPlanConstruction, Empty) { TEST(ExecPlanConstruction, SingleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/0); + auto node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ::testing::ElementsAre(node)); - ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(node)); - - ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/0); - // Input not bound - ASSERT_RAISES(Invalid, plan->Validate()); + ASSERT_THAT(plan->sources(), ElementsAre(node)); + ASSERT_THAT(plan->sinks(), ElementsAre(node)); ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/0, /*num_outputs=*/1); + node = MakeDummyNode(plan.get(), "dummy", /*inputs=*/{}, /*num_outputs=*/1); // Output not bound ASSERT_RAISES(Invalid, plan->Validate()); } TEST(ExecPlanConstruction, SourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - // Input / output not bound - ASSERT_RAISES(Invalid, plan->Validate()); + auto source = MakeDummyNode(plan.get(), "source", /*inputs=*/{}, /*num_outputs=*/1); + auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{source}, /*num_outputs=*/0); - sink->AddInput(source); ASSERT_OK(plan->Validate()); - EXPECT_THAT(plan->sources(), ::testing::ElementsAre(source)); - EXPECT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); + EXPECT_THAT(plan->sources(), ElementsAre(source)); + EXPECT_THAT(plan->sinks(), ElementsAre(sink)); } TEST(ExecPlanConstruction, MultipleNode) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = - MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2); + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2); - auto source2 = - MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1); + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1); auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2); + MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/1); + auto process2 = MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1, source2}, + /*num_outputs=*/1); auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1); - - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - - sink->AddInput(process3); - - process3->AddInput(process1); - process3->AddInput(process2); - process3->AddInput(process1); - - process2->AddInput(source1); - process2->AddInput(source2); + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, process2, process1}, + /*num_outputs=*/1); - process1->AddInput(source1); + auto sink = MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0); ASSERT_OK(plan->Validate()); - ASSERT_THAT(plan->sources(), ::testing::ElementsAre(source1, source2)); - ASSERT_THAT(plan->sinks(), ::testing::ElementsAre(sink)); + ASSERT_THAT(plan->sources(), ElementsAre(source1, source2)); + ASSERT_THAT(plan->sinks(), ElementsAre(sink)); } struct StartStopTracker { @@ -135,30 +114,27 @@ TEST(ExecPlan, DummyStartProducing) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, + + auto source1 = MakeDummyNode(plan.get(), "source1", /*inputs=*/{}, /*num_outputs=*/2, t.start_producing_func(), t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, + + auto source2 = MakeDummyNode(plan.get(), "source2", /*inputs=*/{}, /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + MakeDummyNode(plan.get(), "process1", /*inputs=*/{source1}, /*num_outputs=*/2, t.start_producing_func(), t.stop_producing_func()); + auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + MakeDummyNode(plan.get(), "process2", /*inputs=*/{process1, source2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); + auto process3 = + MakeDummyNode(plan.get(), "process3", /*inputs=*/{process1, source1, process2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); - process1->AddInput(source1); - process2->AddInput(process1); - process2->AddInput(source2); - process3->AddInput(process1); - process3->AddInput(source1); - process3->AddInput(process2); - sink->AddInput(process3); + MakeDummyNode(plan.get(), "sink", /*inputs=*/{process3}, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -166,68 +142,37 @@ TEST(ExecPlan, DummyStartProducing) { ASSERT_OK(plan->StartProducing()); // Note that any correct reverse topological order may do - ASSERT_THAT(t.started, ::testing::ElementsAre("sink", "process3", "process2", - "process1", "source2", "source1")); + ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1", + "source2", "source1")); ASSERT_EQ(t.stopped.size(), 0); } -TEST(ExecPlan, DummyStartProducingCycle) { - // A trivial cycle - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto node = MakeDummyNode(plan.get(), "dummy", /*num_inputs=*/1, /*num_outputs=*/1); - node->AddInput(node); - ASSERT_OK(plan->Validate()); - ASSERT_RAISES(Invalid, plan->StartProducing()); - - // A less trivial one - ASSERT_OK_AND_ASSIGN(plan, ExecPlan::Make()); - auto source = MakeDummyNode(plan.get(), "source", /*num_inputs=*/0, /*num_outputs=*/1); - auto process1 = - MakeDummyNode(plan.get(), "process1", /*num_inputs=*/2, /*num_outputs=*/2); - auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/1, /*num_outputs=*/1); - auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/2, /*num_outputs=*/2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0); - - process1->AddInput(source); - process2->AddInput(process1); - process3->AddInput(process2); - process3->AddInput(process1); - process1->AddInput(process3); - sink->AddInput(process3); - - ASSERT_OK(plan->Validate()); - ASSERT_RAISES(Invalid, plan->StartProducing()); -} - TEST(ExecPlan, DummyStartProducingError) { StartStopTracker t; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - auto source1 = MakeDummyNode(plan.get(), "source1", /*num_inputs=*/0, /*num_outputs=*/2, - t.start_producing_func(Status::NotImplemented("zzz")), - t.stop_producing_func()); - auto source2 = MakeDummyNode(plan.get(), "source2", /*num_inputs=*/0, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); + auto source1 = MakeDummyNode( + plan.get(), "source1", /*num_inputs=*/{}, /*num_outputs=*/2, + t.start_producing_func(Status::NotImplemented("zzz")), t.stop_producing_func()); + + auto source2 = + MakeDummyNode(plan.get(), "source2", /*num_inputs=*/{}, /*num_outputs=*/1, + t.start_producing_func(), t.stop_producing_func()); + auto process1 = MakeDummyNode( - plan.get(), "process1", /*num_inputs=*/1, /*num_outputs=*/2, + plan.get(), "process1", /*num_inputs=*/{source1}, /*num_outputs=*/2, t.start_producing_func(Status::IOError("xxx")), t.stop_producing_func()); + auto process2 = - MakeDummyNode(plan.get(), "process2", /*num_inputs=*/2, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - process1->AddInput(source1); - process2->AddInput(process1); - process2->AddInput(source2); + MakeDummyNode(plan.get(), "process2", /*num_inputs=*/{process1, source2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + auto process3 = - MakeDummyNode(plan.get(), "process3", /*num_inputs=*/3, /*num_outputs=*/1, - t.start_producing_func(), t.stop_producing_func()); - process3->AddInput(process1); - process3->AddInput(source1); - process3->AddInput(process2); - auto sink = MakeDummyNode(plan.get(), "sink", /*num_inputs=*/1, /*num_outputs=*/0, - t.start_producing_func(), t.stop_producing_func()); - sink->AddInput(process3); + MakeDummyNode(plan.get(), "process3", /*num_inputs=*/{process1, source1, process2}, + /*num_outputs=*/1, t.start_producing_func(), t.stop_producing_func()); + + MakeDummyNode(plan.get(), "sink", /*num_inputs=*/{process3}, /*num_outputs=*/0, + t.start_producing_func(), t.stop_producing_func()); ASSERT_OK(plan->Validate()); ASSERT_EQ(t.started.size(), 0); @@ -235,165 +180,206 @@ TEST(ExecPlan, DummyStartProducingError) { // `process1` raises IOError ASSERT_RAISES(IOError, plan->StartProducing()); - ASSERT_THAT(t.started, - ::testing::ElementsAre("sink", "process3", "process2", "process1")); + ASSERT_THAT(t.started, ElementsAre("sink", "process3", "process2", "process1")); // Nodes that started successfully were stopped in reverse order - ASSERT_THAT(t.stopped, ::testing::ElementsAre("process2", "process3", "sink")); + ASSERT_THAT(t.stopped, ElementsAre("process2", "process3", "sink")); } -// TODO move this to gtest_util.h? +namespace { -class SlowRecordBatchReader : public RecordBatchReader { - public: - explicit SlowRecordBatchReader(std::shared_ptr reader) - : reader_(std::move(reader)) {} +struct BatchesWithSchema { + std::vector batches; + std::shared_ptr schema; +}; + +Result MakeTestSourceNode(ExecPlan* plan, std::string label, + BatchesWithSchema batches_with_schema, bool parallel, + bool slow) { + DCHECK_GT(batches_with_schema.batches.size(), 0); + + auto opt_batches = internal::MapVector( + [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, + std::move(batches_with_schema.batches)); - std::shared_ptr schema() const override { return reader_->schema(); } + AsyncGenerator> gen; - Status ReadNext(std::shared_ptr* batch) override { - SleepABit(); - return reader_->ReadNext(batch); + if (parallel) { + // emulate batches completing initial decode-after-scan on a cpu thread + ARROW_ASSIGN_OR_RAISE( + gen, MakeBackgroundGenerator(MakeVectorIterator(std::move(opt_batches)), + internal::GetCpuThreadPool())); + + // ensure that callbacks are not executed immediately on a background thread + gen = MakeTransferredGenerator(std::move(gen), internal::GetCpuThreadPool()); + } else { + gen = MakeVectorGenerator(std::move(opt_batches)); } - static Result> Make( - RecordBatchVector batches, std::shared_ptr schema = nullptr) { - ARROW_ASSIGN_OR_RAISE(auto reader, - RecordBatchReader::Make(std::move(batches), std::move(schema))); - return std::make_shared(std::move(reader)); + if (slow) { + gen = MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + SleepABit(); + return batch; + }); } - protected: - std::shared_ptr reader_; -}; + return MakeSourceNode(plan, label, std::move(batches_with_schema.schema), + std::move(gen)); +} -static Result MakeSlowRecordBatchGenerator( - RecordBatchVector batches, std::shared_ptr schema) { - auto gen = MakeVectorGenerator(batches); - // TODO move this into testing/async_generator_util.h? - auto delayed_gen = MakeMappedGenerator>( - std::move(gen), [](const std::shared_ptr& batch) { - auto fut = Future>::Make(); - SleepABitAsync().AddCallback( - [fut, batch](const Status& status) mutable { fut.MarkFinished(batch); }); - return fut; - }); - // Adding readahead implicitly adds parallelism by pulling reentrantly from - // the delayed generator - return MakeReadaheadGenerator(std::move(delayed_gen), /*max_readahead=*/64); +Result> StartAndCollect( + ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto maybe_collected = CollectAsyncGenerator(gen).result(); + ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + + plan->StopProducing(); + + return internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, collected); } -class TestExecPlanExecution : public ::testing::Test { - public: - void SetUp() override { - ASSERT_OK_AND_ASSIGN(io_executor_, internal::ThreadPool::Make(8)); +BatchesWithSchema MakeBasicBatches() { + BatchesWithSchema out; + out.batches = { + ExecBatchFromJSON({int32(), boolean()}, "[[null, true], [4, false]]"), + ExecBatchFromJSON({int32(), boolean()}, "[[5, null], [6, false], [7, false]]")}; + out.schema = schema({field("i32", int32()), field("bool", boolean())}); + return out; +} + +BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, + int num_batches = 10, int batch_size = 4) { + BatchesWithSchema out; + + random::RandomArrayGenerator rng(42); + out.batches.resize(num_batches); + + for (int i = 0; i < num_batches; ++i) { + out.batches[i] = ExecBatch(*rng.BatchOf(schema->fields(), batch_size)); + // add a tag scalar to ensure the batches are unique + out.batches[i].values.emplace_back(i); } + return out; +} +} // namespace + +TEST(ExecPlanExecution, SourceSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); + + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + auto basic_data = MakeBasicBatches(); - RecordBatchVector MakeRandomBatches(const std::shared_ptr& schema, - int num_batches = 10, int batch_size = 4) { - random::RandomArrayGenerator rng(42); - RecordBatchVector batches; - batches.reserve(num_batches); - for (int i = 0; i < num_batches; ++i) { - batches.push_back(rng.BatchOf(schema->fields(), batch_size)); + ASSERT_OK_AND_ASSIGN(auto source, MakeTestSourceNode(plan.get(), "source", + basic_data, parallel, slow)); + + auto sink_gen = MakeSinkNode(source, "sink"); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(basic_data.batches))); } - return batches; } +} - struct CollectorPlan { - std::shared_ptr plan; - RecordBatchCollectNode* sink; +TEST(ExecPlanExecution, SourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + auto basic_data = MakeBasicBatches(); + auto it = basic_data.batches.begin(); + AsyncGenerator> gen = + [&]() -> Result> { + if (it == basic_data.batches.end()) { + return Status::Invalid("Artificial error"); + } + return util::make_optional(*it++); }; - Result MakeSourceSink(std::shared_ptr reader, - const std::shared_ptr& schema) { - ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); - auto source = - MakeRecordBatchReaderNode(plan.get(), "source", reader, io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - sink->AddInput(source); - return CollectorPlan{plan, sink}; - } + auto source = MakeSourceNode(plan.get(), "source", {}, gen); + auto sink_gen = MakeSinkNode(source, "sink"); - Result MakeSourceSink(RecordBatchGenerator generator, - const std::shared_ptr& schema) { - ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make()); - auto source = MakeRecordBatchReaderNode(plan.get(), "source", schema, generator, - io_executor_.get()); - auto sink = MakeRecordBatchCollectNode(plan.get(), "sink", schema); - sink->AddInput(source); - return CollectorPlan{plan, sink}; - } + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Raises(StatusCode::Invalid, HasSubstr("Artificial"))); +} - Result MakeSourceSink(const RecordBatchVector& batches, - const std::shared_ptr& schema) { - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make(batches, schema)); - return MakeSourceSink(std::move(reader), schema); - } +TEST(ExecPlanExecution, StressSourceSink) { + for (bool slow : {false, true}) { + SCOPED_TRACE(slow ? "slowed" : "unslowed"); - Result StartAndCollect(ExecPlan* plan, - RecordBatchCollectNode* sink) { - RETURN_NOT_OK(plan->StartProducing()); - auto fut = CollectAsyncGenerator(sink->generator()); - return fut.result(); - } + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel" : "single threaded"); - template - void TestSourceSink(RecordBatchReaderFactory reader_factory) { - auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - RecordBatchVector batches{ - RecordBatchFromJSON(schema, R"([{"a": null, "b": true}, - {"a": 4, "b": false}])"), - RecordBatchFromJSON(schema, R"([{"a": 5, "b": null}, - {"a": 6, "b": false}, - {"a": 7, "b": false}])"), - }; + int num_batches = slow && !parallel ? 30 : 300; - ASSERT_OK_AND_ASSIGN(auto reader, reader_factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); - ASSERT_OK(cp.plan->Validate()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); - AssertBatchesEqual(batches, got_batches); - } + auto random_data = MakeRandomBatches( + schema({field("a", int32()), field("b", boolean())}), num_batches); - template - void TestStressSourceSink(int num_batches, RecordBatchReaderFactory batch_factory) { - auto schema = ::arrow::schema({field("a", int32()), field("b", boolean())}); - auto batches = MakeRandomBatches(schema, num_batches); + ASSERT_OK_AND_ASSIGN(auto source, MakeTestSourceNode(plan.get(), "source", + random_data, parallel, slow)); - ASSERT_OK_AND_ASSIGN(auto reader, batch_factory(batches, schema)); - ASSERT_OK_AND_ASSIGN(auto cp, MakeSourceSink(reader, schema)); - ASSERT_OK(cp.plan->Validate()); + auto sink_gen = MakeSinkNode(source, "sink"); - ASSERT_OK_AND_ASSIGN(auto got_batches, StartAndCollect(cp.plan.get(), cp.sink)); - AssertBatchesEqual(batches, got_batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(random_data.batches))); + } } +} - protected: - std::shared_ptr io_executor_; -}; +TEST(ExecPlanExecution, SourceFilterSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); -TEST_F(TestExecPlanExecution, SourceSink) { TestSourceSink(RecordBatchReader::Make); } + auto basic_data = MakeBasicBatches(); -TEST_F(TestExecPlanExecution, SlowSourceSink) { - TestSourceSink(SlowRecordBatchReader::Make); -} + ASSERT_OK_AND_ASSIGN(auto source, + MakeTestSourceNode(plan.get(), "source", basic_data, + /*parallel=*/false, /*slow=*/false)); -TEST_F(TestExecPlanExecution, SlowSourceSinkParallel) { - TestSourceSink(MakeSlowRecordBatchGenerator); -} + ASSERT_OK_AND_ASSIGN(auto predicate, + equal(field_ref("i32"), literal(6)).Bind(*basic_data.schema)); -TEST_F(TestExecPlanExecution, StressSourceSink) { - TestStressSourceSink(/*num_batches=*/200, RecordBatchReader::Make); -} + ASSERT_OK_AND_ASSIGN(auto filter, MakeFilterNode(source, "filter", predicate)); + + auto sink_gen = MakeSinkNode(filter, "sink"); -TEST_F(TestExecPlanExecution, StressSlowSourceSink) { - // This doesn't create parallelism as the RecordBatchReader is iterated serially. - TestStressSourceSink(/*num_batches=*/30, SlowRecordBatchReader::Make); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray( + {ExecBatchFromJSON({int32(), boolean()}, "[]"), + ExecBatchFromJSON({int32(), boolean()}, "[[6, false]]")}))); } -TEST_F(TestExecPlanExecution, StressSlowSourceSinkParallel) { - TestStressSourceSink(/*num_batches=*/300, MakeSlowRecordBatchGenerator); +TEST(ExecPlanExecution, SourceProjectSink) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + + auto basic_data = MakeBasicBatches(); + + ASSERT_OK_AND_ASSIGN(auto source, + MakeTestSourceNode(plan.get(), "source", basic_data, + /*parallel=*/false, /*slow=*/false)); + + std::vector exprs{ + not_(field_ref("bool")), + call("add", {field_ref("i32"), literal(1)}), + }; + for (auto& expr : exprs) { + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*basic_data.schema)); + } + + ASSERT_OK_AND_ASSIGN(auto projection, MakeProjectNode(source, "project", exprs)); + + auto sink_gen = MakeSinkNode(projection, "sink"); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray( + {ExecBatchFromJSON({boolean(), int32()}, "[[false, null], [true, 5]]"), + ExecBatchFromJSON({boolean(), int32()}, + "[[null, 6], [true, 7], [true, 8]]")}))); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index ae2c9446aa9..6fbfa2a430c 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -33,11 +33,13 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/datum.h" #include "arrow/record_batch.h" +#include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/optional.h" +#include "arrow/util/vector.h" namespace arrow { @@ -46,31 +48,22 @@ using internal::Executor; namespace compute { namespace { -// TODO expose this as `static ValueDescr::FromSchemaColumns`? -std::vector DescrFromSchemaColumns(const Schema& schema) { - std::vector descr(schema.num_fields()); - std::transform(schema.fields().begin(), schema.fields().end(), descr.begin(), - [](const std::shared_ptr& field) { - return ValueDescr::Array(field->type()); - }); - return descr; -} - struct DummyNode : ExecNode { - DummyNode(ExecPlan* plan, std::string label, int num_inputs, int num_outputs, + DummyNode(ExecPlan* plan, std::string label, NodeVector inputs, int num_outputs, StartProducingFunc start_producing, StopProducingFunc stop_producing) - : ExecNode(plan, std::move(label), std::vector(num_inputs, descr()), {}, - descr(), num_outputs), + : ExecNode(plan, std::move(label), std::move(inputs), {}, dummy_schema(), + num_outputs), start_producing_(std::move(start_producing)), stop_producing_(std::move(stop_producing)) { - for (int i = 0; i < num_inputs; ++i) { - input_labels_.push_back(std::to_string(i)); + input_labels_.resize(inputs_.size()); + for (size_t i = 0; i < input_labels_.size(); ++i) { + input_labels_[i] = std::to_string(i); } } const char* kind_name() override { return "Dummy"; } - void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} + void InputReceived(ExecNode* input, int seq_num, ExecBatch batch) override {} void ErrorReceived(ExecNode* input, Status error) override {} @@ -117,283 +110,45 @@ struct DummyNode : ExecNode { ASSERT_NE(std::find(outputs_.begin(), outputs_.end(), output), outputs_.end()); } - BatchDescr descr() const { return std::vector{ValueDescr(null())}; } + std::shared_ptr dummy_schema() const { + return schema({field("dummy", null())}); + } StartProducingFunc start_producing_; StopProducingFunc stop_producing_; bool started_ = false; }; -struct RecordBatchReaderNode : ExecNode { - RecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, Executor* io_executor) - : ExecNode(plan, std::move(label), {}, {}, - DescrFromSchemaColumns(*reader->schema()), /*num_outputs=*/1), - schema_(reader->schema()), - reader_(std::move(reader)), - io_executor_(io_executor) {} - - RecordBatchReaderNode(ExecPlan* plan, std::string label, std::shared_ptr schema, - RecordBatchGenerator generator, Executor* io_executor) - : ExecNode(plan, std::move(label), {}, {}, DescrFromSchemaColumns(*schema), - /*num_outputs=*/1), - schema_(std::move(schema)), - generator_(std::move(generator)), - io_executor_(io_executor) {} - - const char* kind_name() override { return "RecordBatchReader"; } - - void InputReceived(ExecNode* input, int seq_num, compute::ExecBatch batch) override {} - - void ErrorReceived(ExecNode* input, Status error) override {} - - void InputFinished(ExecNode* input, int seq_stop) override {} - - Status StartProducing() override { - next_batch_index_ = 0; - if (!generator_) { - auto it = MakeIteratorFromReader(reader_); - ARROW_ASSIGN_OR_RAISE(generator_, - MakeBackgroundGenerator(std::move(it), io_executor_)); - } - GenerateOne(std::unique_lock{mutex_}); - return Status::OK(); - } - - void PauseProducing(ExecNode* output) override {} - - void ResumeProducing(ExecNode* output) override {} - - void StopProducing(ExecNode* output) override { - ASSERT_EQ(output, outputs_[0]); - std::unique_lock lock(mutex_); - generator_ = nullptr; // null function - } - - void StopProducing() override { StopProducing(outputs_[0]); } - - private: - void GenerateOne(std::unique_lock&& lock) { - if (!generator_) { - // Stopped - return; - } - auto plan = this->plan()->shared_from_this(); - auto fut = generator_(); - const auto batch_index = next_batch_index_++; - - lock.unlock(); - // TODO we want to transfer always here - io_executor_->Transfer(std::move(fut)) - .AddCallback( - [plan, batch_index, this](const Result>& res) { - std::unique_lock lock(mutex_); - if (!res.ok()) { - for (auto out : outputs_) { - out->ErrorReceived(this, res.status()); - } - return; - } - const auto& batch = *res; - if (IsIterationEnd(batch)) { - lock.unlock(); - for (auto out : outputs_) { - out->InputFinished(this, batch_index); - } - } else { - lock.unlock(); - for (auto out : outputs_) { - out->InputReceived(this, batch_index, compute::ExecBatch(*batch)); - } - lock.lock(); - GenerateOne(std::move(lock)); - } - }); - } - - std::mutex mutex_; - const std::shared_ptr schema_; - const std::shared_ptr reader_; - RecordBatchGenerator generator_; - int next_batch_index_; - - Executor* const io_executor_; -}; - -struct RecordBatchCollectNodeImpl : public RecordBatchCollectNode { - RecordBatchCollectNodeImpl(ExecPlan* plan, std::string label, - std::shared_ptr schema) - : RecordBatchCollectNode(plan, std::move(label), {DescrFromSchemaColumns(*schema)}, - {"batches_to_collect"}, {}, 0), - schema_(std::move(schema)) {} - - RecordBatchGenerator generator() override { return generator_; } - - const char* kind_name() override { return "RecordBatchReader"; } - - Status StartProducing() override { - num_received_ = 0; - num_emitted_ = 0; - emit_stop_ = -1; - stopped_ = false; - producer_.emplace(generator_.producer()); - return Status::OK(); - } - - // sink nodes have no outputs from which to feel backpressure - void ResumeProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - void PauseProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - void StopProducing(ExecNode* output) override { - FAIL() << "no outputs; this should never be called"; - } - - void StopProducing() override { - std::unique_lock lock(mutex_); - StopProducingUnlocked(); - } - - void InputReceived(ExecNode* input, int seq_num, - compute::ExecBatch exec_batch) override { - std::unique_lock lock(mutex_); - if (stopped_) { - return; - } - auto maybe_batch = MakeBatch(std::move(exec_batch)); - if (!maybe_batch.ok()) { - lock.unlock(); - producer_->Push(std::move(maybe_batch)); - return; - } - - // TODO would be nice to factor this out in a ReorderQueue - auto batch = *std::move(maybe_batch); - if (seq_num <= static_cast(received_batches_.size())) { - received_batches_.resize(seq_num + 1, nullptr); - } - DCHECK_EQ(received_batches_[seq_num], nullptr); - received_batches_[seq_num] = std::move(batch); - ++num_received_; - - if (seq_num != num_emitted_) { - // Cannot emit yet as there is a hole at `num_emitted_` - DCHECK_GT(seq_num, num_emitted_); - DCHECK_EQ(received_batches_[num_emitted_], nullptr); - return; - } - if (num_received_ == emit_stop_) { - StopProducingUnlocked(); - } - - // Emit batches in order as far as possible - // First collect these batches, then unlock before producing. - const auto seq_start = seq_num; - while (seq_num < static_cast(received_batches_.size()) && - received_batches_[seq_num] != nullptr) { - ++seq_num; - } - DCHECK_GT(seq_num, seq_start); - // By moving the values now, we make sure another thread won't emit the same values - // below - RecordBatchVector to_emit( - std::make_move_iterator(received_batches_.begin() + seq_start), - std::make_move_iterator(received_batches_.begin() + seq_num)); - - lock.unlock(); - for (auto&& batch : to_emit) { - producer_->Push(std::move(batch)); - } - lock.lock(); - - DCHECK_EQ(seq_start, num_emitted_); // num_emitted_ wasn't bumped in the meantime - num_emitted_ = seq_num; - } +} // namespace - void ErrorReceived(ExecNode* input, Status error) override { - // XXX do we care about properly sequencing the error? - producer_->Push(std::move(error)); - std::unique_lock lock(mutex_); - StopProducingUnlocked(); - } +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, + int num_outputs, StartProducingFunc start_producing, + StopProducingFunc stop_producing) { + return plan->EmplaceNode(plan, std::move(label), std::move(inputs), + num_outputs, std::move(start_producing), + std::move(stop_producing)); +} - void InputFinished(ExecNode* input, int seq_stop) override { - std::unique_lock lock(mutex_); - DCHECK_GE(seq_stop, static_cast(received_batches_.size())); - received_batches_.reserve(seq_stop); - emit_stop_ = seq_stop; - if (emit_stop_ == num_received_) { - DCHECK_EQ(emit_stop_, num_emitted_); - StopProducingUnlocked(); - } - } +ExecBatch ExecBatchFromJSON(const std::vector& descrs, + util::string_view json) { + auto fields = internal::MapVector( + [](const ValueDescr& descr) { return field("", descr.type); }, descrs); - private: - void StopProducingUnlocked() { - if (!stopped_) { - stopped_ = true; - producer_->Close(); - inputs_[0]->StopProducing(this); - } - } + ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; - // TODO factor this out as ExecBatch::ToRecordBatch()? - Result> MakeBatch(compute::ExecBatch&& exec_batch) { - ArrayDataVector columns; - columns.reserve(exec_batch.values.size()); - for (auto&& value : exec_batch.values) { - if (!value.is_array()) { - return Status::TypeError("Expected array input"); + auto value_it = batch.values.begin(); + for (const auto& descr : descrs) { + if (descr.shape == ValueDescr::SCALAR) { + if (batch.length == 0) { + *value_it = MakeNullScalar(value_it->type()); + } else { + *value_it = value_it->make_array()->GetScalar(0).ValueOrDie(); } - columns.push_back(std::move(value).array()); } - return RecordBatch::Make(schema_, exec_batch.length, std::move(columns)); + ++value_it; } - const std::shared_ptr schema_; - - std::mutex mutex_; - RecordBatchVector received_batches_; - int num_received_; - int num_emitted_; - int emit_stop_; - bool stopped_; - - PushGenerator> generator_; - util::optional>::Producer> producer_; -}; - -} // namespace - -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, - Executor* io_executor) { - return plan->EmplaceNode(plan, std::move(label), - std::move(reader), io_executor); -} - -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr schema, - RecordBatchGenerator generator, - ::arrow::internal::Executor* io_executor) { - return plan->EmplaceNode( - plan, std::move(label), std::move(schema), std::move(generator), io_executor); -} - -ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, - int num_outputs, StartProducingFunc start_producing, - StopProducingFunc stop_producing) { - return plan->EmplaceNode(plan, std::move(label), num_inputs, num_outputs, - std::move(start_producing), - std::move(stop_producing)); -} - -RecordBatchCollectNode* MakeRecordBatchCollectNode( - ExecPlan* plan, std::string label, const std::shared_ptr& schema) { - return arrow::internal::checked_cast( - plan->EmplaceNode(plan, std::move(label), schema)); + return batch; } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index c2dc785a501..faa395bab78 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -18,15 +18,13 @@ #pragma once #include -#include #include #include +#include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" -#include "arrow/record_batch.h" #include "arrow/testing/visibility.h" -#include "arrow/util/async_generator.h" -#include "arrow/util/type_fwd.h" +#include "arrow/util/string_view.h" namespace arrow { namespace compute { @@ -36,35 +34,12 @@ using StopProducingFunc = std::function; // Make a dummy node that has no execution behaviour ARROW_TESTING_EXPORT -ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, int num_inputs, +ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector inputs, int num_outputs, StartProducingFunc = {}, StopProducingFunc = {}); -using RecordBatchGenerator = AsyncGenerator>; - -// Make a source node (no inputs) that produces record batches by reading in the -// background from a RecordBatchReader. -ARROW_TESTING_EXPORT -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr reader, - ::arrow::internal::Executor* io_executor); - -ARROW_TESTING_EXPORT -ExecNode* MakeRecordBatchReaderNode(ExecPlan* plan, std::string label, - std::shared_ptr schema, - RecordBatchGenerator generator, - ::arrow::internal::Executor* io_executor); - -class RecordBatchCollectNode : public ExecNode { - public: - virtual RecordBatchGenerator generator() = 0; - - protected: - using ExecNode::ExecNode; -}; - ARROW_TESTING_EXPORT -RecordBatchCollectNode* MakeRecordBatchCollectNode(ExecPlan* plan, std::string label, - const std::shared_ptr& schema); +ExecBatch ExecBatchFromJSON(const std::vector& descrs, + util::string_view json); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index d6a1d4ccbc4..e723bd7838e 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -244,7 +244,7 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { for (const auto& descr : descrs) { auto id = descr.type->id(); - auto max_width = is_signed_integer(id) ? &max_width_signed : &max_width_unsigned; + auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned); *max_width = std::max(bit_width(id), *max_width); } diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 8a0d6de7f25..eebc8c1b678 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -41,6 +41,8 @@ struct VectorKernel; struct KernelState; class Expression; +class ExecNode; +class ExecPlan; } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 841b792ee34..fc6b38b37a9 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -165,13 +165,8 @@ Dataset::Dataset(std::shared_ptr schema, compute::Expression partition_e : schema_(std::move(schema)), partition_expression_(std::move(partition_expression)) {} -Result> Dataset::NewScan( - std::shared_ptr options) { - return std::make_shared(this->shared_from_this(), options); -} - Result> Dataset::NewScan() { - return NewScan(std::make_shared()); + return std::make_shared(this->shared_from_this()); } Result Dataset::GetFragments() { diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index d2cba730252..11210fdc27b 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -155,7 +155,6 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { public: /// \brief Begin to build a new Scan operation against this Dataset - Result> NewScan(std::shared_ptr options); Result> NewScan(); /// \brief GetFragments returns an iterator of Fragments given a predicate. diff --git a/cpp/src/arrow/dataset/dataset_internal.h b/cpp/src/arrow/dataset/dataset_internal.h index 4336f9c157e..952ad3e83ca 100644 --- a/cpp/src/arrow/dataset/dataset_internal.h +++ b/cpp/src/arrow/dataset/dataset_internal.h @@ -204,5 +204,35 @@ arrow::Result> GetFragmentScanOptions( return internal::checked_pointer_cast(source); } +class FragmentDataset : public Dataset { + public: + FragmentDataset(std::shared_ptr schema, FragmentVector fragments) + : Dataset(std::move(schema)), fragments_(std::move(fragments)) {} + + std::string type_name() const override { return "fragment"; } + + Result> ReplaceSchema( + std::shared_ptr schema) const override { + return std::make_shared(std::move(schema), fragments_); + } + + protected: + Result GetFragmentsImpl(compute::Expression predicate) override { + // TODO(ARROW-12891) Provide subtree pruning for any vector of fragments + FragmentVector fragments; + for (const auto& fragment : fragments_) { + ARROW_ASSIGN_OR_RAISE( + auto simplified_filter, + compute::SimplifyWithGuarantee(predicate, fragment->partition_expression())); + + if (simplified_filter.IsSatisfiable()) { + fragments.push_back(fragment); + } + } + return MakeVectorIterator(std::move(fragments)); + } + FragmentVector fragments_; +}; + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index fd96fe8f50e..3f42ab44a39 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -162,8 +162,8 @@ static inline Future> OpenReaderAsync( })); return reader_fut.Then( // Adds the filename to the error - [](const std::shared_ptr& maybe_reader) - -> Result> { return maybe_reader; }, + [](const std::shared_ptr& reader) + -> Result> { return reader; }, [source](const Status& err) -> Result> { return err.WithMessage("Could not open CSV input source '", source.path(), "': ", err); diff --git a/cpp/src/arrow/dataset/file_ipc_test.cc b/cpp/src/arrow/dataset/file_ipc_test.cc index f0409abe85b..e6192523f53 100644 --- a/cpp/src/arrow/dataset/file_ipc_test.cc +++ b/cpp/src/arrow/dataset/file_ipc_test.cc @@ -100,13 +100,6 @@ class TestIpcFileSystemDataset : public testing::Test, format_ = ipc_format; SetWriteOptions(ipc_format->DefaultWriteOptions()); } - - std::shared_ptr MakeScanner(const std::shared_ptr& dataset, - const std::shared_ptr& scan_options) { - ScannerBuilder builder(dataset, scan_options); - EXPECT_OK_AND_ASSIGN(auto scanner, builder.Finish()); - return scanner; - } }; TEST_F(TestIpcFileSystemDataset, WriteWithIdenticalPartitioningSchema) { @@ -132,7 +125,7 @@ TEST_F(TestIpcFileSystemDataset, WriteExceedsMaxPartitions) { // require that no batch be grouped into more than 2 written batches: write_options_.max_partitions = 2; - auto scanner = MakeScanner(dataset_, scan_options_); + EXPECT_OK_AND_ASSIGN(auto scanner, ScannerBuilder(dataset_, scan_options_).Finish()); EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("This exceeds the maximum"), FileSystemDataset::Write(write_options_, scanner)); } diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index 04c86b1f16f..ffa64e8ec10 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -25,6 +25,7 @@ #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/io/memory.h" +#include "arrow/io/util_internal.h" #include "arrow/record_batch.h" #include "arrow/table.h" #include "arrow/testing/gtest_util.h" @@ -283,6 +284,32 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) { } } +TEST_F(TestParquetFileFormat, MultithreadedScan) { + constexpr int64_t kNumRowGroups = 16; + + // See PredicatePushdown test below for a description of the generated data + auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups); + auto source = GetFileSource(reader.get()); + auto options = std::make_shared(); + + auto fragment = MakeFragment(*source); + + FragmentDataset dataset(ArithmeticDatasetFixture::schema(), {fragment}); + ScannerBuilder builder({&dataset, [](...) {}}); + + ASSERT_OK(builder.UseAsync(true)); + ASSERT_OK(builder.UseThreads(true)); + ASSERT_OK(builder.Project({call("add", {field_ref("i64"), literal(3)})}, {""})); + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto gen, scanner->ScanBatchesUnorderedAsync()); + + auto collect_fut = CollectAsyncGenerator(gen); + ASSERT_OK_AND_ASSIGN(auto batches, collect_fut.result()); + + ASSERT_EQ(batches.size(), kNumRowGroups); +} + class TestParquetFileSystemDataset : public WriteFileSystemDatasetMixin, public testing::Test { public: diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index b80d1bb57f0..5bf89330429 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -169,7 +169,8 @@ TEST_F(TestFileSystemDataset, ReplaceSchema) { TEST_F(TestFileSystemDataset, RootPartitionPruning) { auto root_partition = equal(field_ref("i32"), literal(5)); - MakeDataset({fs::File("a"), fs::File("b")}, root_partition); + MakeDataset({fs::File("a"), fs::File("b")}, root_partition, {}, + schema({field("i32", int32()), field("f32", float32())})); auto GetFragments = [&](compute::Expression filter) { return *dataset_->GetFragments(*filter.Bind(*dataset_->schema())); @@ -191,8 +192,9 @@ TEST_F(TestFileSystemDataset, RootPartitionPruning) { AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), {"a", "b"}); - // No partition should match - MakeDataset({fs::File("a"), fs::File("b")}); + // No root partition: don't prune any fragments + MakeDataset({fs::File("a"), fs::File("b")}, literal(true), {}, + schema({field("i32", int32()), field("f32", float32())})); AssertFragmentsAreFromPath(GetFragments(equal(field_ref("f32"), literal(3.F))), {"a", "b"}); } diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 5c390b6b487..1ec47e3cee1 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -30,6 +30,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/filesystem/path_util.h" #include "arrow/scalar.h" @@ -252,7 +253,7 @@ Result KeyValuePartitioning::Format(const compute::Expression& expr ScalarVector values{static_cast(schema_->num_fields()), nullptr}; ARROW_ASSIGN_OR_RAISE(auto known_values, ExtractKnownFieldValues(expr)); - for (const auto& ref_value : known_values) { + for (const auto& ref_value : known_values.map) { if (!ref_value.second.is_scalar()) { return Status::Invalid("non-scalar partition key ", ref_value.second.ToString()); } diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index 09e05cdbf75..58e96fdc113 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -27,6 +27,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/scanner_internal.h" @@ -317,10 +318,6 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner { SyncScanner(std::shared_ptr dataset, std::shared_ptr scan_options) : Scanner(std::move(scan_options)), dataset_(std::move(dataset)) {} - SyncScanner(std::shared_ptr fragment, - std::shared_ptr scan_options) - : Scanner(std::move(scan_options)), fragment_(std::move(fragment)) {} - Result ScanBatches() override; Result Scan() override; Status Scan(std::function visitor) override; @@ -337,8 +334,6 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner { Result ScanInternal(); std::shared_ptr dataset_; - // TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments - std::shared_ptr fragment_; }; Result SyncScanner::ScanBatches() { @@ -370,10 +365,6 @@ Result SyncScanner::ScanBatchesUnorderedAsync() } Result SyncScanner::GetFragments() { - if (fragment_ != nullptr) { - return MakeVectorIterator(FragmentVector{fragment_}); - } - // Transform Datasets in a flat Iterator. This // iterator is lazily constructed, i.e. Dataset::GetFragments is // not invoked until a Fragment is requested. @@ -411,18 +402,6 @@ Result SyncScanner::ScanInternal() { return GetScanTaskIterator(std::move(fragment_it), scan_options_); } -Result ScanTaskIteratorFromRecordBatch( - std::vector> batches, - std::shared_ptr options) { - if (batches.empty()) { - return MakeVectorIterator(ScanTaskVector()); - } - auto schema = batches[0]->schema(); - auto fragment = - std::make_shared(std::move(schema), std::move(batches)); - return fragment->Scan(std::move(options)); -} - class ARROW_DS_EXPORT AsyncScanner : public Scanner, public std::enable_shared_from_this { public: @@ -454,15 +433,17 @@ class ARROW_DS_EXPORT AsyncScanner : public Scanner, namespace { inline Result DoFilterAndProjectRecordBatchAsync( - const std::shared_ptr& scanner, const EnumeratedRecordBatch& in) { - ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_filter, - SimplifyWithGuarantee(scanner->options()->filter, - in.fragment.value->partition_expression())); - - compute::ExecContext exec_context{scanner->options()->pool}; + const std::shared_ptr& options, const EnumeratedRecordBatch& in) { ARROW_ASSIGN_OR_RAISE( - Datum mask, ExecuteScalarExpression(simplified_filter, Datum(in.record_batch.value), - &exec_context)); + compute::Expression simplified_filter, + SimplifyWithGuarantee(options->filter, in.fragment.value->partition_expression())); + + const auto& schema = *options->dataset_schema; + + compute::ExecContext exec_context{options->pool}; + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(simplified_filter, schema, + in.record_batch.value, &exec_context)); Datum filtered; if (mask.is_scalar()) { @@ -481,11 +462,12 @@ inline Result DoFilterAndProjectRecordBatchAsync( } ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, - SimplifyWithGuarantee(scanner->options()->projection, + SimplifyWithGuarantee(options->projection, in.fragment.value->partition_expression())); + ARROW_ASSIGN_OR_RAISE( Datum projected, - ExecuteScalarExpression(simplified_projection, filtered, &exec_context)); + ExecuteScalarExpression(simplified_projection, schema, filtered, &exec_context)); DCHECK_EQ(projected.type()->id(), Type::STRUCT); if (projected.shape() == ValueDescr::SCALAR) { @@ -493,7 +475,7 @@ inline Result DoFilterAndProjectRecordBatchAsync( ARROW_ASSIGN_OR_RAISE( projected, MakeArrayFromScalar(*projected.scalar(), filtered.record_batch()->num_rows(), - scanner->options()->pool)); + options->pool)); } ARROW_ASSIGN_OR_RAISE(auto out, RecordBatch::FromStructArray(projected.array_as())); @@ -506,17 +488,16 @@ inline Result DoFilterAndProjectRecordBatchAsync( } inline EnumeratedRecordBatchGenerator FilterAndProjectRecordBatchAsync( - const std::shared_ptr& scanner, EnumeratedRecordBatchGenerator rbs) { - auto mapper = [scanner](const EnumeratedRecordBatch& in) { - return DoFilterAndProjectRecordBatchAsync(scanner, in); + const std::shared_ptr& options, EnumeratedRecordBatchGenerator rbs) { + auto mapper = [options](const EnumeratedRecordBatch& in) { + return DoFilterAndProjectRecordBatchAsync(options, in); }; - return MakeMappedGenerator(std::move(rbs), mapper); + return MakeMappedGenerator(std::move(rbs), mapper); } Result FragmentToBatches( - std::shared_ptr scanner, const Enumerated>& fragment, - const std::shared_ptr& options) { + const std::shared_ptr& options, bool filter_and_project = true) { ARROW_ASSIGN_OR_RAISE(auto batch_gen, fragment.value->ScanBatchesAsync(options)); auto enumerated_batch_gen = MakeEnumeratedGenerator(std::move(batch_gen)); @@ -525,30 +506,37 @@ Result FragmentToBatches( return EnumeratedRecordBatch{record_batch, fragment}; }; - auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, - std::move(combine_fn)); + auto combined_gen = MakeMappedGenerator(enumerated_batch_gen, std::move(combine_fn)); - return FilterAndProjectRecordBatchAsync(scanner, std::move(combined_gen)); + if (filter_and_project) { + return FilterAndProjectRecordBatchAsync(options, std::move(combined_gen)); + } + return combined_gen; } Result> FragmentsToBatches( - std::shared_ptr scanner, FragmentGenerator fragment_gen) { + FragmentGenerator fragment_gen, const std::shared_ptr& options, + bool filter_and_project = true) { auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); - return MakeMappedGenerator( - std::move(enumerated_fragment_gen), - [scanner](const Enumerated>& fragment) { - return FragmentToBatches(scanner, fragment, scanner->options()); - }); + return MakeMappedGenerator(std::move(enumerated_fragment_gen), + [=](const Enumerated>& fragment) { + return FragmentToBatches(fragment, options, + filter_and_project); + }); } Result>>> FragmentsToRowCount( - std::shared_ptr scanner, FragmentGenerator fragment_gen) { + FragmentGenerator fragment_gen, + std::shared_ptr options_with_projection) { // Must use optional to avoid breaking the pipeline on empty batches auto enumerated_fragment_gen = MakeEnumeratedGenerator(std::move(fragment_gen)); - auto options = std::make_shared(*scanner->options()); + + // Drop projection since we only need to count rows + auto options = std::make_shared(*options_with_projection); RETURN_NOT_OK(SetProjection(options.get(), std::vector())); + auto count_fragment_fn = - [scanner, options](const Enumerated>& fragment) + [options](const Enumerated>& fragment) -> Result>> { auto count_fut = fragment.value->CountRows(options->filter, options); return MakeFromFuture( @@ -560,18 +548,29 @@ Result>>> FragmentsToRowCo Future>::MakeFinished(val)); } // Slow path - ARROW_ASSIGN_OR_RAISE(auto batch_gen, - FragmentToBatches(scanner, fragment, options)); + ARROW_ASSIGN_OR_RAISE(auto batch_gen, FragmentToBatches(fragment, options)); auto count_fn = [](const EnumeratedRecordBatch& enumerated) -> util::optional { return enumerated.record_batch.value->num_rows(); }; - return MakeMappedGenerator>(batch_gen, - std::move(count_fn)); + return MakeMappedGenerator(batch_gen, std::move(count_fn)); })); }; - return MakeMappedGenerator>>( - std::move(enumerated_fragment_gen), std::move(count_fragment_fn)); + return MakeMappedGenerator(std::move(enumerated_fragment_gen), + std::move(count_fragment_fn)); +} + +Result ScanBatchesUnorderedAsyncImpl( + const std::shared_ptr& options, FragmentGenerator fragment_gen, + internal::Executor* cpu_executor, bool filter_and_project = true) { + ARROW_ASSIGN_OR_RAISE( + auto batch_gen_gen, + FragmentsToBatches(std::move(fragment_gen), options, filter_and_project)); + auto batch_gen_gen_readahead = + MakeSerialReadaheadGenerator(std::move(batch_gen_gen), options->fragment_readahead); + auto merged_batch_gen = MakeMergedGenerator(std::move(batch_gen_gen_readahead), + options->fragment_readahead); + return MakeReadaheadGenerator(std::move(merged_batch_gen), options->fragment_readahead); } } // namespace @@ -607,16 +606,9 @@ Result AsyncScanner::ScanBatchesUnorderedAsync() Result AsyncScanner::ScanBatchesUnorderedAsync( internal::Executor* cpu_executor) { - auto self = shared_from_this(); ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); - ARROW_ASSIGN_OR_RAISE(auto batch_gen_gen, - FragmentsToBatches(self, std::move(fragment_gen))); - auto batch_gen_gen_readahead = MakeSerialReadaheadGenerator( - std::move(batch_gen_gen), scan_options_->fragment_readahead); - auto merged_batch_gen = MakeMergedGenerator(std::move(batch_gen_gen_readahead), - scan_options_->fragment_readahead); - return MakeReadaheadGenerator(std::move(merged_batch_gen), - scan_options_->fragment_readahead); + return ScanBatchesUnorderedAsyncImpl(scan_options_, std::move(fragment_gen), + cpu_executor); } Result AsyncScanner::ScanBatchesAsync() { @@ -626,13 +618,17 @@ Result AsyncScanner::ScanBatchesAsync() { Result AsyncScanner::ScanBatchesAsync( internal::Executor* cpu_executor) { ARROW_ASSIGN_OR_RAISE(auto unordered, ScanBatchesUnorderedAsync(cpu_executor)); - auto left_after_right = [](const EnumeratedRecordBatch& left, - const EnumeratedRecordBatch& right) { + // We need an initial value sentinel, so we use one with fragment.index < 0 + auto is_before_any = [](const EnumeratedRecordBatch& batch) { + return batch.fragment.index < 0; + }; + auto left_after_right = [&is_before_any](const EnumeratedRecordBatch& left, + const EnumeratedRecordBatch& right) { // Before any comes first - if (left.fragment.value == nullptr) { + if (is_before_any(left)) { return false; } - if (right.fragment.value == nullptr) { + if (is_before_any(right)) { return true; } // Compare batches if fragment is the same @@ -642,10 +638,10 @@ Result AsyncScanner::ScanBatchesAsync( // Otherwise compare fragment return left.fragment.index > right.fragment.index; }; - auto is_next = [](const EnumeratedRecordBatch& prev, - const EnumeratedRecordBatch& next) { + auto is_next = [is_before_any](const EnumeratedRecordBatch& prev, + const EnumeratedRecordBatch& next) { // Only true if next is the first batch - if (prev.fragment.value == nullptr) { + if (is_before_any(prev)) { return next.fragment.index == 0 && next.record_batch.index == 0; } // If same fragment, compare batch index @@ -664,7 +660,7 @@ Result AsyncScanner::ScanBatchesAsync( return TaggedRecordBatch{enumerated_batch.record_batch.value, enumerated_batch.fragment.value}; }; - return MakeMappedGenerator(std::move(sequenced), unenumerate_fn); + return MakeMappedGenerator(std::move(sequenced), unenumerate_fn); } struct AsyncTableAssemblyState { @@ -725,8 +721,8 @@ Future> AsyncScanner::ToTableAsync( return batch; }; - auto table_building_gen = MakeMappedGenerator( - positioned_batch_gen, table_building_task); + auto table_building_gen = + MakeMappedGenerator(positioned_batch_gen, table_building_task); return DiscardAllFromAsyncGenerator(table_building_gen).Then([state, scan_options]() { return Table::FromRecordBatches(scan_options->projected_schema, state->Finish()); @@ -734,10 +730,9 @@ Future> AsyncScanner::ToTableAsync( } Result AsyncScanner::CountRows() { - auto self = shared_from_this(); ARROW_ASSIGN_OR_RAISE(auto fragment_gen, GetFragments()); ARROW_ASSIGN_OR_RAISE(auto count_gen_gen, - FragmentsToRowCount(self, std::move(fragment_gen))); + FragmentsToRowCount(std::move(fragment_gen), scan_options_)); auto count_gen = MakeConcatenatedGenerator(std::move(count_gen_gen)); int64_t total = 0; auto sum_fn = [&total](util::optional count) -> Status { @@ -755,9 +750,7 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset) ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, std::shared_ptr scan_options) - : dataset_(std::move(dataset)), - fragment_(nullptr), - scan_options_(std::move(scan_options)) { + : dataset_(std::move(dataset)), scan_options_(std::move(scan_options)) { scan_options_->dataset_schema = dataset_->schema(); DCHECK_OK(Filter(scan_options_->filter)); } @@ -765,12 +758,9 @@ ScannerBuilder::ScannerBuilder(std::shared_ptr dataset, ScannerBuilder::ScannerBuilder(std::shared_ptr schema, std::shared_ptr fragment, std::shared_ptr scan_options) - : dataset_(nullptr), - fragment_(std::move(fragment)), - scan_options_(std::move(scan_options)) { - scan_options_->dataset_schema = std::move(schema); - DCHECK_OK(Filter(scan_options_->filter)); -} + : ScannerBuilder(std::make_shared( + std::move(schema), FragmentVector{std::move(fragment)}), + std::move(scan_options)) {} namespace { class OneShotScanTask : public ScanTask { @@ -898,10 +888,6 @@ Result> ScannerBuilder::Finish() { RETURN_NOT_OK(Project(scan_options_->dataset_schema->field_names())); } - if (dataset_ == nullptr) { - // AsyncScanner does not support this method of running. It may in the future - return std::make_shared(fragment_, scan_options_); - } if (scan_options_->use_async) { return std::make_shared(dataset_, scan_options_); } else { @@ -1119,5 +1105,51 @@ Result SyncScanner::CountRows() { return count; } +Result MakeScanNode(compute::ExecPlan* plan, + std::shared_ptr dataset, + std::shared_ptr scan_options) { + if (!scan_options->use_async) { + return Status::NotImplemented("ScanNodes without asynchrony"); + } + + // using a generator for speculative forward compatibility with async fragment discovery + ARROW_ASSIGN_OR_RAISE(scan_options->filter, + scan_options->filter.Bind(*dataset->schema())); + ARROW_ASSIGN_OR_RAISE(auto fragments_it, dataset->GetFragments(scan_options->filter)); + ARROW_ASSIGN_OR_RAISE(auto fragments_vec, fragments_it.ToVector()); + auto fragments_gen = MakeVectorGenerator(std::move(fragments_vec)); + + ARROW_ASSIGN_OR_RAISE(auto batch_gen, + ScanBatchesUnorderedAsyncImpl( + scan_options, std::move(fragments_gen), + internal::GetCpuThreadPool(), /*filter_and_project=*/false)); + + auto gen = MakeMappedGenerator( + std::move(batch_gen), + [dataset](const EnumeratedRecordBatch& partial) + -> Result> { + ARROW_ASSIGN_OR_RAISE( + util::optional batch, + compute::MakeExecBatch(*dataset->schema(), partial.record_batch.value)); + + // TODO fragments may be able to attach more guarantees to batches than this, + // for example parquet's row group stats. + batch->guarantee = partial.fragment.value->partition_expression(); + + // tag rows with fragment- and batch-of-origin + batch->values.emplace_back(partial.fragment.index); + batch->values.emplace_back(partial.record_batch.index); + batch->values.emplace_back(partial.record_batch.last); + return batch; + }); + + auto augmented_fields = dataset->schema()->fields(); + augmented_fields.push_back(field("__fragment_index", int32())); + augmented_fields.push_back(field("__batch_index", int32())); + augmented_fields.push_back(field("__last_in_fragment", boolean())); + return compute::MakeSourceNode(plan, "dataset_scan", + schema(std::move(augmented_fields)), std::move(gen)); +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner.h b/cpp/src/arrow/dataset/scanner.h index 29fd5aad994..c803cde1978 100644 --- a/cpp/src/arrow/dataset/scanner.h +++ b/cpp/src/arrow/dataset/scanner.h @@ -26,6 +26,7 @@ #include #include "arrow/compute/exec/expression.h" +#include "arrow/compute/type_fwd.h" #include "arrow/dataset/dataset.h" #include "arrow/dataset/projector.h" #include "arrow/dataset/type_fwd.h" @@ -194,20 +195,22 @@ using EnumeratedRecordBatchIterator = Iterator; template <> struct IterationTraits { static dataset::TaggedRecordBatch End() { - return dataset::TaggedRecordBatch{NULL, NULL}; + return dataset::TaggedRecordBatch{NULLPTR, NULLPTR}; } static bool IsEnd(const dataset::TaggedRecordBatch& val) { - return val.record_batch == NULL; + return val.record_batch == NULLPTR; } }; template <> struct IterationTraits { static dataset::EnumeratedRecordBatch End() { - return dataset::EnumeratedRecordBatch{{NULL, -1, false}, {NULL, -1, false}}; + return dataset::EnumeratedRecordBatch{ + IterationEnd>>(), + IterationEnd>>()}; } static bool IsEnd(const dataset::EnumeratedRecordBatch& val) { - return val.fragment.value == NULL; + return IsIterationEnd(val.fragment); } }; @@ -401,10 +404,16 @@ class ARROW_DS_EXPORT ScannerBuilder { private: std::shared_ptr dataset_; - std::shared_ptr fragment_; - std::shared_ptr scan_options_; + std::shared_ptr scan_options_ = std::make_shared(); }; +/// \brief Construct a source ExecNode which yields batches from a dataset scan. +/// +/// Does not construct associated filter or project nodes +ARROW_DS_EXPORT Result MakeScanNode(compute::ExecPlan*, + std::shared_ptr, + std::shared_ptr); + /// @} /// \brief A trivial ScanTask that yields the RecordBatch of an array. @@ -422,9 +431,5 @@ class ARROW_DS_EXPORT InMemoryScanTask : public ScanTask { std::vector> record_batches_; }; -ARROW_DS_EXPORT Result ScanTaskIteratorFromRecordBatch( - std::vector> batches, - std::shared_ptr options); - } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/scanner_internal.h b/cpp/src/arrow/dataset/scanner_internal.h index 30fb4e07cef..27b32aa6f19 100644 --- a/cpp/src/arrow/dataset/scanner_internal.h +++ b/cpp/src/arrow/dataset/scanner_internal.h @@ -40,10 +40,11 @@ namespace dataset { inline Result> FilterSingleBatch( const std::shared_ptr& in, const compute::Expression& filter, - MemoryPool* pool) { - compute::ExecContext exec_context{pool}; - ARROW_ASSIGN_OR_RAISE(Datum mask, - ExecuteScalarExpression(filter, Datum(in), &exec_context)); + const std::shared_ptr& options) { + compute::ExecContext exec_context{options->pool}; + ARROW_ASSIGN_OR_RAISE( + Datum mask, + ExecuteScalarExpression(filter, *options->dataset_schema, in, &exec_context)); if (mask.is_scalar()) { const auto& mask_scalar = mask.scalar_as(); @@ -59,28 +60,29 @@ inline Result> FilterSingleBatch( return filtered.record_batch(); } -inline RecordBatchIterator FilterRecordBatch(RecordBatchIterator it, - compute::Expression filter, - MemoryPool* pool) { +inline RecordBatchIterator FilterRecordBatch( + RecordBatchIterator it, compute::Expression filter, + const std::shared_ptr& options) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { - return FilterSingleBatch(in, filter, pool); + return FilterSingleBatch(in, filter, options); }, std::move(it)); } inline Result> ProjectSingleBatch( const std::shared_ptr& in, const compute::Expression& projection, - MemoryPool* pool) { - compute::ExecContext exec_context{pool}; - ARROW_ASSIGN_OR_RAISE(Datum projected, - ExecuteScalarExpression(projection, Datum(in), &exec_context)); + const std::shared_ptr& options) { + compute::ExecContext exec_context{options->pool}; + ARROW_ASSIGN_OR_RAISE( + Datum projected, + ExecuteScalarExpression(projection, *options->dataset_schema, in, &exec_context)); DCHECK_EQ(projected.type()->id(), Type::STRUCT); if (projected.shape() == ValueDescr::SCALAR) { // Only virtual columns are projected. Broadcast to an array - ARROW_ASSIGN_OR_RAISE(projected, - MakeArrayFromScalar(*projected.scalar(), in->num_rows(), pool)); + ARROW_ASSIGN_OR_RAISE(projected, MakeArrayFromScalar(*projected.scalar(), + in->num_rows(), options->pool)); } ARROW_ASSIGN_OR_RAISE(auto out, @@ -89,12 +91,12 @@ inline Result> ProjectSingleBatch( return out->ReplaceSchemaMetadata(in->schema()->metadata()); } -inline RecordBatchIterator ProjectRecordBatch(RecordBatchIterator it, - compute::Expression projection, - MemoryPool* pool) { +inline RecordBatchIterator ProjectRecordBatch( + RecordBatchIterator it, compute::Expression projection, + const std::shared_ptr& options) { return MakeMaybeMapIterator( [=](std::shared_ptr in) -> Result> { - return ProjectSingleBatch(in, projection, pool); + return ProjectSingleBatch(in, projection, options); }, std::move(it)); } @@ -117,10 +119,9 @@ class FilterAndProjectScanTask : public ScanTask { SimplifyWithGuarantee(options()->projection, partition_)); RecordBatchIterator filter_it = - FilterRecordBatch(std::move(it), simplified_filter, options_->pool); + FilterRecordBatch(std::move(it), simplified_filter, options_); - return ProjectRecordBatch(std::move(filter_it), simplified_projection, - options_->pool); + return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_); } Result ToFilteredAndProjectedIterator( @@ -133,10 +134,9 @@ class FilterAndProjectScanTask : public ScanTask { SimplifyWithGuarantee(options()->projection, partition_)); RecordBatchIterator filter_it = - FilterRecordBatch(std::move(it), simplified_filter, options_->pool); + FilterRecordBatch(std::move(it), simplified_filter, options_); - return ProjectRecordBatch(std::move(filter_it), simplified_projection, - options_->pool); + return ProjectRecordBatch(std::move(filter_it), simplified_projection, options_); } Result> FilterAndProjectBatch( @@ -147,8 +147,8 @@ class FilterAndProjectScanTask : public ScanTask { ARROW_ASSIGN_OR_RAISE(compute::Expression simplified_projection, SimplifyWithGuarantee(options()->projection, partition_)); ARROW_ASSIGN_OR_RAISE(auto filtered, - FilterSingleBatch(batch, simplified_filter, options_->pool)); - return ProjectSingleBatch(filtered, simplified_projection, options_->pool); + FilterSingleBatch(batch, simplified_filter, options_)); + return ProjectSingleBatch(filtered, simplified_projection, options_); } inline Future SafeExecute(internal::Executor* executor) override { diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 87fc2c902c3..bed276b1bff 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -18,6 +18,7 @@ #include "arrow/dataset/scanner.h" #include +#include #include @@ -25,6 +26,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/cast.h" +#include "arrow/compute/exec/exec_plan.h" #include "arrow/dataset/scanner_internal.h" #include "arrow/dataset/test_util.h" #include "arrow/record_batch.h" @@ -32,11 +34,14 @@ #include "arrow/testing/future_util.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/testing/util.h" #include "arrow/util/range.h" +#include "arrow/util/vector.h" using testing::ElementsAre; using testing::IsEmpty; +using testing::UnorderedElementsAreArray; namespace arrow { namespace dataset { @@ -922,7 +927,7 @@ TEST_F(TestReordering, ScanBatchesUnordered) { struct BatchConsumer { explicit BatchConsumer(EnumeratedRecordBatchGenerator generator) - : generator(generator), next() {} + : generator(std::move(generator)), next() {} void AssertCanConsume() { if (!next.is_valid()) { @@ -1087,5 +1092,301 @@ TEST(ScanOptions, TestMaterializedFields) { EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i64", "i32")); } +namespace { + +static Result> StartAndCollect( + compute::ExecPlan* plan, AsyncGenerator> gen) { + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + + auto maybe_collected = CollectAsyncGenerator(gen).result(); + ARROW_ASSIGN_OR_RAISE(auto collected, maybe_collected); + + plan->StopProducing(); + + return internal::MapVector( + [](util::optional batch) { return std::move(*batch); }, + collected); +} + +struct DatasetAndBatches { + std::shared_ptr dataset; + std::vector batches; +}; + +DatasetAndBatches MakeBasicDataset() { + const auto dataset_schema = ::arrow::schema({ + field("a", int32()), + field("b", boolean()), + field("c", int32()), + }); + + const auto physical_schema = SchemaFromColumnNames(dataset_schema, {"a", "b"}); + + RecordBatchVector record_batches{ + RecordBatchFromJSON(physical_schema, R"([{"a": 1, "b": null}, + {"a": 2, "b": true}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 3, "b": false}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": null, "b": true}, + {"a": 4, "b": false}])"), + RecordBatchFromJSON(physical_schema, R"([{"a": 5, "b": null}, + {"a": 6, "b": false}, + {"a": 7, "b": false}])"), + }; + + auto dataset = std::make_shared( + dataset_schema, + FragmentVector{ + std::make_shared( + physical_schema, RecordBatchVector{record_batches[0], record_batches[1]}, + equal(field_ref("c"), literal(23))), + std::make_shared( + physical_schema, RecordBatchVector{record_batches[2], record_batches[3]}, + equal(field_ref("c"), literal(47))), + }); + + std::vector batches; + + auto batch_it = record_batches.begin(); + for (int fragment_index = 0; fragment_index < 2; ++fragment_index) { + for (int batch_index = 0; batch_index < 2; ++batch_index) { + const auto& batch = *batch_it++; + + // the scanned ExecBatches will begin with physical columns + batches.emplace_back(*batch); + + // a placeholder will be inserted for partition field "c" + batches.back().values.emplace_back(std::make_shared()); + + // scanned batches will be augmented with fragment and batch indices + batches.back().values.emplace_back(fragment_index); + batches.back().values.emplace_back(batch_index); + + // ... and with the last-in-fragment flag + batches.back().values.emplace_back(batch_index == 1); + + // each batch carries a guarantee inherited from its Fragment's partition expression + batches.back().guarantee = + equal(field_ref("c"), literal(fragment_index == 0 ? 23 : 47)); + } + } + + return {dataset, batches}; +} +} // namespace + +TEST(ScanNode, Schema) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto fields = basic.dataset->schema()->fields(); + fields.push_back(field("__fragment_index", int32())); + fields.push_back(field("__batch_index", int32())); + fields.push_back(field("__last_in_fragment", boolean())); + AssertSchemaEqual(Schema(fields), *scan->output_schema()); +} + +TEST(ScanNode, Trivial) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + auto sink_gen = MakeSinkNode(scan, "sink"); + + // trivial scan: the batches are returned unmodified + auto expected = basic.batches; + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); +} + +TEST(ScanNode, FilteredOnVirtualColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + options->filter = less(field_ref("c"), literal(30)); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto sink_gen = MakeSinkNode(scan, "sink"); + + auto expected = basic.batches; + + // only the first fragment will make it past the filter + expected.pop_back(); + expected.pop_back(); + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); +} + +TEST(ScanNode, DeferredFilterOnPhysicalColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + options->filter = greater(field_ref("a"), literal(4)); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + auto sink_gen = MakeSinkNode(scan, "sink"); + + // No post filtering is performed by ScanNode: all batches will be yielded whole. + // To filter out rows from individual batches, construct a FilterNode. + auto expected = basic.batches; + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); +} + +TEST(ScanNode, ProjectionPushdown) { + // ensure non-projected columns are dropped +} + +TEST(ScanNode, MaterializationOfVirtualColumn) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + auto options = std::make_shared(); + options->use_async = true; + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + ASSERT_OK_AND_ASSIGN( + auto project, + compute::MakeProjectNode( + scan, "project", + {field_ref("a"), field_ref("b"), field_ref("c"), field_ref("__fragment_index"), + field_ref("__batch_index"), field_ref("__last_in_fragment")})); + + auto sink_gen = MakeSinkNode(project, "sink"); + + auto expected = basic.batches; + + for (auto& batch : expected) { + // ProjectNode overwrites "c" placeholder with non-null drawn from guarantee + const auto& value = *batch.guarantee.call()->arguments[1].literal(); + batch.values[project->output_schema()->GetFieldIndex("c")] = value; + } + + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + ResultWith(UnorderedElementsAreArray(expected))); +} + +TEST(ScanNode, CompareToScanner) { + ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); + + auto basic = MakeBasicDataset(); + + ScannerBuilder builder(basic.dataset); + ASSERT_OK(builder.UseAsync(true)); + ASSERT_OK(builder.UseThreads(true)); + ASSERT_OK(builder.Filter(greater(field_ref("c"), literal(30)))); + ASSERT_OK(builder.Project( + {field_ref("c"), call("multiply", {field_ref("a"), literal(2)})}, {"c", "a * 2"})); + ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish()); + + ASSERT_OK_AND_ASSIGN(auto fragments_it, + basic.dataset->GetFragments(scanner->options()->filter)); + ASSERT_OK_AND_ASSIGN(auto fragments, fragments_it.ToVector()); + + auto options = scanner->options(); + + ASSERT_OK_AND_ASSIGN(auto scan, MakeScanNode(plan.get(), basic.dataset, options)); + + ASSERT_OK_AND_ASSIGN(auto filter, + compute::MakeFilterNode(scan, "filter", options->filter)); + + auto exprs = options->projection.call()->arguments; + exprs.push_back(compute::field_ref("__fragment_index")); + exprs.push_back(compute::field_ref("__batch_index")); + exprs.push_back(compute::field_ref("__last_in_fragment")); + ASSERT_OK_AND_ASSIGN(auto project, compute::MakeProjectNode(filter, "project", exprs)); + + AsyncGenerator> sink_gen = + compute::MakeSinkNode(project, "sink"); + + ASSERT_OK(plan->StartProducing()); + + auto from_plan = + CollectAsyncGenerator( + MakeMappedGenerator( + sink_gen, + [&](const util::optional& batch) + -> Result { + int num_fields = options->projected_schema->num_fields(); + + ArrayVector columns(num_fields); + for (size_t i = 0; i < columns.size(); ++i) { + const Datum& value = batch->values[i]; + if (value.is_array()) { + columns[i] = value.make_array(); + continue; + } + ARROW_ASSIGN_OR_RAISE( + columns[i], + MakeArrayFromScalar(*value.scalar(), batch->length, options->pool)); + } + + EnumeratedRecordBatch out; + out.fragment.index = + batch->values[num_fields].scalar_as().value; + out.fragment.value = fragments[out.fragment.index]; + out.fragment.last = false; // ignored during reordering + + out.record_batch.index = + batch->values[num_fields + 1].scalar_as().value; + out.record_batch.value = RecordBatch::Make( + options->projected_schema, batch->length, std::move(columns)); + out.record_batch.last = + batch->values[num_fields + 2].scalar_as().value; + + return out; + })) + .result(); + + ASSERT_OK_AND_ASSIGN(auto from_scanner_gen, scanner->ScanBatchesUnorderedAsync()); + auto from_scanner = CollectAsyncGenerator(from_scanner_gen).result(); + + auto less = [](const EnumeratedRecordBatch& l, const EnumeratedRecordBatch& r) { + if (l.fragment.index < r.fragment.index) return true; + return l.record_batch.index < r.record_batch.index; + }; + + ASSERT_OK(from_plan); + std::sort(from_plan->begin(), from_plan->end(), less); + + ASSERT_OK(from_scanner); + std::sort(from_scanner->begin(), from_scanner->end(), less); + + ASSERT_EQ(from_plan->size(), from_scanner->size()); + for (size_t i = 0; i < from_plan->size(); ++i) { + const auto& p = from_plan->at(i); + const auto& s = from_scanner->at(i); + SCOPED_TRACE(i); + ASSERT_EQ(p.fragment.index, s.fragment.index); + ASSERT_EQ(p.fragment.value, s.fragment.value); + ASSERT_EQ(p.record_batch.last, s.record_batch.last); + ASSERT_EQ(p.record_batch.index, s.record_batch.index); + AssertBatchesEqual(*p.record_batch.value, *s.record_batch.value); + } +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 42704fea9b5..201fc7e55b2 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -122,25 +122,6 @@ void EnsureRecordBatchReaderDrained(RecordBatchReader* reader) { EXPECT_EQ(batch, nullptr); } -/// Test dataset that returns one or more fragments. -class FragmentDataset : public Dataset { - public: - FragmentDataset(std::shared_ptr schema, FragmentVector fragments) - : Dataset(std::move(schema)), fragments_(std::move(fragments)) {} - - std::string type_name() const override { return "fragment"; } - - Result> ReplaceSchema(std::shared_ptr) const override { - return Status::NotImplemented(""); - } - - protected: - Result GetFragmentsImpl(compute::Expression predicate) override { - return MakeVectorIterator(fragments_); - } - FragmentVector fragments_; -}; - class DatasetFixtureMixin : public ::testing::Test { public: /// \brief Ensure that record batches found in reader are equals to the @@ -547,8 +528,8 @@ class FileFormatScanMixin : public FileFormatFixtureMixin, // Scan the fragment through the scanner. RecordBatchIterator Batches(std::shared_ptr fragment) { - EXPECT_OK_AND_ASSIGN(auto schema, fragment->ReadPhysicalSchema()); - auto dataset = std::make_shared(schema, FragmentVector{fragment}); + auto dataset = std::make_shared(opts_->dataset_schema, + FragmentVector{fragment}); ScannerBuilder builder(dataset, opts_); ARROW_EXPECT_OK(builder.UseAsync(GetParam().use_async)); ARROW_EXPECT_OK(builder.UseThreads(GetParam().use_threads)); @@ -761,12 +742,11 @@ class JSONRecordBatchFileFormat : public FileFormat { ARROW_ASSIGN_OR_RAISE(auto file, fragment->source().Open()); ARROW_ASSIGN_OR_RAISE(int64_t size, file->GetSize()); ARROW_ASSIGN_OR_RAISE(auto buffer, file->Read(size)); - - util::string_view view{*buffer}; - ARROW_ASSIGN_OR_RAISE(auto schema, Inspect(fragment->source())); - std::shared_ptr batch = RecordBatchFromJSON(schema, view); - return ScanTaskIteratorFromRecordBatch({batch}, std::move(options)); + + RecordBatchVector batches{RecordBatchFromJSON(schema, util::string_view{*buffer})}; + return std::make_shared(std::move(schema), std::move(batches)) + ->Scan(std::move(options)); } Result> MakeWriter( @@ -910,13 +890,10 @@ struct ArithmeticDatasetFixture { static std::shared_ptr schema() { return ::arrow::schema({ field("i64", int64()), - // ARROW-1644: Parquet can't write complex level - // field("struct", struct_({ - // // ARROW-2587: Parquet can't write struct with more - // // than one field. - // // field("i32", int32()), - // field("str", utf8()), - // })), + field("struct", struct_({ + field("i32", int32()), + field("str", utf8()), + })), field("u8", uint8()), field("list", list(int32())), field("bool", boolean()), @@ -933,12 +910,12 @@ struct ArithmeticDatasetFixture { ss << "{"; ss << "\"i64\": " << n << ", "; - // ss << "\"struct\": {"; - // { - // // ss << "\"i32\": " << n_i32 << ", "; - // ss << "\"str\": \"" << std::to_string(n) << "\""; - // } - // ss << "}, "; + ss << "\"struct\": {"; + { + ss << "\"i32\": " << n_i32 << ", "; + ss << R"("str": ")" << std::to_string(n) << "\""; + } + ss << "}, "; ss << "\"u8\": " << static_cast(n) << ", "; ss << "\"list\": [" << n_i32 << ", " << n_i32 << "], "; ss << "\"bool\": " << (static_cast(n % 2) ? "true" : "false"); @@ -1052,7 +1029,7 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { ASSERT_OK_AND_ASSIGN(dataset_, factory->Finish()); scan_options_ = std::make_shared(); - scan_options_->dataset_schema = source_schema_; + scan_options_->dataset_schema = dataset_->schema(); ASSERT_OK(SetProjection(scan_options_.get(), source_schema_->field_names())); } diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index 8c2ac376d1e..8d1c16e0ed6 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -69,10 +69,12 @@ class PrettyPrinter { }; void PrettyPrinter::OpenArray(const Array& array) { - Indent(); + if (!options_.skip_new_lines) { + Indent(); + } (*sink_) << "["; if (array.length() > 0) { - (*sink_) << "\n"; + Newline(); indent_ += options_.indent_size; } } @@ -103,7 +105,6 @@ void PrettyPrinter::Newline() { return; } (*sink_) << "\n"; - Indent(); } void PrettyPrinter::Indent() { @@ -124,11 +125,15 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); + } + if (!options_.skip_new_lines) { + Indent(); } - Indent(); if ((i >= options_.window) && (i < (array.length() - options_.window))) { - (*sink_) << "...\n"; + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { @@ -137,7 +142,7 @@ class ArrayPrinter : public PrettyPrinter { func(i); } } - (*sink_) << "\n"; + Newline(); } Status WriteDataValues(const BooleanArray& array) { @@ -239,11 +244,13 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); } if ((i >= options_.window) && (i < (array.length() - options_.window))) { Indent(); - (*sink_) << "...\n"; + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { @@ -252,10 +259,11 @@ class ArrayPrinter : public PrettyPrinter { } else { std::shared_ptr slice = array.values()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*slice, {indent_, options_.window}, sink_)); + RETURN_NOT_OK( + PrettyPrint(*slice, PrettyPrintOptions{indent_, options_.window}, sink_)); } } - (*sink_) << "\n"; + Newline(); return Status::OK(); } @@ -265,28 +273,36 @@ class ArrayPrinter : public PrettyPrinter { if (skip_comma) { skip_comma = false; } else { - (*sink_) << ",\n"; + (*sink_) << ","; + Newline(); } - if ((i >= options_.window) && (i < (array.length() - options_.window))) { + + if (!options_.skip_new_lines) { Indent(); - (*sink_) << "...\n"; + } + + if ((i >= options_.window) && (i < (array.length() - options_.window))) { + (*sink_) << "..."; + Newline(); i = array.length() - options_.window - 1; skip_comma = true; } else if (array.IsNull(i)) { - Indent(); (*sink_) << options_.null_rep; } else { - Indent(); - (*sink_) << "keys:\n"; + (*sink_) << "keys:"; + Newline(); auto keys_slice = array.keys()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*keys_slice, {indent_, options_.window}, sink_)); - (*sink_) << "\n"; + RETURN_NOT_OK(PrettyPrint(*keys_slice, + PrettyPrintOptions{indent_, options_.window}, sink_)); + Newline(); Indent(); - (*sink_) << "values:\n"; + (*sink_) << "values:"; + Newline(); auto values_slice = array.items()->Slice(array.value_offset(i), array.value_length(i)); - RETURN_NOT_OK(PrettyPrint(*values_slice, {indent_, options_.window}, sink_)); + RETURN_NOT_OK(PrettyPrint(*values_slice, + PrettyPrintOptions{indent_, options_.window}, sink_)); } } (*sink_) << "\n"; @@ -325,6 +341,7 @@ class ArrayPrinter : public PrettyPrinter { int64_t length) { for (size_t i = 0; i < fields.size(); ++i) { Newline(); + Indent(); std::stringstream ss; ss << "-- child " << i << " type: " << fields[i]->type()->ToString() << "\n"; Write(ss.str()); @@ -352,12 +369,14 @@ class ArrayPrinter : public PrettyPrinter { RETURN_NOT_OK(WriteValidityBitmap(array)); Newline(); + Indent(); Write("-- type_ids: "); UInt8Array type_codes(array.length(), array.type_codes(), nullptr, 0, array.offset()); RETURN_NOT_OK(PrettyPrint(type_codes, indent_ + options_.indent_size, sink_)); if (array.mode() == UnionMode::DENSE) { Newline(); + Indent(); Write("-- value_offsets: "); Int32Array value_offsets( array.length(), checked_cast(array).value_offsets(), @@ -376,11 +395,13 @@ class ArrayPrinter : public PrettyPrinter { Status Visit(const DictionaryArray& array) { Newline(); + Indent(); Write("-- dictionary:\n"); RETURN_NOT_OK( PrettyPrint(*array.dictionary(), indent_ + options_.indent_size, sink_)); Newline(); + Indent(); Write("-- indices:\n"); return PrettyPrint(*array.indices(), indent_ + options_.indent_size, sink_); } @@ -431,6 +452,7 @@ Status ArrayPrinter::WriteValidityBitmap(const Array& array) { if (array.null_count() > 0) { Newline(); + Indent(); BooleanArray is_valid(array.length(), array.null_bitmap(), nullptr, 0, array.offset()); return PrettyPrint(is_valid, indent_ + options_.indent_size, sink_); @@ -470,19 +492,28 @@ Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& op for (int i = 0; i < indent; ++i) { (*sink) << " "; } - (*sink) << "[\n"; + (*sink) << "["; + if (!options.skip_new_lines) { + *sink << "\n"; + } bool skip_comma = true; for (int i = 0; i < num_chunks; ++i) { if (skip_comma) { skip_comma = false; } else { - (*sink) << ",\n"; + (*sink) << ","; + if (!options.skip_new_lines) { + *sink << "\n"; + } } if ((i >= window) && (i < (num_chunks - window))) { for (int i = 0; i < indent; ++i) { (*sink) << " "; } - (*sink) << "...\n"; + (*sink) << "..."; + if (!options.skip_new_lines) { + *sink << "\n"; + } i = num_chunks - window - 1; skip_comma = true; } else { @@ -492,7 +523,9 @@ Status PrettyPrint(const ChunkedArray& chunked_arr, const PrettyPrintOptions& op RETURN_NOT_OK(printer.Print(*chunked_arr.chunk(i))); } } - (*sink) << "\n"; + if (!options.skip_new_lines) { + *sink << "\n"; + } for (int i = 0; i < indent; ++i) { (*sink) << " "; @@ -572,6 +605,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintVerboseMetadata(const KeyValueMetadata& metadata) { for (int64_t i = 0; i < metadata.size(); ++i) { Newline(); + Indent(); Write(metadata.key(i) + ": '" + metadata.value(i) + "'"); } } @@ -579,6 +613,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintTruncatedMetadata(const KeyValueMetadata& metadata) { for (int64_t i = 0; i < metadata.size(); ++i) { Newline(); + Indent(); size_t size = metadata.value(i).size(); size_t truncated_size = std::max(10, 70 - metadata.key(i).size() - indent_); if (size <= truncated_size) { @@ -594,6 +629,7 @@ class SchemaPrinter : public PrettyPrinter { void PrintMetadata(const std::string& metadata_type, const KeyValueMetadata& metadata) { if (metadata.size() > 0) { Newline(); + Indent(); Write(metadata_type); if (options_.truncate_metadata) { PrintTruncatedMetadata(metadata); @@ -607,6 +643,7 @@ class SchemaPrinter : public PrettyPrinter { for (int i = 0; i < schema_.num_fields(); ++i) { if (i > 0) { Newline(); + Indent(); } else { Indent(); } @@ -631,6 +668,7 @@ Status SchemaPrinter::PrintType(const DataType& type, bool nullable) { } for (int i = 0; i < type.num_fields(); ++i) { Newline(); + Indent(); std::stringstream ss; ss << "child " << i << ", "; diff --git a/cpp/src/arrow/pretty_print.h b/cpp/src/arrow/pretty_print.h index 9d2c72c7186..1bc086a6889 100644 --- a/cpp/src/arrow/pretty_print.h +++ b/cpp/src/arrow/pretty_print.h @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/util/visibility.h" @@ -34,13 +35,14 @@ class Table; struct PrettyPrintOptions { PrettyPrintOptions() = default; - PrettyPrintOptions(int indent_arg, int window_arg = 10, int indent_size_arg = 2, + PrettyPrintOptions(int indent_arg, // NOLINT runtime/explicit + int window_arg = 10, int indent_size_arg = 2, std::string null_rep_arg = "null", bool skip_new_lines_arg = false, bool truncate_metadata_arg = true) : indent(indent_arg), indent_size(indent_size_arg), window(window_arg), - null_rep(null_rep_arg), + null_rep(std::move(null_rep_arg)), skip_new_lines(skip_new_lines_arg), truncate_metadata(truncate_metadata_arg) {} diff --git a/cpp/src/arrow/result.h b/cpp/src/arrow/result.h index 0172a852434..cb7437cd242 100644 --- a/cpp/src/arrow/result.h +++ b/cpp/src/arrow/result.h @@ -478,6 +478,11 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { /// /// WARNING: ARROW_ASSIGN_OR_RAISE `std::move`s its right operand. If you have /// an lvalue Result which you *don't* want to move out of cast appropriately. +/// +/// WARNING: ARROW_ASSIGN_OR_RAISE is not a single expression; it will not +/// maintain lifetimes of all temporaries in `rexpr` (e.g. +/// `ARROW_ASSIGN_OR_RAISE(auto x, MakeTemp().GetResultRef());` +/// will most likely segfault)! #define ARROW_ASSIGN_OR_RAISE(lhs, rexpr) \ ARROW_ASSIGN_OR_RAISE_IMPL(ARROW_ASSIGN_OR_RAISE_NAME(_error_or_value, __COUNTER__), \ lhs, rexpr); @@ -485,7 +490,7 @@ class ARROW_MUST_USE_TYPE Result : public util::EqualityComparable> { namespace internal { template -inline Status GenericToStatus(const Result& res) { +inline const Status& GenericToStatus(const Result& res) { return res.status(); } @@ -496,9 +501,9 @@ inline Status GenericToStatus(Result&& res) { } // namespace internal -template -Result ToResult(T t) { - return Result(std::move(t)); +template ::type> +R ToResult(T t) { + return R(std::move(t)); } template diff --git a/cpp/src/arrow/result_test.cc b/cpp/src/arrow/result_test.cc index b71af9d8531..cb645bc7402 100644 --- a/cpp/src/arrow/result_test.cc +++ b/cpp/src/arrow/result_test.cc @@ -26,6 +26,8 @@ #include #include "arrow/testing/gtest_compat.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" namespace arrow { @@ -724,5 +726,74 @@ TEST(ResultTest, ViewAsStatus) { EXPECT_EQ(ViewAsStatus(&err), &err.status()); } +TEST(ResultTest, MatcherExamples) { + EXPECT_THAT(Result(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid)); + + EXPECT_THAT(Result(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Result(Status::Invalid("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Result(Status::TypeError("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Result(333), testing::Not(Raises(StatusCode::Invalid))); + + EXPECT_THAT(Result("hello world"), + ResultWith(testing::HasSubstr("hello"))); + + EXPECT_THAT(Result(Status::Invalid("XXX")), + testing::Not(ResultWith(testing::HasSubstr("hello")))); + + // holds a value, but that value doesn't match the given pattern + EXPECT_THAT(Result("foo bar"), + testing::Not(ResultWith(testing::HasSubstr("hello")))); +} + +TEST(ResultTest, MatcherDescriptions) { + testing::Matcher> matcher = ResultWith(testing::HasSubstr("hello")); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("value has substring \"hello\"")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("value has no substring \"hello\"")); + } +} + +TEST(ResultTest, MatcherExplanations) { + testing::Matcher> matcher = ResultWith(testing::HasSubstr("hello")); + + { + testing::StringMatchResultListener listener; + EXPECT_TRUE(matcher.MatchAndExplain(Result("hello world"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"hello world\" matches")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Result("foo bar"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"foo bar\" doesn't match")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener)); + EXPECT_THAT(listener.str(), + testing::StrEq("whose error \"Type error: XXX\" doesn't match")); + } +} + } // namespace } // namespace arrow diff --git a/cpp/src/arrow/status.h b/cpp/src/arrow/status.h index 43879e6c6a3..056d60d6f32 100644 --- a/cpp/src/arrow/status.h +++ b/cpp/src/arrow/status.h @@ -312,7 +312,10 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable< StatusCode code() const { return ok() ? StatusCode::OK : state_->code; } /// \brief Return the specific error message attached to this status. - std::string message() const { return ok() ? "" : state_->msg; } + const std::string& message() const { + static const std::string no_message = ""; + return ok() ? no_message : state_->msg; + } /// \brief Return the status detail attached to this message. const std::shared_ptr& detail() const { @@ -440,7 +443,7 @@ namespace internal { // Extract Status from Status or Result // Useful for the status check macros such as RETURN_NOT_OK. -inline Status GenericToStatus(const Status& st) { return st; } +inline const Status& GenericToStatus(const Status& st) { return st; } inline Status GenericToStatus(Status&& st) { return std::move(st); } } // namespace internal diff --git a/cpp/src/arrow/status_test.cc b/cpp/src/arrow/status_test.cc index fc5a7ec45cf..10a79d9b990 100644 --- a/cpp/src/arrow/status_test.cc +++ b/cpp/src/arrow/status_test.cc @@ -17,9 +17,12 @@ #include +#include #include #include "arrow/status.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" namespace arrow { @@ -114,6 +117,85 @@ TEST(StatusTest, TestEquality) { ASSERT_NE(Status::Invalid("error"), Status::Invalid("other error")); } +TEST(StatusTest, MatcherExamples) { + EXPECT_THAT(Status::Invalid("arbitrary error"), Raises(StatusCode::Invalid)); + + EXPECT_THAT(Status::Invalid("arbitrary error"), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Status::Invalid("arbitrary error"), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Status::TypeError("arbitrary error"), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Status::OK(), testing::Not(Raises(StatusCode::Invalid))); +} + +TEST(StatusTest, MatcherDescriptions) { + testing::Matcher matcher = Raises(StatusCode::Invalid); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("raises StatusCode::Invalid")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid")); + } +} + +TEST(StatusTest, MessageMatcherDescriptions) { + testing::Matcher matcher = + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary")); + + { + std::stringstream ss; + matcher.DescribeTo(&ss); + EXPECT_THAT( + ss.str(), + testing::StrEq( + "raises StatusCode::Invalid and message has substring \"arbitrary\"")); + } + + { + std::stringstream ss; + matcher.DescribeNegationTo(&ss); + EXPECT_THAT(ss.str(), testing::StrEq("does not raise StatusCode::Invalid or message " + "has no substring \"arbitrary\"")); + } +} + +TEST(StatusTest, MatcherExplanations) { + testing::Matcher matcher = Raises(StatusCode::Invalid); + + { + testing::StringMatchResultListener listener; + EXPECT_TRUE(matcher.MatchAndExplain(Status::Invalid("XXX"), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"Invalid: XXX\" matches")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::OK(), &listener)); + EXPECT_THAT(listener.str(), testing::StrEq("whose value \"OK\" doesn't match")); + } + + { + testing::StringMatchResultListener listener; + EXPECT_FALSE(matcher.MatchAndExplain(Status::TypeError("XXX"), &listener)); + EXPECT_THAT(listener.str(), + testing::StrEq("whose value \"Type error: XXX\" doesn't match")); + } +} + TEST(StatusTest, TestDetailEquality) { const auto status_with_detail = arrow::Status(StatusCode::IOError, "", std::make_shared()); diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h new file mode 100644 index 00000000000..246f321e8fa --- /dev/null +++ b/cpp/src/arrow/testing/matchers.h @@ -0,0 +1,177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/result.h" +#include "arrow/status.h" + +namespace arrow { + +template +class ResultMatcher { + public: + explicit ResultMatcher(ValueMatcher value_matcher) + : value_matcher_(std::move(value_matcher)) {} + + template ::type::ValueType> + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(const ValueMatcher& value_matcher) + : value_matcher_(testing::MatcherCast(value_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "value "; + value_matcher_.DescribeNegationTo(os); + } + + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + if (!maybe_value.status().ok()) { + *listener << "whose error " + << testing::PrintToString(maybe_value.status().ToString()) + << " doesn't match"; + return false; + } + const ValueType& value = GetValue(maybe_value); + testing::StringMatchResultListener value_listener; + const bool match = value_matcher_.MatchAndExplain(value, &value_listener); + *listener << "whose value " << testing::PrintToString(value) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const testing::Matcher value_matcher_; + }; + + return testing::Matcher(new Impl(value_matcher_)); + } + + private: + template + static const T& GetValue(const Result& maybe_value) { + return maybe_value.ValueOrDie(); + } + + template + static const T& GetValue(const Future& value_fut) { + return GetValue(value_fut.result()); + } + + const ValueMatcher value_matcher_; +}; + +class StatusMatcher { + public: + explicit StatusMatcher(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + template + operator testing::Matcher() const { // NOLINT runtime/explicit + struct Impl : testing::MatcherInterface { + explicit Impl(StatusCode code, + util::optional> message_matcher) + : code_(code), message_matcher_(std::move(message_matcher)) {} + + void DescribeTo(::std::ostream* os) const override { + *os << "raises StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " and message "; + message_matcher_->DescribeTo(os); + } + } + + void DescribeNegationTo(::std::ostream* os) const override { + *os << "does not raise StatusCode::" << Status::CodeAsString(code_); + if (message_matcher_) { + *os << " or message "; + message_matcher_->DescribeNegationTo(os); + } + } + + bool MatchAndExplain(const Res& maybe_value, + testing::MatchResultListener* listener) const override { + const Status& status = GetStatus(maybe_value); + testing::StringMatchResultListener value_listener; + + bool match = status.code() == code_; + if (message_matcher_) { + match = match && + message_matcher_->MatchAndExplain(status.message(), &value_listener); + } + + *listener << "whose value " << testing::PrintToString(status.ToString()) + << (match ? " matches" : " doesn't match"); + testing::internal::PrintIfNotEmpty(value_listener.str(), listener->stream()); + return match; + } + + const StatusCode code_; + const util::optional> message_matcher_; + }; + + return testing::Matcher(new Impl(code_, message_matcher_)); + } + + private: + static const Status& GetStatus(const Status& status) { return status; } + + template + static const Status& GetStatus(const Result& maybe_value) { + return maybe_value.status(); + } + + template + static const Status& GetStatus(const Future& value_fut) { + return value_fut.status(); + } + + const StatusCode code_; + const util::optional> message_matcher_; +}; + +// Returns a matcher that matches the value of a successful Result or Future. +// (Future will be waited upon to acquire its result for matching.) +template +ResultMatcher ResultWith(const ValueMatcher& value_matcher) { + return ResultMatcher(value_matcher); +} + +// Returns a matcher that matches the StatusCode of a Status, Result, or Future. +// (Future will be waited upon to acquire its result for matching.) +inline StatusMatcher Raises(StatusCode code) { + return StatusMatcher(code, util::nullopt); +} + +// Returns a matcher that matches the StatusCode and message of a Status, Result, or +// Future. (Future will be waited upon to acquire its result for matching.) +template +StatusMatcher Raises(StatusCode code, const MessageMatcher& message_matcher) { + return StatusMatcher(code, testing::MatcherCast(message_matcher)); +} + +} // namespace arrow diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 65c783ce847..41914f43663 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1195,6 +1195,10 @@ std::string FieldRef::ToString() const { } std::vector FieldRef::FindAll(const Schema& schema) const { + if (auto name = this->name()) { + return internal::MapVector([](int i) { return FieldPath{i}; }, + schema.GetAllFieldIndices(*name)); + } return FindAll(schema.fields()); } diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 084720f9908..1ac10ad7ce8 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -259,43 +259,17 @@ class MappingGenerator { /// Note: Errors returned from the `map` function will be propagated /// /// If the source generator is async-reentrant then this generator will be also -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function(const T&)> map) { - std::function(const T&)> future_map = [map](const T& val) -> Future { - return Future::MakeFinished(map(val)); - }; - return MappingGenerator(std::move(source_generator), std::move(future_map)); -} -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function map) { - std::function(const T&)> maybe_future_map = [map](const T& val) -> Future { - return Future::MakeFinished(map(val)); - }; - return MappingGenerator(std::move(source_generator), std::move(maybe_future_map)); -} -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, - std::function(const T&)> map) { - return MappingGenerator(std::move(source_generator), std::move(map)); -} - -template -AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFunc map) { +template , + typename V = typename EnsureFuture::type::ValueType> +AsyncGenerator MakeMappedGenerator(AsyncGenerator source_generator, MapFn map) { struct MapCallback { - MapFunc map; + MapFn map_; - Future operator()(const T& val) { return EnsureFuture(map(val)); } - - Future EnsureFuture(Result val) { - return Future::MakeFinished(std::move(val)); - } - Future EnsureFuture(V val) { return Future::MakeFinished(std::move(val)); } - Future EnsureFuture(Future val) { return val; } + Future operator()(const T& val) { return ToFuture(map_(val)); } }; - std::function(const T&)> map_fn = MapCallback{map}; - return MappingGenerator(std::move(source_generator), map_fn); + + return MappingGenerator(std::move(source_generator), MapCallback{std::move(map)}); } /// \see MakeSequencingGenerator diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index 14b528ade5e..29c8d73ab6c 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" @@ -51,7 +52,7 @@ AsyncGenerator FailsAt(AsyncGenerator src, int failing_index) { template AsyncGenerator SlowdownABit(AsyncGenerator source) { - return MakeMappedGenerator(std::move(source), [](const T& res) -> Future { + return MakeMappedGenerator(std::move(source), [](const T& res) { return SleepABitAsync().Then([res]() { return res; }); }); } @@ -88,8 +89,7 @@ std::function()> BackgroundAsyncVectorIt( auto slow_iterator = PossiblySlowVectorIt(v, sleep); EXPECT_OK_AND_ASSIGN( auto background, - MakeBackgroundGenerator(std::move(slow_iterator), - internal::GetCpuThreadPool(), max_q, q_restart)); + MakeBackgroundGenerator(std::move(slow_iterator), pool, max_q, q_restart)); return MakeTransferredGenerator(background, pool); } @@ -106,8 +106,7 @@ std::function()> NewBackgroundAsyncVectorIt(std::vector }); EXPECT_OK_AND_ASSIGN(auto background, - MakeBackgroundGenerator(std::move(slow_iterator), - internal::GetCpuThreadPool())); + MakeBackgroundGenerator(std::move(slow_iterator), pool)); return MakeTransferredGenerator(background, pool); } @@ -176,7 +175,8 @@ class ReentrantChecker { template class ReentrantCheckerGuard { public: - explicit ReentrantCheckerGuard(ReentrantChecker checker) : checker_(checker) {} + explicit ReentrantCheckerGuard(ReentrantChecker checker) + : checker_(std::move(checker)) {} ARROW_DISALLOW_COPY_AND_ASSIGN(ReentrantCheckerGuard); ReentrantCheckerGuard(ReentrantCheckerGuard&& other) : checker_(other.checker_) { diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index d08c598a32b..c7c5ba802f9 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -36,6 +36,9 @@ namespace arrow { +template +struct EnsureFuture; + namespace detail { template @@ -976,4 +979,28 @@ Future Loop(Iterate iterate) { return break_fut; } +inline Future<> ToFuture(Status status) { + return Future<>::MakeFinished(std::move(status)); +} + +template +Future ToFuture(T value) { + return Future::MakeFinished(std::move(value)); +} + +template +Future ToFuture(Result maybe_value) { + return Future::MakeFinished(std::move(maybe_value)); +} + +template +Future ToFuture(Future fut) { + return std::move(fut); +} + +template +struct EnsureFuture { + using type = decltype(ToFuture(std::declval())); +}; + } // namespace arrow diff --git a/cpp/src/arrow/util/future_test.cc b/cpp/src/arrow/util/future_test.cc index 33796a05bb1..b25d77c48cd 100644 --- a/cpp/src/arrow/util/future_test.cc +++ b/cpp/src/arrow/util/future_test.cc @@ -36,6 +36,7 @@ #include "arrow/testing/executor_util.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" #include "arrow/util/logging.h" #include "arrow/util/thread_pool.h" @@ -1704,5 +1705,45 @@ TEST(FnOnceTest, MoveOnlyDataType) { ASSERT_EQ(i0.moves, 0); ASSERT_EQ(i1.moves, 0); } + +TEST(FutureTest, MatcherExamples) { + EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid)); + + EXPECT_THAT(Future::MakeFinished(Status::Invalid("arbitrary error")), + Raises(StatusCode::Invalid, testing::HasSubstr("arbitrary"))); + + // message doesn't match, so no match + EXPECT_THAT( + Future::MakeFinished(Status::Invalid("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid, testing::HasSubstr("reasonable")))); + + // different error code, so no match + EXPECT_THAT(Future::MakeFinished(Status::TypeError("arbitrary error")), + testing::Not(Raises(StatusCode::Invalid))); + + // not an error, so no match + EXPECT_THAT(Future::MakeFinished(333), testing::Not(Raises(StatusCode::Invalid))); + + EXPECT_THAT(Future::MakeFinished("hello world"), + ResultWith(testing::HasSubstr("hello"))); + + // Matcher waits on Futures + auto string_fut = Future::Make(); + auto finisher = std::thread([&] { + SleepABit(); + string_fut.MarkFinished("hello world"); + }); + EXPECT_THAT(string_fut, ResultWith(testing::HasSubstr("hello"))); + finisher.join(); + + EXPECT_THAT(Future::MakeFinished(Status::Invalid("XXX")), + testing::Not(ResultWith(testing::HasSubstr("hello")))); + + // holds a value, but that value doesn't match the given pattern + EXPECT_THAT(Future::MakeFinished("foo bar"), + testing::Not(ResultWith(testing::HasSubstr("hello")))); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/thread_pool.cc b/cpp/src/arrow/util/thread_pool.cc index 672839b67d5..758295d01ed 100644 --- a/cpp/src/arrow/util/thread_pool.cc +++ b/cpp/src/arrow/util/thread_pool.cc @@ -321,13 +321,20 @@ void ThreadPool::CollectFinishedWorkersUnlocked() { state_->finished_workers_.clear(); } +thread_local ThreadPool* current_thread_pool_ = nullptr; + +bool ThreadPool::OwnsThisThread() { return current_thread_pool_ == this; } + void ThreadPool::LaunchWorkersUnlocked(int threads) { std::shared_ptr state = sp_state_; for (int i = 0; i < threads; i++) { state_->workers_.emplace_back(); auto it = --(state_->workers_.end()); - *it = std::thread([state, it] { WorkerLoop(state, it); }); + *it = std::thread([this, state, it] { + current_thread_pool_ = this; + WorkerLoop(state, it); + }); } } diff --git a/cpp/src/arrow/util/thread_pool.h b/cpp/src/arrow/util/thread_pool.h index d012aa02010..febbc997852 100644 --- a/cpp/src/arrow/util/thread_pool.h +++ b/cpp/src/arrow/util/thread_pool.h @@ -179,6 +179,10 @@ class ARROW_EXPORT Executor { // concurrently). This may be an approximate number. virtual int GetCapacity() = 0; + // Return true if the thread from which this function is called is owned by this + // Executor. Returns false if this Executor does not support this property. + virtual bool OwnsThisThread() { return false; } + protected: ARROW_DISALLOW_COPY_AND_ASSIGN(Executor); @@ -298,6 +302,8 @@ class ARROW_EXPORT ThreadPool : public Executor { // match this value. int GetCapacity() override; + bool OwnsThisThread() override; + // Return the number of tasks either running or in the queue. int GetNumTasks(); diff --git a/cpp/src/arrow/util/thread_pool_test.cc b/cpp/src/arrow/util/thread_pool_test.cc index 2cfb4c62613..399c755a8f9 100644 --- a/cpp/src/arrow/util/thread_pool_test.cc +++ b/cpp/src/arrow/util/thread_pool_test.cc @@ -395,6 +395,23 @@ TEST_F(TestThreadPool, StressSpawn) { SpawnAdds(pool.get(), 1000, task_add); } +TEST_F(TestThreadPool, OwnsCurrentThread) { + auto pool = this->MakeThreadPool(30); + std::atomic one_failed{false}; + + for (int i = 0; i < 1000; ++i) { + ASSERT_OK(pool->Spawn([&] { + if (pool->OwnsThisThread()) return; + + one_failed = true; + })); + } + + ASSERT_OK(pool->Shutdown()); + ASSERT_FALSE(pool->OwnsThisThread()); + ASSERT_FALSE(one_failed); +} + TEST_F(TestThreadPool, StressSpawnThreaded) { auto pool = this->MakeThreadPool(30); SpawnAddsThreaded(pool.get(), 20, 100, task_add); diff --git a/cpp/src/arrow/util/vector.h b/cpp/src/arrow/util/vector.h index 3ef0074aa9d..041bdb424a7 100644 --- a/cpp/src/arrow/util/vector.h +++ b/cpp/src/arrow/util/vector.h @@ -84,27 +84,49 @@ std::vector FilterVector(std::vector values, Predicate&& predicate) { return values; } -/// \brief Like MapVector, but where the function can fail. -template , - typename To = typename internal::call_traits::return_type::ValueType> -Result> MaybeMapVector(Fn&& map, const std::vector& src) { +template ()(std::declval()))> +std::vector MapVector(Fn&& map, const std::vector& source) { std::vector out; - out.reserve(src.size()); - ARROW_RETURN_NOT_OK(MaybeTransform(src.begin(), src.end(), std::back_inserter(out), - std::forward(map))); - return std::move(out); + out.reserve(source.size()); + std::transform(source.begin(), source.end(), std::back_inserter(out), + std::forward(map)); + return out; } template ()(std::declval()))> -std::vector MapVector(Fn&& map, const std::vector& source) { +std::vector MapVector(Fn&& map, std::vector&& source) { std::vector out; out.reserve(source.size()); - std::transform(source.begin(), source.end(), std::back_inserter(out), + std::transform(std::make_move_iterator(source.begin()), + std::make_move_iterator(source.end()), std::back_inserter(out), std::forward(map)); return out; } +/// \brief Like MapVector, but where the function can fail. +template , + typename To = typename internal::call_traits::return_type::ValueType> +Result> MaybeMapVector(Fn&& map, const std::vector& source) { + std::vector out; + out.reserve(source.size()); + ARROW_RETURN_NOT_OK(MaybeTransform(source.begin(), source.end(), + std::back_inserter(out), std::forward(map))); + return std::move(out); +} + +template , + typename To = typename internal::call_traits::return_type::ValueType> +Result> MaybeMapVector(Fn&& map, std::vector&& source) { + std::vector out; + out.reserve(source.size()); + ARROW_RETURN_NOT_OK(MaybeTransform(std::make_move_iterator(source.begin()), + std::make_move_iterator(source.end()), + std::back_inserter(out), std::forward(map))); + return std::move(out); +} + template std::vector FlattenVectors(const std::vector>& vecs) { std::size_t sum = 0; diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index bd93da9cb18..e7e8341c9d4 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -3001,7 +3001,7 @@ def _get_partition_keys(Expression partition_expression): pair[CFieldRef, CDatum] ref_val out = {} - for ref_val in GetResultValue(CExtractKnownFieldValues(expr)): + for ref_val in GetResultValue(CExtractKnownFieldValues(expr)).map: assert ref_val.first.name() != nullptr assert ref_val.second.kind() == DatumType_SCALAR val = pyarrow_wrap_scalar(ref_val.second.scalar()) diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 8cab5536647..f9349f3a642 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -32,6 +32,26 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: pass +cdef extern from * namespace "arrow::compute": + # inlined from expression_internal.h to avoid + # proliferation of #include + """ + #include + + #include "arrow/type.h" + #include "arrow/datum.h" + + namespace arrow { + namespace compute { + struct KnownFieldValues { + std::unordered_map map; + }; + } // namespace compute + } // namespace arrow + """ + cdef struct CKnownFieldValues "arrow::compute::KnownFieldValues": + unordered_map[CFieldRef, CDatum, CFieldRefHash] map + cdef extern from "arrow/compute/exec/expression.h" \ namespace "arrow::compute" nogil: @@ -57,7 +77,7 @@ cdef extern from "arrow/compute/exec/expression.h" \ cdef CResult[CExpression] CDeserializeExpression \ "arrow::compute::Deserialize"(shared_ptr[CBuffer]) - cdef CResult[unordered_map[CFieldRef, CDatum, CFieldRefHash]] \ + cdef CResult[CKnownFieldValues] \ CExtractKnownFieldValues "arrow::compute::ExtractKnownFieldValues"( const CExpression& partition_expression) diff --git a/r/src/dataset.cpp b/r/src/dataset.cpp index 24c1a1343ea..7bb1e639e05 100644 --- a/r/src/dataset.cpp +++ b/r/src/dataset.cpp @@ -70,9 +70,9 @@ const char* r6_class_name::get( // [[dataset::export]] std::shared_ptr dataset___Dataset__NewScan( const std::shared_ptr& ds) { - auto options = std::make_shared(); - options->pool = gc_memory_pool(); - return ValueOrStop(ds->NewScan(std::move(options))); + auto builder = ValueOrStop(ds->NewScan()); + StopIfNotOk(builder->Pool(gc_memory_pool())); + return builder; } // [[dataset::export]]